Running Ensemble Optimization with cryojax_eo¶
In this notebook, we show how to build a custom pipeline to run our ensemble optimization method.
Warning¶
This notebook assumes you have already run the data simulation tutorial. If you haven't, the required data will not exist.
import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import mdtraj
import optax
from cryospax import RelionParticleDataset, RelionParticleParameterFile
import cryojax_eo as cxeo
config = cxeo.load_config(
"./ensemble_optimization_config.yaml", config_mode="ensemble optimization"
)
# create output directory if it does not exist
os.makedirs(config.path_to_output, exist_ok=True)
Running the default version of the method¶
This is the easiest way to run the method, but it comes with a few limitations:
- Only the noise-variance-marginalized image-to-structure likelihood function is available
- The Steered MD biasing constant scheduler can only be constant or linear
Below we show how the method can be run via a simple function call or from the terminal.
Running from the command line¶
You can also run the optimization directly from the terminal:
run_ensemble_optimization --config ensemble_optimization_config.yaml
Note that the command-line interface does not support custom simulation functions or likelihood functions.
Using a function call¶
walkers, weights = cxeo.run_ensemble_optimization_with_md(config)
Iter 100/100, Neg Log-Likelihood: 79561.2188: 100%|██████████| 100/100 [02:45<00:00, 1.66s/it] Computing full likelihood matrix: 100%|██████████| 10/10 [00:02<00:00, 4.55 batch/s]
weights
Array([0.30819213, 0.69180787], dtype=float32)
# The trajectories can be found in the output directory
traj0 = mdtraj.load(
os.path.join(config.path_to_output, "traj_walker_0.xtc"),
top=os.path.join(config.path_to_output, "final_walker_0.pdb"),
)
traj1 = mdtraj.load(
os.path.join(config.path_to_output, "traj_walker_1.xtc"),
top=os.path.join(config.path_to_output, "final_walker_1.pdb"),
)
# Load the true ensemble members
true_ensemble_structure0 = mdtraj.load("../atomic_models/ala_A.pdb")
true_ensemble_structure1 = mdtraj.load("../atomic_models/ala_B.pdb")
atom_indices = true_ensemble_structure0.topology.select("not element H")
# Compute the RMSD of each trajectory to the true ensemble members
traj0.superpose(true_ensemble_structure0, atom_indices=atom_indices)
traj1.superpose(true_ensemble_structure0, atom_indices=atom_indices)
rmsd00 = mdtraj.rmsd(traj0, true_ensemble_structure0, 0, atom_indices=atom_indices) * 10.0
rmsd10 = mdtraj.rmsd(traj1, true_ensemble_structure0, 0, atom_indices=atom_indices) * 10.0
traj0.superpose(true_ensemble_structure1, atom_indices=atom_indices)
traj1.superpose(true_ensemble_structure1, atom_indices=atom_indices)
rmsd01 = mdtraj.rmsd(traj0, true_ensemble_structure1, 0, atom_indices=atom_indices) * 10.0
rmsd11 = mdtraj.rmsd(traj1, true_ensemble_structure1, 0, atom_indices=atom_indices) * 10.0
fig, ax = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
ax[0].plot(rmsd00, label="Walker 0")
ax[0].plot(rmsd10, label="Walker 1")
ax[0].set_title("RMSD to true ensemble member 0")
ax[0].set_xlabel("Frame")
ax[0].set_ylabel("RMSD (Angstroms)")
ax[0].legend()
ax[1].plot(rmsd01, label="Walker 0")
ax[1].plot(rmsd11, label="Walker 1")
ax[1].set_title("RMSD to true ensemble member 1")
ax[1].set_xlabel("Frame")
ax[1].set_ylabel("RMSD (Angstroms)")
ax[1].legend()
<matplotlib.legend.Legend at 0x1554cb56b1a0>
From the command line¶
The same result can be obtained from the terminal with:
run_ensemble_optimization --config ensemble_optimization_config.yaml
- The atomic models contain the atomic positions and scattering factors used to simulate images
- The reference structure is used for alignment (it must be aligned to the frame of reference of the particle images)
- The restraint atom list selects which atoms are optimized and which are used in the biasing force of the steered MD
config = dict(config.model_dump())
ref_structure = mdtraj.load(config["alignment_params"]["path_to_prealigned_atomic_model"])
# center the reference structure
ref_structure = ref_structure.center_coordinates()
restrain_atom_list = ref_structure.topology.select(config["atom_selection"])
Parse atomic models and filter scattering parameters
initial_walkers, amplitudes, variances = cxeo.io.read_walkers_from_pdbs(
config["path_to_atomic_models"],
loads_b_factors=config["loads_b_factors"],
)
amplitudes = amplitudes[restrain_atom_list]
variances = variances[restrain_atom_list]
initial_walkers.shape, variances.shape, amplitudes.shape
((2, 42, 3), (20, 5), (20, 5))
Load the cryo-EM dataset. The dataloader produces random batches of particle images.
stack_dataset = RelionParticleDataset(
RelionParticleParameterFile(
path_to_starfile=config["data_params"]["path_to_starfile"],
options=dict(
loads_envelope=config["data_params"]["loads_envelope"],
broadcasts_image_config=False,
),
),
path_to_relion_project=config["data_params"]["path_to_relion_project"],
)
key = jax.random.PRNGKey(config["rng_seed"])
key_data, key_pipeline = jax.random.split(key)
dataloader = cxeo.dataset.create_dataloader(
stack_dataset,
batch_size=config["likelihood_optimizer_params"]["batch_size"],
shuffle=True,
drop_last=False,
jax_prng_key=key_data,
per_particle_args=None,
)
The per_particle_args argument can be used to pass additional per-particle information that is not available in the STAR file — for example, noise parameters. This can be useful for defining custom log-likelihood functions, as shown below.
Construct the objects necessary for the ensemble optimization¶
The prior projector¶
This is a list of projectors — one per walker. The base_state_file_path argument specifies where the simulation state is saved at each iteration of the ensemble optimization. This path must be unique for each projector, and is useful for resuming a previous run.
projector_list = []
for i in range(len(config["path_to_atomic_models"])):
projector_list.append(
cxeo.ensemble_optimization.SteeredMDSimulator(
path_to_initial_pdb=config["path_to_atomic_models"][i],
n_steps=config["projector_params"]["n_steps"],
restrain_atom_list=restrain_atom_list,
parameters_for_md={
"platform": config["projector_params"]["platform"],
"properties": config["projector_params"]["platform_properties"],
},
base_state_file_path=os.path.join(
config["path_to_output"], f"states_proj_{i}/state_"
),
make_simulation_fn=None, # see below!
)
)
md_projector = cxeo.ensemble_optimization.EnsembleSteeredMDSimulator(projector_list)
The steered MD simulation can be customized by providing your own make_simulation_fn, which takes an OpenMM topology object (derived from path_to_initial_pdb) and a parameters_for_md dictionary. For example, the default function is defined as:
import openmm
import openmm.app as openmm_app
import openmm.unit as openmm_unit
def _create_forcefield(parameters_for_md: dict) -> openmm_app.ForceField:
return openmm_app.ForceField(
parameters_for_md["forcefield"], parameters_for_md["water_model"]
)
def _create_integrator(parameters_for_md: dict) -> openmm.Integrator:
return openmm.LangevinIntegrator(
parameters_for_md["temperature"],
parameters_for_md["friction"],
parameters_for_md["timestep"],
)
def _create_system(
parameters_for_md: dict,
forcefield: openmm_app.ForceField,
topology: openmm_app.Topology,
) -> openmm.System:
system = forcefield.createSystem(
topology,
nonbondedMethod=parameters_for_md["nonbondedMethod"],
nonbondedCutoff=parameters_for_md["nonbondedCutoff"],
constraints=parameters_for_md["constraints"],
)
return system
def _create_platform(parameters_for_md: dict) -> openmm.Platform:
return openmm.Platform.getPlatformByName(parameters_for_md["platform"])
def _default_make_sim_fn(parameters_for_md: dict, topology) -> openmm_app.Simulation:
forcefield = _create_forcefield(parameters_for_md)
integrator = _create_integrator(parameters_for_md)
platform = _create_platform(parameters_for_md)
system = _create_system(parameters_for_md, forcefield, topology)
simulation = openmm_app.Simulation(
topology,
system,
integrator,
platform,
parameters_for_md["properties"],
)
return simulation
By default, parameters_for_md is populated with these values:
DEFAULT_MD_PARAMS = {
"forcefield": "amber14-all.xml",
"water_model": "amber14/tip3p.xml",
"nonbondedMethod": openmm_app.PME,
"nonbondedCutoff": 1.0 * openmm_unit.nanometer,
"constraints": openmm_app.HBonds,
"temperature": 300.0 * openmm_unit.kelvin,
"friction": 1.0 / openmm_unit.picosecond,
"timestep": 0.002 * openmm_unit.picoseconds,
"platform": "CPU",
"properties": {"Threads": "1"},
}
The likelihood optimizer¶
img_to_walker_likelihood_fn = cxeo.ensemble_optimization.MargGaussianWhiteLogLikelihoodFn(
variances=variances,
amplitudes=amplitudes,
data_sign=config["data_params"]["data_sign"],
)
ensemble_likelihood_fn = cxeo.ensemble_optimization.ImagesToEnsembleLikelihoodFn(
image_to_walker_likelihood_fn=img_to_walker_likelihood_fn,
n_walkers_in_parallel=1, # modify based on your memory and speed requirements
n_images_in_parallel=1, # modify based on your memory and speed requirements
)
likelihood_optimizer = cxeo.ensemble_optimization.IterativeEnsembleLikelihoodOptimizer(
step_size=config["likelihood_optimizer_params"]["step_size"],
n_steps=config["likelihood_optimizer_params"]["n_steps"],
ensemble_likelihood_fn=ensemble_likelihood_fn,
n_batches_per_step=config["likelihood_optimizer_params"]["n_batches_per_step"],
)
The likelihood optimizer can also be customized. The image_to_walker_log_likelihood_fn can be passed as a callable with the following signature:
def my_loglikelihood_fn(
walker: Float[Array, "n_atoms 3"],
relion_stack: ParticleStackInfo,
amplitudes: Float[Array, "n_atoms n_gaussians_per_atom"],
variances: Float[Array, "n_atoms n_gaussians_per_atom"],
dilated_mask: Optional[DilatedMask] = None,
image_sign: Float[Array, ""] = jnp.array(1.0)
per_particle_args: Any, # pytree with one entry per image
):
log_likelihood = ...
return log_likelihood
This function must be JIT-compilable with Equinox. The per_particle_args can be passed as a PyTree with the same batch dimension as the particle stack when defining the dataloader (see above). The arguments for each image should match their order in the STAR file.
The ensemble refinement pipeline¶
ensemble_refinement_pipeline = cxeo.ensemble_optimization.EnsembleOptimizationPipeline(
prior_projector=md_projector,
likelihood_optimizer=likelihood_optimizer,
n_steps=config["n_steps"],
prealigned_structure=ref_structure,
atom_indices_for_opt=restrain_atom_list,
runs_postprocessing=True,
)
Define the initial walkers and weights and run the pipeline.
If restarting from a previous run, the initial_state_for_projector argument should be a list of paths to the saved OpenMM state files.
initial_weights = jnp.array(config["likelihood_optimizer_params"]["initial_weights"])
# In this case we will use a constant schedule, but any optax scheduler is compatible
bias_constant_scheduler = optax.constant_schedule(
config["projector_params"]["bias_constant_in_kjpermol"]
)
walkers, weights = ensemble_refinement_pipeline.run(
key=key_pipeline,
initial_walkers=initial_walkers.copy(),
initial_weights=initial_weights,
dataloader=dataloader,
bias_constant_scheduler=bias_constant_scheduler,
output_directory=config["path_to_output"],
initial_state_for_projector=None,
)
Optimization Progress: 0%| | 0/100 [00:00<?, ?it/s]2026-05-08 17:41:57.164224: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead. 2026-05-08 17:41:57.164250: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead. Iter 100/100, Neg Log-Likelihood: 79796.7422: 100%|██████████| 100/100 [06:34<00:00, 3.95s/it] Computing full likelihood matrix: 100%|██████████| 10/10 [00:02<00:00, 4.16 batch/s]
# These should be around 0.7 and 0.3
print(weights)
[0.69180787 0.30819213]
# The trajectories can be found in the output directory
traj0 = mdtraj.load(
os.path.join(config["path_to_output"], "traj_walker_0.xtc"),
# top=config["alignment_params"]["path_to_prealigned_atomic_model"],
top=os.path.join(config["path_to_output"], "final_walker_0.pdb"),
)
traj1 = mdtraj.load(
os.path.join(config["path_to_output"], "traj_walker_1.xtc"),
# top=config["alignment_params"]["path_to_prealigned_atomic_model"],
top=os.path.join(config["path_to_output"], "final_walker_1.pdb"),
)
# Let's also load the true members of the ensemble
true_ensemble_structure0 = mdtraj.load("../atomic_models/ala_A.pdb")
true_ensemble_structure1 = mdtraj.load("../atomic_models/ala_B.pdb")
atom_indices = true_ensemble_structure0.topology.select("not element H")
# now comute the RMSD of each trajectory to the true ensemble members
rmsd00 = mdtraj.rmsd(traj0, true_ensemble_structure0, 0) * 10.0
rmsd01 = mdtraj.rmsd(traj0, true_ensemble_structure1, 0) * 10.0
rmsd10 = mdtraj.rmsd(traj1, true_ensemble_structure0, 0) * 10.0
rmsd11 = mdtraj.rmsd(traj1, true_ensemble_structure1, 0) * 10.0
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].plot(rmsd00, label="Walker 0")
ax[0].plot(rmsd10, label="Walker 1")
ax[0].set_title("RMSD to true ensemble member 0")
ax[0].set_xlabel("Frame")
ax[0].set_ylabel("RMSD (Angstroms)")
ax[0].legend()
ax[1].plot(rmsd01, label="Walker 0")
ax[1].plot(rmsd11, label="Walker 1")
ax[1].set_title("RMSD to true ensemble member 1")
ax[1].set_xlabel("Frame")
ax[1].set_ylabel("RMSD (Angstroms)")
ax[1].legend()
<matplotlib.legend.Legend at 0x154b15e39a00>