Simulating a Heterogeneous Cryo-EM Dataset¶
Load the config file¶
In [1]:
Copied!
import matplotlib.pyplot as plt
import numpy as np
from cryojax.ndimage import LowpassFilter
from cryospax import RelionParticleDataset, RelionParticleParameterFile
import cryojax_eo as cxeo
import matplotlib.pyplot as plt
import numpy as np
from cryojax.ndimage import LowpassFilter
from cryospax import RelionParticleDataset, RelionParticleParameterFile
import cryojax_eo as cxeo
In [2]:
Copied!
config = cxeo.load_config("./config_data_simulation.yaml", config_mode="data simulation")
config.noise_snr
config = cxeo.load_config("./config_data_simulation.yaml", config_mode="data simulation")
config.noise_snr
Out[2]:
0.1
The image generation can also be run from the command line as:
simulate_data --config config_data_simulation.yaml
If this fails, make sure your virtual environment is activated. This command runs the following function:
In [3]:
Copied!
cxeo.dataset.simulate_relion_dataset(config)
cxeo.dataset.simulate_relion_dataset(config)
Out[3]:
<cryospax._dataset.relion.RelionParticleDataset at 0x1553736c6d50>
Visualize the images¶
In [4]:
Copied!
stack_dataset = RelionParticleDataset(
RelionParticleParameterFile(
path_to_starfile=config.path_to_starfile,
mode="r",
options=dict(loads_envelope=False),
),
path_to_relion_project=config.path_to_relion_project,
mode="r",
)
stack_dataset = RelionParticleDataset(
RelionParticleParameterFile(
path_to_starfile=config.path_to_starfile,
mode="r",
options=dict(loads_envelope=False),
),
path_to_relion_project=config.path_to_relion_project,
mode="r",
)
In [5]:
Copied!
lowpass_filter = LowpassFilter(
stack_dataset[0]["parameters"]["image_config"].frequency_grid_in_pixels,
frequency_cutoff_fraction=0.7,
)
lowpass_filter = LowpassFilter(
stack_dataset[0]["parameters"]["image_config"].frequency_grid_in_pixels,
frequency_cutoff_fraction=0.7,
)
/mnt/home/dsilvasanchez/miniforge3/envs/ens_opt_cuda_forge/lib/python3.12/site-packages/equinox/_module/_module.py:549: FutureWarning: `BasicImageConfig.frequency_grid_in_pixels` has been deprecated and will be removed in cryoJAX 0.6.0. Instead, make the appropriate call to `BasicImageConfig.get_frequency_grid`. out = super().__getattribute__(name)
In [6]:
Copied!
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
images = stack_dataset[0:4]["images"]
# images = irfftn(lowpass_filter(rfftn(images)))
for i in range(4):
ax.flatten()[i].imshow(images[i], cmap="gray")
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
images = stack_dataset[0:4]["images"]
# images = irfftn(lowpass_filter(rfftn(images)))
for i in range(4):
ax.flatten()[i].imshow(images[i], cmap="gray")
Metadata¶
Information about the ensemble and other parameters is saved to a metadata file
In [7]:
Copied!
metadata = np.load("tutorial_data/metadata.npz")
metadata.files
metadata = np.load("tutorial_data/metadata.npz")
metadata.files
Out[7]:
['snr_per_image', 'ensemble_indices_per_image']
In [8]:
Copied!
metadata["ensemble_indices_per_image"]
metadata["ensemble_indices_per_image"]
Out[8]:
array([0, 0, 1, ..., 0, 1, 0], dtype=int32)
In [9]:
Copied!
weight_0 = np.isclose(metadata["ensemble_indices_per_image"], 0).mean()
weight_1 = np.isclose(metadata["ensemble_indices_per_image"], 1).mean()
weight_0, weight_1
weight_0 = np.isclose(metadata["ensemble_indices_per_image"], 0).mean()
weight_1 = np.isclose(metadata["ensemble_indices_per_image"], 1).mean()
weight_0, weight_1
Out[9]:
(0.6912, 0.3088)