Skip to content

Full WPPM fit (end-to-end) — simulated 2D data

This tutorial explains what the example script full_wppm_fit_example.py is doing, and where the key functions live in the psyphy codebase.

  • Goal: Fit a spatially varying covariance field \(\Sigma(x)\) over a 2D stimulus space \(x \in [-1,1]^2\) using the Wishart Process Psychophysical Model (WPPM).
  • Data: synthetic oddity-task responses simulated from a ``ground-truth'' WPPM.
  • Inference: MAP (maximum a posteriori) optimization of the WPPM parameters.

You can treat this as a ``recipe'' for using the Wishart Psychophysical Process Model (WPPM) in your own project: build a model, initialize parameters, fit the model, and visualize fitted predicted thresholds.

NOTE: Running this script takes about 3 min on a A100 40GB. If you want to accelarate it to for example run it on a CPU, decrease the number of MC-Samples significantly (e.g., 5) and the number of steps the optimizer is running for.


What the Wishart Psychophysical Process Model (WPPM) is in a nutshell

WPPM defines a covariance matrix field \(\Sigma(x)\) over stimulus space (e.g. color represented in RGB). Intuitively, \(\Sigma(x)\) describes the local noise/uncertainty ellipse around stimulus \(x\) where stimulus within that ellipse will be perceived as identical to the human observer.

The model represents \(\Sigma(x)\) as

\[ \Sigma(x) = U(x)U(x)^\top + \varepsilon I, \]

where \(U(x)\) is a smooth, basis-expanded matrix-valued function and \(\varepsilon\) is a small diagonal “jitter” (diag_term) to avoid numerical issues. Alternatively, in Gaussian Process (GP) terms, you can think of \(U(x)\) defining a GP in weight space, i.e., a "Bayesian linear model".

A psychophysical task model (here: OddityTask) uses \(\Sigma\) to compute probability of a correct response on each trial, and MAPOptimizer fits WPPM parameters by maximizing

\[ \log p(\theta \mid \mathcal{D}) = \log p(\mathcal{D} \mid \theta) + \log p(\theta). \]

For more details on how the Wishart Psychophysical Model (WPPM) works, and the psychophysical task used in this example, please checkout the paper by Hong et al (2025) and this tutorial.


Step 0 — Imports and setup

The following imports are important to set the model up and fit to date:

Ground-truth model + prior sample
# (imports above are included via mkdocs-snippets)

Step 1 — Define the prior (how weights are distributed initially)

The WPPM parameters are basis weights stored as a dict:

  • params = {"W": W}

where W is a tensor of Chebyshev-basis coefficients.

Prior distribution over weights

See src/psyphy/model/prior.py:

  • Prior.sample_params(key) samples weights W from a zero-mean Gaussian with a degree-dependent variance.

For 2D, the weight tensor shape is

\[ W \in \mathbb{R}^{(d+1) \times (d+1) \times D \times E}, \]

where:

  • \(d\) = basis_degree
  • \(D\) = input_dim (here 2)
  • \(E\) = embedding_dim = input_dim + extra_embedding_dims

The prior variance decays with basis “total degree”. In code:

  • Prior._compute_basis_degree_grid() constructs degrees \(i+j\) (2D) or \(i+j+k\) (3D).
  • Prior._compute_W_prior_variances() returns
\[ \sigma^2_{ij} = \texttt{variance_scale} \cdot (\texttt{decay_rate})^{(i+j)}. \]
  • Prior.sample_params(...) then samples
\[ W_{ijde} \sim \mathcal{N}(0, \sigma^2_{ij}). \]

This is the state of the WPPM: before any data, WPPM draws smooth random fields because high-frequency coefficients are shrunk by the decay.

Ground-truth model + prior sample
# Set all Wishart process hyperparameters in Prior
truth_prior = Prior(
    input_dim=input_dim,  # (2D)
    basis_degree=basis_degree,  # (5)
    extra_embedding_dims=extra_dims,  # (1)
    decay_rate=decay_rate,  # for basis functions
    variance_scale=variance_scale,  # how big covariance matrices
    # are before fitting
)
truth_model = WPPM(
    input_dim=input_dim,
    extra_dims=extra_dims,
    prior=truth_prior,
    task=task,  # oddity task ("pick the odd-one out among 3 stimuli")
    noise=noise,  # (Gaussian noise)
    diag_term=diag_term,  # ensure positive-definite covariances
)

# Sample ground-truth Wishart process weights
truth_params = truth_model.init_params(jax.random.PRNGKey(123))

A sample from the prior


Step 2 — Build the model (WPPM + task + noise)

We create a model by combining its prior and likelihood. Note that the task inherently defines the likelihood. Hence, think of them interchangebly. In psyphy, think of model as simply a container of the prior and the likelihood. - All prior specific hyerparameters are owned by the Prior. - Likewise, all likelihood specific hyerparameters are owned by the task. - besides being the container for prior and likelihood, the model also takes some compute specific arguments, such as diag_term, which ensures numeric stability by ensuring positive-definite matrices.

