Quick start — fit your first covariance ellipse
Goal: run the full psyphy workflow — simulate data, fit a model, inspect the result — in one short script with no GPU required.
The complete runnable script is quick_start.py.
For a spatially-varying field over a 2-D stimulus grid, see the full example.
Runtime
| Hardware |
Approximate time |
| GPU (any modern CUDA device) |
< 5 s |
| CPU (laptop / M-series Mac) |
< 1 min |
The three knobs that control runtime:
| Compute settings (quick start defaults) |
|---|
| MC_SAMPLES = 50 # MC samples per trial in the likelihood (full example: 500)
NUM_TRIALS = 100 # total simulated trials (full example: 4000 × 25)
NUM_STEPS = 200 # optimizer steps (full example: 2000)
learning_rate = 5e-4 # full example: 5e-5. The smaller the lr, the more steps
# are required.
|
Step 0 — Imports
| Imports |
|---|
| from psyphy.data import TrialData # batched trial container
from psyphy.inference import MAPOptimizer # fitter
from psyphy.model import (
WPPM,
GaussianNoise,
OddityTask,
OddityTaskConfig,
Prior,
WPPMCovarianceField, # fast Σ(x) evaluation
)
|
Step 1 — Define a ground-truth model and sample parameters
We create a WPPM with known parameters to act as the synthetic observer.
Data will be generated from it so we have a ground truth to compare against.
| Ground-truth model |
|---|
| task = OddityTask(
config=OddityTaskConfig(num_samples=int(MC_SAMPLES), bandwidth=float(bandwidth))
)
noise = GaussianNoise(sigma=0.1)
# Set all Wishart process hyperparameters in Prior
truth_prior = Prior(
input_dim=input_dim,
basis_degree=basis_degree,
extra_embedding_dims=extra_dims,
decay_rate=decay_rate,
variance_scale=variance_scale,
)
truth_model = WPPM(
input_dim=input_dim,
extra_dims=extra_dims,
prior=truth_prior,
task=task,
noise=noise,
diag_term=diag_term,
)
# Sample ground-truth Wishart process weights
truth_params = truth_model.init_params(jax.random.PRNGKey(123))
|
Step 2 — Simulate trials at a single reference point
We generate NUM_TRIALS oddity-task responses at a single reference stimulus
ref = [0, 0]. Probe displacements are scaled by the local covariance
(constant Mahalanobis radius), so trial difficulty stays roughly uniform.
| Simulate data |
|---|
| # Single reference point at the centre of the stimulus space.
ref_point = jnp.array([[0.0, 0.0]]) # shape (1, 2) — kept as a batch for generality
seed = 3
key = jr.PRNGKey(seed)
# Repeat the reference point for every trial.
refs = jnp.repeat(ref_point, repeats=NUM_TRIALS, axis=0) # (NUM_TRIALS, 2)
# Evaluate Σ at the reference point.
truth_field = WPPMCovarianceField(truth_model, truth_params)
Sigmas_ref = truth_field(refs) # (NUM_TRIALS, 2, 2)
# Sample unit directions and build covariance-scaled probe displacements.
k_dir, k_pred, k_y = jr.split(key, 3)
angles = jr.uniform(k_dir, shape=(NUM_TRIALS,), minval=0.0, maxval=2.0 * jnp.pi)
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1) # (N, 2)
# Constant Mahalanobis radius: probe = ref + MAHAL_RADIUS * chol(Σ_ref) @ unit_dir
MAHAL_RADIUS = 2.8
L = jnp.linalg.cholesky(Sigmas_ref) # (N, 2, 2)
# location of comparisons = ref+delta
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs) # (N, 2)
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)
# Compute p(correct) via MC simulation of the oddity task.
trial_pred_keys = jr.split(k_pred, NUM_TRIALS)
def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
return task._simulate_trial_mc(
params=truth_params,
ref=ref,
comparison=comp,
model=truth_model,
noise=truth_model.noise,
num_samples=int(task.config.num_samples),
bandwidth=float(task.config.bandwidth),
key=kk,
)
p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)
# Sample observed responses y ~ Bernoulli(p_correct).
ys = jr.bernoulli(k_y, p_correct, shape=(NUM_TRIALS,)).astype(jnp.int32)
data = TrialData(
refs=refs, comparisons=comparisons, responses=ys
) # contains 3 JAX arrays
|
The TrialData container is the canonical input for fitting:
| Data container |
|---|
| data = TrialData(
refs=refs, comparisons=comparisons, responses=ys
) # contains 3 JAX arrays
|
Step 3 — Build the model to fit
We build a fresh WPPM with the same hyperparameters but independent random
weights, then take one draw from the prior as the starting point for
optimization.
| Model definition |
|---|
| prior = Prior(
input_dim=input_dim,
basis_degree=basis_degree,
extra_embedding_dims=extra_dims,
decay_rate=decay_rate,
variance_scale=variance_scale,
)
model = WPPM(
input_dim=input_dim,
prior=prior,
task=task,
noise=noise,
diag_term=1e-4,
)
|
| Prior sample (initialization) |
|---|
| # Initialize parameters at a sample from the prior
init_params = model.init_params(jax.random.PRNGKey(42))
init_field = WPPMCovarianceField(model, init_params)
# Evaluate prior covariance at the reference point
covs_prior = init_field(ref_point) # (1, 2, 2)
|
Step 4 — Fit with MAP optimization
| Fit with MAPOptimizer |
|---|
| map_optimizer = MAPOptimizer(
steps=NUM_STEPS,
learning_rate=learning_rate,
momentum=momentum,
track_history=True,
log_every=1,
)
map_posterior = map_optimizer.fit(model, data, init_params=init_params, seed=4)
|
MAPOptimizer runs SGD + momentum and returns a MAPPosterior — a point
estimate at \(W_\text{MAP}\).
Step 5 — Inspect the fitted covariance ellipse
WPPMCovarianceField binds a (model, params) pair into a single callable
that returns \(\Sigma(x)\) for any stimulus x:
| Evaluate covariance fields |
|---|
| # Evaluate any covariance-field object at a single point or a batch of points.
covs_truth = truth_field(vis_points) # (1, 2, 2)
covs_init = init_field(vis_points) # (1, 2, 2)
covs_map = map_field(vis_points) # (1, 2, 2)
|
The ellipse plot below overlays the ground truth (black), the prior
initialization (blue), and the MAP fit (red) at the single reference point:
| Plot ellipses |
|---|
| _PLOT_JITTER = 0.0
# Use the single reference point as the visualization centre.
vis_points = ref_point # (1, 2)
map_field = WPPMCovarianceField(model, map_posterior.params)
# Evaluate any covariance-field object at a single point or a batch of points.
covs_truth = truth_field(vis_points) # (1, 2, 2)
covs_init = init_field(vis_points) # (1, 2, 2)
covs_map = map_field(vis_points) # (1, 2, 2)
# Scale ellipses so they are visually readable.
gt_scale = float(jnp.sqrt(jnp.mean(jnp.linalg.eigvalsh(covs_truth[0]))))
ellipse_scale = max(0.3, 0.4 * gt_scale / 0.01) # keep readable on the unit square
fig, ax = plt.subplots(figsize=(6, 6))
labels = ["Ground Truth", "Prior Sample (init)", "Fitted (MAP)"]
colors = ["k", "b", "r"]
fields = [truth_field, init_field, map_field]
non_pd_counts = []
for field, color, label in zip(fields, colors, labels):
covs = field(vis_points)
segments, valid = _ellipse_segments_from_covs(
vis_points,
covs,
scale=ellipse_scale,
plot_jitter=_PLOT_JITTER,
unit_circle=_UNIT_CIRCLE,
)
non_pd_counts.append(int((~valid).sum()))
lc = LineCollection(
jax.device_get(segments),
colors=color,
linewidths=2.0,
alpha=0.8,
)
ax.add_collection(lc)
ax.plot([], [], color=color, alpha=0.8, linewidth=1.5, label=label)
ax.scatter(
vis_points[:, 0], vis_points[:, 1], c="g", s=40, zorder=5, label="Reference Point"
)
ax.set_xlim(-0.6, 0.6)
ax.set_ylim(-0.6, 0.6)
ax.set_aspect("equal", adjustable="box")
ax.set_xlabel("Stimulus dimension 1")
ax.set_ylabel("Stimulus dimension 2")
ax.set_title(
f"Covariance ellipse at ref={ref_point[0].tolist()}\n"
f"lr={learning_rate}, steps={NUM_STEPS}, MC-samples={MC_SAMPLES}, trials={NUM_TRIALS}"
)
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right")
plt.tight_layout()
os.makedirs(PLOTS_DIR, exist_ok=True)
fig.savefig(
os.path.join(PLOTS_DIR, "quick_start_ellipses.png"), dpi=200, bbox_inches="tight"
)
print(f" Saved → {PLOTS_DIR}/quick_start_ellipses.png")
|
Ground truth (black), prior sample (blue), and MAP-fitted (red) covariance ellipses at the single reference point. Note how the red ellipse has moved from the blue (init) closer to the ground truth (black).
Step 6 — Learning curve
| Access learning curve |
|---|
| steps_hist, loss_hist = map_optimizer.get_history()
|
Negative log-likelihood over optimizer steps.
Next steps