Skip to content

Quick start — fit your first covariance ellipse

Goal: run the full psyphy workflow — simulate data, fit a model, inspect the result — in one short script with no GPU required.

The complete runnable script is quick_start.py. For a spatially-varying field over a 2-D stimulus grid, see the full example.


Runtime

Hardware Approximate time
GPU (any modern CUDA device) < 5 s
CPU (laptop / M-series Mac) < 1 min

The three knobs that control runtime:

Compute settings (quick start defaults)
1
2
3
4
5
6
MC_SAMPLES = 50  # MC samples per trial in the likelihood (full example: 500)
NUM_TRIALS = 100  # total simulated trials (full example: 4000 × 25)
NUM_STEPS = 200  # optimizer steps (full example: 2000)

learning_rate = 5e-4  # full example: 5e-5. The smaller the lr, the more steps
# are required.

Step 0 — Imports

Imports
from psyphy.data import TrialData  # batched trial container
from psyphy.inference import MAPOptimizer  # fitter
from psyphy.model import (
    WPPM,
    GaussianNoise,
    OddityTask,
    OddityTaskConfig,
    Prior,
    WPPMCovarianceField,  # fast Σ(x) evaluation
)

Step 1 — Define a ground-truth model and sample parameters

We create a WPPM with known parameters to act as the synthetic observer. Data will be generated from it so we have a ground truth to compare against.

Ground-truth model
task = OddityTask(
    config=OddityTaskConfig(num_samples=int(MC_SAMPLES), bandwidth=float(bandwidth))
)
noise = GaussianNoise(sigma=0.1)

# Set all Wishart process hyperparameters in Prior
truth_prior = Prior(
    input_dim=input_dim,
    basis_degree=basis_degree,
    extra_embedding_dims=extra_dims,
    decay_rate=decay_rate,
    variance_scale=variance_scale,
)
truth_model = WPPM(
    input_dim=input_dim,
    extra_dims=extra_dims,
    prior=truth_prior,
    task=task,
    noise=noise,
    diag_term=diag_term,
)

# Sample ground-truth Wishart process weights
truth_params = truth_model.init_params(jax.random.PRNGKey(123))

Step 2 — Simulate trials at a single reference point

We generate NUM_TRIALS oddity-task responses at a single reference stimulus ref = [0, 0]. Probe displacements are scaled by the local covariance (constant Mahalanobis radius), so trial difficulty stays roughly uniform.

Simulate data
# Single reference point at the centre of the stimulus space.
ref_point = jnp.array([[0.0, 0.0]])  # shape (1, 2) — kept as a batch for generality

seed = 3
key = jr.PRNGKey(seed)

# Repeat the reference point for every trial.
refs = jnp.repeat(ref_point, repeats=NUM_TRIALS, axis=0)  # (NUM_TRIALS, 2)

# Evaluate Σ at the reference point.
truth_field = WPPMCovarianceField(truth_model, truth_params)
Sigmas_ref = truth_field(refs)  # (NUM_TRIALS, 2, 2)

# Sample unit directions and build covariance-scaled probe displacements.
k_dir, k_pred, k_y = jr.split(key, 3)
angles = jr.uniform(k_dir, shape=(NUM_TRIALS,), minval=0.0, maxval=2.0 * jnp.pi)
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1)  # (N, 2)

# Constant Mahalanobis radius: probe = ref + MAHAL_RADIUS * chol(Σ_ref) @ unit_dir
MAHAL_RADIUS = 2.8
L = jnp.linalg.cholesky(Sigmas_ref)  # (N, 2, 2)
# location of comparisons = ref+delta
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs)  # (N, 2)
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)

# Compute p(correct) via MC simulation of the oddity task.
trial_pred_keys = jr.split(k_pred, NUM_TRIALS)


def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
    return task._simulate_trial_mc(
        params=truth_params,
        ref=ref,
        comparison=comp,
        model=truth_model,
        noise=truth_model.noise,
        num_samples=int(task.config.num_samples),
        bandwidth=float(task.config.bandwidth),
        key=kk,
    )


p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)

# Sample observed responses y ~ Bernoulli(p_correct).
ys = jr.bernoulli(k_y, p_correct, shape=(NUM_TRIALS,)).astype(jnp.int32)

data = TrialData(
    refs=refs, comparisons=comparisons, responses=ys
)  # contains 3 JAX arrays

The TrialData container is the canonical input for fitting:

Data container
1
2
3
data = TrialData(
    refs=refs, comparisons=comparisons, responses=ys
)  # contains 3 JAX arrays

Step 3 — Build the model to fit

We build a fresh WPPM with the same hyperparameters but independent random weights, then take one draw from the prior as the starting point for optimization.

