Skip to content

Quickstart

psyphy logo

Psychophysical Modeling and Adaptive Trial Placement

Installation | Documentation | Examples | Contributing


Quick-start walkthrough — fit your first covariance ellipse

The snippet below shows the minimal end-to-end workflow: simulate a handful of oddity-task trials at a single reference point, fit the WPPM with MAP optimization, and visualize the result. No GPU needed — runs in under 2 min on CPU.

The complete runnable script is quick_start.py. A step-by-step explanation lives in the Quick-start example.

Imports

Imports
from psyphy.data import TrialData  # batched trial container
from psyphy.inference import MAPOptimizer  # fitter
from psyphy.model import (
    WPPM,
    GaussianNoise,
    OddityTask,
    OddityTaskConfig,
    Prior,
    WPPMCovarianceField,  # fast Σ(x) evaluation
)

Compute settings

Compute settings
1
2
3
4
5
6
MC_SAMPLES = 50  # MC samples per trial in the likelihood (full example: 500)
NUM_TRIALS = 100  # total simulated trials (full example: 4000 × 25)
NUM_STEPS = 200  # optimizer steps (full example: 2000)

learning_rate = 5e-4  # full example: 5e-5. The smaller the lr, the more steps
# are required.

Ground-truth model + simulate data

Ground-truth model
task = OddityTask(
    config=OddityTaskConfig(num_samples=int(MC_SAMPLES), bandwidth=float(bandwidth))
)
noise = GaussianNoise(sigma=0.1)

# Set all Wishart process hyperparameters in Prior
truth_prior = Prior(
    input_dim=input_dim,
    basis_degree=basis_degree,
    extra_embedding_dims=extra_dims,
    decay_rate=decay_rate,
    variance_scale=variance_scale,
)
truth_model = WPPM(
    input_dim=input_dim,
    extra_dims=extra_dims,
    prior=truth_prior,
    task=task,
    noise=noise,
    diag_term=diag_term,
)

# Sample ground-truth Wishart process weights
truth_params = truth_model.init_params(jax.random.PRNGKey(123))
Simulate data
# Single reference point at the centre of the stimulus space.
ref_point = jnp.array([[0.0, 0.0]])  # shape (1, 2) — kept as a batch for generality

seed = 3
key = jr.PRNGKey(seed)

# Repeat the reference point for every trial.
refs = jnp.repeat(ref_point, repeats=NUM_TRIALS, axis=0)  # (NUM_TRIALS, 2)

# Evaluate Σ at the reference point.
truth_field = WPPMCovarianceField(truth_model, truth_params)
Sigmas_ref = truth_field(refs)  # (NUM_TRIALS, 2, 2)

# Sample unit directions and build covariance-scaled probe displacements.
k_dir, k_pred, k_y = jr.split(key, 3)
angles = jr.uniform(k_dir, shape=(NUM_TRIALS,), minval=0.0, maxval=2.0 * jnp.pi)
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1)  # (N, 2)

# Constant Mahalanobis radius: probe = ref + MAHAL_RADIUS * chol(Σ_ref) @ unit_dir
MAHAL_RADIUS = 2.8
L = jnp.linalg.cholesky(Sigmas_ref)  # (N, 2, 2)
# location of comparisons = ref+delta
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs)  # (N, 2)
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)

# Compute p(correct) via MC simulation of the oddity task.
trial_pred_keys = jr.split(k_pred, NUM_TRIALS)


def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
    return task._simulate_trial_mc(
        params=truth_params,
        ref=ref,
        comparison=comp,
        model=truth_model,
        noise=truth_model.noise,
        num_samples=int(task.config.num_samples),
        bandwidth=float(task.config.bandwidth),
        key=kk,
    )


p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)

# Sample observed responses y ~ Bernoulli(p_correct).
ys = jr.bernoulli(k_y, p_correct, shape=(NUM_TRIALS,)).astype(jnp.int32)

data = TrialData(
    refs=refs, comparisons=comparisons, responses=ys
)  # contains 3 JAX arrays

Build model and fit

Model definition
prior = Prior(
    input_dim=input_dim,
    basis_degree=basis_degree,
    extra_embedding_dims=extra_dims,
    decay_rate=decay_rate,
    variance_scale=variance_scale,
)
model = WPPM(
    input_dim=input_dim,
    prior=prior,
    task=task,
    noise=noise,
    diag_term=1e-4,
)
Fit with MAPOptimizer
1
2
3
4
5
6
7
8
9
map_optimizer = MAPOptimizer(
    steps=NUM_STEPS,
    learning_rate=learning_rate,
    momentum=momentum,
    track_history=True,
    log_every=1,
)

map_posterior = map_optimizer.fit(model, data, init_params=init_params, seed=4)

Results

Covariance ellipses: ground truth (black), prior (blue), MAP fit (red)

Ground truth (black), prior sample (blue), and MAP-fitted (red) covariance ellipses at the single reference point.

Learning curve

Negative log-likelihood over optimizer steps.