Skip to content

Build your own opimizer with jax

This example caters to the users who are interested in building their own optimizers with jax.

For this purpose, we expose how the MAP optimizer is implemented in psyphy.

You can run the toy example with this from scratch implementation yourself with the following script:

python docs/examples/wppm/full_wppm_fit_example.py


Model

Model
# Set all Wishart process hyperparameters in Prior
truth_prior = Prior(
    input_dim=input_dim,  # (2D)
    basis_degree=basis_degree,  # (5)
    extra_embedding_dims=extra_dims,  # (1)
    decay_rate=decay_rate,  # for basis functions
    variance_scale=variance_scale,  # how big covariance matrices
    # are before fitting
)
truth_model = WPPM(
    input_dim=input_dim,
    extra_dims=extra_dims,
    prior=truth_prior,
    task=task,  # oddity task ("pick the odd-one out among 3 stimuli")
    noise=noise,  # (Gaussian noise)
    diag_term=diag_term,  # ensure positive-definite covariances
)

# Sample ground-truth Wishart process weights
truth_params = truth_model.init_params(jax.random.PRNGKey(123))
data = ResponseData()
num_trials_per_ref = NUM_TRIALS_TOTAL  # (trials per reference point)
n_ref_grid = 5  # NUM_GRID_PTS
ref_grid = jnp.linspace(-1, 1, n_ref_grid)  # [-1,1] space
ref_points = jnp.stack(jnp.meshgrid(ref_grid, ref_grid), axis=-1).reshape(-1, 2)

# --- Stimulus design: covariance-scaled probe displacements ---
#
# Rather than sampling probes at a fixed Euclidean radius, we scale the probe
# displacement by sqrt(Σ_ref). This tends to equalize trial difficulty across
# space (roughly constant Mahalanobis radius).

seed = 3
key = jr.PRNGKey(seed)

# Build a batched reference list by repeating each grid point.
n_ref = ref_points.shape[0]
refs = jnp.repeat(ref_points, repeats=num_trials_per_ref, axis=0)  # (N, 2)
num_trials_total = int(refs.shape[0])

# Evaluate Σ(ref) in batch using the psyphy covariance-field wrapper.
truth_field = WPPMCovarianceField(truth_model, truth_params)
Sigmas_ref = truth_field(refs)  # (N, 2, 2)

# Sample unit directions on the circle.
k_dir, k_pred, k_y = jr.split(key, 3)
angles = jr.uniform(k_dir, shape=(num_trials_total,), minval=0.0, maxval=2.0 * jnp.pi)
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1)  # (N, 2)


# displacement scale/orientation follows the local ellipse
# (constant-ish Mahalanobis radius).
# a constant Mahalanobis radius for generating probes around reference points
# MAHAL_RADIUS * chol(Sigma_ref) @ unit_dir
MAHAL_RADIUS = 2.8

# Compute sqrt(Σ) via Cholesky (Σ should be SPD; diag_term/noise keep it stable).
L = jnp.linalg.cholesky(Sigmas_ref)  # (N, 2, 2)
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs)  # (N, 2)
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)

# Compute p(correct) in batch. We vmap the single-trial predictor.
trial_pred_keys = jr.split(k_pred, num_trials_total)


def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
    # Task MC settings (num_samples/bandwidth) come from OddityTaskConfig.
    # Only the randomness is threaded dynamically.
    return task._simulate_trial_mc(
        params=truth_params,
        ref=ref,
        comparison=comp,
        model=truth_model,
        noise=truth_model.noise,
        num_samples=int(task.config.num_samples),
        bandwidth=float(task.config.bandwidth),
        key=kk,
    )


p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)

# Sample observed y ~ Bernoulli(p_correct) in batch.
ys = jr.bernoulli(k_y, p_correct, shape=(num_trials_total,)).astype(jnp.int32)

# Populate ResponseData (kept as Python loop since ResponseData is a Python container).
for ref, comp, y in zip(
    jax.device_get(refs), jax.device_get(comparisons), jax.device_get(ys)
):
    data.add_trial(ref=jnp.array(ref), comparison=jnp.array(comp), resp=int(y))

prior = Prior(
    input_dim=input_dim,  # (2D)
    basis_degree=basis_degree,  # 5
    extra_embedding_dims=extra_dims,  # 1
    decay_rate=decay_rate,  # for basis functions (how quickly they vary)
    variance_scale=variance_scale,  # how big covariance matrices
    # are before fitting
)
model = WPPM(
    input_dim=input_dim,
    prior=prior,
    task=task,
    noise=noise,
    diag_term=1e-4,  # ensure positive-definite covariances
)

Training / Fitting

Here, we define the training loop that minimizes the model’s negative log posterior using stochastic gradient descent and momentum both with pshyphy and from scratch.

Fitting with psyphy
map_optimizer = MAPOptimizer(
    steps=steps, learning_rate=lr, momentum=momentum, track_history=True, log_every=1
)
# Initialize at prior sample
init_params = model.init_params(jax.random.PRNGKey(42))

map_posterior = map_optimizer.fit(
    model,
    data,
    init_params=init_params,
)

Implementing optimizers with jax -- exposing psyphy's MAP implementation in Jax

Below we illustrate the implementation of MAP, one of our optimizers implemented in the inference module. In particular, we will point out why and how we use Jax and optax.

A note on JAX: The key feature here is JAX’s Just-In-Time (JIT) compilation, which transforms our Python function into a single, optimized computation graph that runs efficiently on CPU, GPU, or TPU. To make this work, we represent parameters and optimizer states as PyTrees (nested dictionaries or tuples of arrays) — a core JAX data structure that supports efficient vectorization and differentiation. This approach lets us scale optimization and inference routines from small CPU experiments to large GPU-accelerated Bayesian models with minimal code changes.

From scratch: training loop exposing psyphy's MAP implementation in Jax
map_optimizer = MAPOptimizer(
    steps=steps, learning_rate=lr, momentum=momentum, track_history=True, log_every=1
)
# Initialize at prior sample
init_params = model.init_params(jax.random.PRNGKey(42))

map_posterior = map_optimizer.fit(
    model,
    data,
    init_params=init_params,
)