Model definition
prior = Prior(
    input_dim=input_dim,  # (2D)
    basis_degree=basis_degree,  # 5
    extra_embedding_dims=extra_dims,  # 1
    decay_rate=decay_rate,  # for basis functions (how quickly they vary)
    variance_scale=variance_scale,  # how big covariance matrices
    # are before fitting
)
model = WPPM(
    input_dim=input_dim,
    prior=prior,
    task=task,
    noise=noise,
    diag_term=1e-4,  # ensure positive-definite covariances
)

Step 3 — Evaluate the covariance field \(\Sigma(x)\)

The example uses a convenience wrapper:

Covariance field evaluation (Σ(x))
data = ResponseData()
num_trials_per_ref = NUM_TRIALS_TOTAL  # (trials per reference point)
n_ref_grid = 5  # NUM_GRID_PTS
ref_grid = jnp.linspace(-1, 1, n_ref_grid)  # [-1,1] space
ref_points = jnp.stack(jnp.meshgrid(ref_grid, ref_grid), axis=-1).reshape(-1, 2)

# --- Stimulus design: covariance-scaled probe displacements ---
#
# Rather than sampling probes at a fixed Euclidean radius, we scale the probe
# displacement by sqrt(Σ_ref). This tends to equalize trial difficulty across
# space (roughly constant Mahalanobis radius).

seed = 3
key = jr.PRNGKey(seed)

# Build a batched reference list by repeating each grid point.
n_ref = ref_points.shape[0]
refs = jnp.repeat(ref_points, repeats=num_trials_per_ref, axis=0)  # (N, 2)
num_trials_total = int(refs.shape[0])

# Evaluate Σ(ref) in batch using the psyphy covariance-field wrapper.
truth_field = WPPMCovarianceField(truth_model, truth_params)
Sigmas_ref = truth_field(refs)  # (N, 2, 2)

# Sample unit directions on the circle.
k_dir, k_pred, k_y = jr.split(key, 3)
angles = jr.uniform(k_dir, shape=(num_trials_total,), minval=0.0, maxval=2.0 * jnp.pi)
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1)  # (N, 2)


# displacement scale/orientation follows the local ellipse
# (constant-ish Mahalanobis radius).
# a constant Mahalanobis radius for generating probes around reference points
# MAHAL_RADIUS * chol(Sigma_ref) @ unit_dir
MAHAL_RADIUS = 2.8

# Compute sqrt(Σ) via Cholesky (Σ should be SPD; diag_term/noise keep it stable).
L = jnp.linalg.cholesky(Sigmas_ref)  # (N, 2, 2)
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs)  # (N, 2)
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)

# Compute p(correct) in batch. We vmap the single-trial predictor.
trial_pred_keys = jr.split(k_pred, num_trials_total)


def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
    # Task MC settings (num_samples/bandwidth) come from OddityTaskConfig.
    # Only the randomness is threaded dynamically.
    return task._simulate_trial_mc(
        params=truth_params,
        ref=ref,
        comparison=comp,
        model=truth_model,
        noise=truth_model.noise,
        num_samples=int(task.config.num_samples),
        bandwidth=float(task.config.bandwidth),
        key=kk,
    )


p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)

# Sample observed y ~ Bernoulli(p_correct) in batch.
ys = jr.bernoulli(k_y, p_correct, shape=(num_trials_total,)).astype(jnp.int32)

# Populate ResponseData (kept as Python loop since ResponseData is a Python container).
for ref, comp, y in zip(
    jax.device_get(refs), jax.device_get(comparisons), jax.device_get(ys)
):
    data.add_trial(ref=jnp.array(ref), comparison=jnp.array(comp), resp=int(y))

A sample from the prior

What field(x) does

At a high level:

  • Input: x with shape (D,) or (..., D).
  • Output: covariance matrix/matrices \(\Sigma(x)\) with shape (..., D, D).

Mathematically:

  1. Compute a basis feature vector \(\phi(x)\) (Chebyshev basis products).
  2. Form a matrix
\[ U(x) = \sum_{i,j} W_{ij}\, \phi_{ij}(x) \]

(where indices suppressed; the actual tensor contraction is done via einsum).

  1. Produce
\[ \Sigma(x) = U(x)U(x)^\top + \varepsilon I. \]

In the code, the name “sqrt” is often used for \(U(x)\): it is a square-root factor of the covariance (up to the diagonal term).

If you’re looking for the implementation details of the “sqrt” computation, search in src/psyphy/model/wppm.py for a helper named like _compute_sqrt (or similarly named). That’s where you’ll find the einsum contraction turning W and basis features into U(x).