Model definition
prior = Prior(
    input_dim=input_dim,
    basis_degree=basis_degree,
    extra_embedding_dims=extra_dims,
    decay_rate=decay_rate,
    variance_scale=variance_scale,
)
model = WPPM(
    input_dim=input_dim,
    prior=prior,
    task=task,
    noise=noise,
    diag_term=1e-4,
)
Prior sample (initialization)
1
2
3
4
5
# Initialize parameters at a sample from the prior
init_params = model.init_params(jax.random.PRNGKey(42))
init_field = WPPMCovarianceField(model, init_params)
# Evaluate prior covariance at the reference point
covs_prior = init_field(ref_point)  # (1, 2, 2)

Step 4 — Fit with MAP optimization

Fit with MAPOptimizer
1
2
3
4
5
6
7
8
9
map_optimizer = MAPOptimizer(
    steps=NUM_STEPS,
    learning_rate=learning_rate,
    momentum=momentum,
    track_history=True,
    log_every=1,
)

map_posterior = map_optimizer.fit(model, data, init_params=init_params, seed=4)

MAPOptimizer runs SGD + momentum and returns a MAPPosterior — a point estimate at \(W_\text{MAP}\).


Step 5 — Inspect the fitted covariance ellipse

WPPMCovarianceField binds a (model, params) pair into a single callable that returns \(\Sigma(x)\) for any stimulus x:

Evaluate covariance fields
1
2
3
4
# Evaluate any covariance-field object at a single point or a batch of points.
covs_truth = truth_field(vis_points)  # (1, 2, 2)
covs_init = init_field(vis_points)  # (1, 2, 2)
covs_map = map_field(vis_points)  # (1, 2, 2)

The ellipse plot below overlays the ground truth (black), the prior initialization (blue), and the MAP fit (red) at the single reference point:

Plot ellipses
_PLOT_JITTER = 0.0

# Use the single reference point as the visualization centre.
vis_points = ref_point  # (1, 2)

map_field = WPPMCovarianceField(model, map_posterior.params)

# Evaluate any covariance-field object at a single point or a batch of points.
covs_truth = truth_field(vis_points)  # (1, 2, 2)
covs_init = init_field(vis_points)  # (1, 2, 2)
covs_map = map_field(vis_points)  # (1, 2, 2)

# Scale ellipses so they are visually readable.
gt_scale = float(jnp.sqrt(jnp.mean(jnp.linalg.eigvalsh(covs_truth[0]))))
ellipse_scale = max(0.3, 0.4 * gt_scale / 0.01)  # keep readable on the unit square

fig, ax = plt.subplots(figsize=(6, 6))

labels = ["Ground Truth", "Prior Sample (init)", "Fitted (MAP)"]
colors = ["k", "b", "r"]
fields = [truth_field, init_field, map_field]
non_pd_counts = []

for field, color, label in zip(fields, colors, labels):
    covs = field(vis_points)
    segments, valid = _ellipse_segments_from_covs(
        vis_points,
        covs,
        scale=ellipse_scale,
        plot_jitter=_PLOT_JITTER,
        unit_circle=_UNIT_CIRCLE,
    )
    non_pd_counts.append(int((~valid).sum()))
    lc = LineCollection(
        jax.device_get(segments),
        colors=color,
        linewidths=2.0,
        alpha=0.8,
    )
    ax.add_collection(lc)
    ax.plot([], [], color=color, alpha=0.8, linewidth=1.5, label=label)

ax.scatter(
    vis_points[:, 0], vis_points[:, 1], c="g", s=40, zorder=5, label="Reference Point"
)
ax.set_xlim(-0.6, 0.6)
ax.set_ylim(-0.6, 0.6)
ax.set_aspect("equal", adjustable="box")
ax.set_xlabel("Stimulus dimension 1")
ax.set_ylabel("Stimulus dimension 2")
ax.set_title(
    f"Covariance ellipse at ref={ref_point[0].tolist()}\n"
    f"lr={learning_rate}, steps={NUM_STEPS}, MC-samples={MC_SAMPLES}, trials={NUM_TRIALS}"
)
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right")
plt.tight_layout()

os.makedirs(PLOTS_DIR, exist_ok=True)
fig.savefig(
    os.path.join(PLOTS_DIR, "quick_start_ellipses.png"), dpi=200, bbox_inches="tight"
)
print(f"  Saved → {PLOTS_DIR}/quick_start_ellipses.png")
Covariance ellipses: ground truth (black), prior (blue), MAP fit (red)

Ground truth (black), prior sample (blue), and MAP-fitted (red) covariance ellipses at the single reference point. Note how the red ellipse has moved from the blue (init) closer to the ground truth (black).


Step 6 — Learning curve

Access learning curve
steps_hist, loss_hist = map_optimizer.get_history()
Learning curve

Negative log-likelihood over optimizer steps.


Next steps

  • Spatially-varying field: scale up to a full 2-D grid → full example.
  • Your own data: replace the simulated TrialData with your own refs, comparisons, and responses arrays.
  • API reference: see MAPOptimizer, WPPM, and WPPMCovarianceField.