Corresponding code block in the example

  • Field wrapper construction:
  • truth_field = WPPMCovarianceField(truth_model, truth_params)
  • init_field = WPPMCovarianceField(model, init_params)
  • map_field = WPPMCovarianceField(model, map_posterior.params)

  • Batched evaluation:

  • gt_covs = truth_field(ref_points)

Step 5 — Fit with MAP optimization

We fut parameters with SGD + momentum:

Fitting with psyphy (MAPOptimizer)
map_optimizer = MAPOptimizer(
    steps=steps, learning_rate=lr, momentum=momentum, track_history=True, log_every=1
)
# Initialize at prior sample
init_params = model.init_params(jax.random.PRNGKey(42))

map_posterior = map_optimizer.fit(
    model,
    data,
    init_params=init_params,
)

What is being optimized

MAP fitting finds

\[ \theta_\text{MAP} = \arg\max_{\theta} \big[\log p(\mathcal{D}\mid\theta) + \log p(\theta)\big]. \]
  • \(\log p(\theta)\) is from Prior.log_prob(params) (see prior.py).
  • \(\log p(\mathcal{D}\mid\theta)\) is computed by the task’s log-likelihood (here via Monte Carlo inside OddityTask.loglik).

The result in this example is a MAPPosterior object that contains a point estimate map_posterior.params.


Step 6 — Visualize fit vs. truth vs. prior sample

Fitted ellipsoids overlayed with ground truth and model initialization, a sample from the prior.


Access learning curve
steps_hist, loss_hist = map_optimizer.get_history()
print(f"num steps: {len(steps_hist)}, num losses: {len(loss_hist)}")
if steps_hist:
    print(f"history step range: [{steps_hist[0]}, {steps_hist[-1]}]")
if steps_hist and loss_hist:
    fig2, ax2 = plt.subplots(figsize=(6, 4))
    ax2.set_xlim(steps_hist[0], steps_hist[-1])
    ax2.plot(steps_hist, loss_hist, color="#4444aa")
    ax2.set_title(
        f"Learning curve — lr={lr}, steps={steps} - MC-samples={MC_SAMPLES}, num-trials={num_trials_total}"
    )
    ax2.set_xlabel("Step")
    ax2.set_ylabel("Neg log likelihood")
    ax2.grid(True, alpha=0.3)
    plt.tight_layout()
    fig2.savefig(
        os.path.join(PLOTS_DIR, "learning_curve.png"),
        dpi=200,
        bbox_inches="tight",
    )
else:
    print("No history recorded — set track_history=True in MAPOptimizer to enable.")

Learning curve.


Minimal recipe (copy/paste mental model)

To use WPPM on your own data, these are the essential calls:

  1. Create task + noise + prior:
  2. task = OddityTask()
  3. noise = GaussianNoise(sigma=...)
  4. prior = Prior(input_dim=..., basis_degree=..., extra_embedding_dims=..., decay_rate=..., variance_scale=...)

  5. Create WPPM:

  6. model = WPPM(input_dim=..., prior=prior, task=task, noise=noise, diag_term=...)

  7. Initialize parameters:

  8. params0 = model.init_params(jax.random.PRNGKey(...)) (draws from Prior.sample_params)

  9. Load/build a dataset:

  10. data = ResponseData(); data.add_trial(ref=..., comparison=..., resp=...)

  11. Fit:

  12. map = MAPOptimizer(...).fit(model, data, init_params=params0, ...)

  13. Inspect \(\Sigma(x)\):

  14. field = WPPMCovarianceField(model, map.params)
  15. Sigmas = field(xs)

Notes and pitfalls

  • CPU vs GPU: this example can be heavy because the oddity likelihood uses Monte Carlo. A GPU can help a lot.
  • Positive definiteness: diag_term is important. If you ever see a non-PD covariance, increase diag_term slightly.
  • MC variance: optimization stability depends on MC_SAMPLES. Too small means noisy gradients.

Next places to explore

  • Read the API docs in docs/reference/ (especially model + inference sections).
  • Inspect src/psyphy/model/prior.py if you want to change smoothness/regularization.
  • Inspect src/psyphy/model/covariance_field.py if you want faster / vmapped field evaluation patterns.

If your curious about some of the implementation details, checkout these files:

If you want to “follow the call graph”:

  1. WPPM.init_params(...) (defined in src/psyphy/model/wppm.py) → delegates to the prior’s Prior.sample_params(...) (defined in src/psyphy/model/prior.py).
  2. OddityTask.predict_with_kwargs(...) / OddityTask.loglik(...) (defined in src/psyphy/model/task.py) → calls into the model to get \(\Sigma(x)\) and then runs the task’s decision rule (Monte Carlo in the full model).
  3. WPPMCovarianceField(model, params) (defined in src/psyphy/model/covariance_field.py) → provides a callable field(x) that returns \(\Sigma(x)\) for single points or batches.
  4. MAPOptimizer.fit(...) (defined in src/psyphy/inference/map_optimizer.py) → runs gradient-based optimization of the negative log likelihood.