# Set all Wishart process hyperparameters in Prior
truth_prior = Prior(
input_dim=input_dim, # (2D)
basis_degree=basis_degree, # (5)
extra_embedding_dims=extra_dims, # (1)
decay_rate=decay_rate, # for basis functions
variance_scale=variance_scale, # how big covariance matrices
# are before fitting
)
truth_model = WPPM(
input_dim=input_dim,
extra_dims=extra_dims,
prior=truth_prior,
task=task, # oddity task ("pick the odd-one out among 3 stimuli")
noise=noise, # (Gaussian noise)
diag_term=diag_term, # ensure positive-definite covariances
)
# Sample ground-truth Wishart process weights
truth_params = truth_model.init_params(jax.random.PRNGKey(123))
data = ResponseData()
num_trials_per_ref = NUM_TRIALS_TOTAL # (trials per reference point)
n_ref_grid = 5 # NUM_GRID_PTS
ref_grid = jnp.linspace(-1, 1, n_ref_grid) # [-1,1] space
ref_points = jnp.stack(jnp.meshgrid(ref_grid, ref_grid), axis=-1).reshape(-1, 2)
# --- Stimulus design: covariance-scaled probe displacements ---
#
# Rather than sampling probes at a fixed Euclidean radius, we scale the probe
# displacement by sqrt(Σ_ref). This tends to equalize trial difficulty across
# space (roughly constant Mahalanobis radius).
seed = 3
key = jr.PRNGKey(seed)
# Build a batched reference list by repeating each grid point.
n_ref = ref_points.shape[0]
refs = jnp.repeat(ref_points, repeats=num_trials_per_ref, axis=0) # (N, 2)
num_trials_total = int(refs.shape[0])
# Evaluate Σ(ref) in batch using the psyphy covariance-field wrapper.
truth_field = WPPMCovarianceField(truth_model, truth_params)
Sigmas_ref = truth_field(refs) # (N, 2, 2)
# Sample unit directions on the circle.
k_dir, k_pred, k_y = jr.split(key, 3)
angles = jr.uniform(k_dir, shape=(num_trials_total,), minval=0.0, maxval=2.0 * jnp.pi)
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1) # (N, 2)
# displacement scale/orientation follows the local ellipse
# (constant-ish Mahalanobis radius).
# a constant Mahalanobis radius for generating probes around reference points
# MAHAL_RADIUS * chol(Sigma_ref) @ unit_dir
MAHAL_RADIUS = 2.8
# Compute sqrt(Σ) via Cholesky (Σ should be SPD; diag_term/noise keep it stable).
L = jnp.linalg.cholesky(Sigmas_ref) # (N, 2, 2)
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs) # (N, 2)
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)
# Compute p(correct) in batch. We vmap the single-trial predictor.
trial_pred_keys = jr.split(k_pred, num_trials_total)
def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
# Task MC settings (num_samples/bandwidth) come from OddityTaskConfig.
# Only the randomness is threaded dynamically.
return task._simulate_trial_mc(
params=truth_params,
ref=ref,
comparison=comp,
model=truth_model,
noise=truth_model.noise,
num_samples=int(task.config.num_samples),
bandwidth=float(task.config.bandwidth),
key=kk,
)
p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)
# Sample observed y ~ Bernoulli(p_correct) in batch.
ys = jr.bernoulli(k_y, p_correct, shape=(num_trials_total,)).astype(jnp.int32)
# Populate ResponseData (kept as Python loop since ResponseData is a Python container).
for ref, comp, y in zip(
jax.device_get(refs), jax.device_get(comparisons), jax.device_get(ys)
):
data.add_trial(ref=jnp.array(ref), comparison=jnp.array(comp), resp=int(y))
prior = Prior(
input_dim=input_dim, # (2D)
basis_degree=basis_degree, # 5
extra_embedding_dims=extra_dims, # 1
decay_rate=decay_rate, # for basis functions (how quickly they vary)
variance_scale=variance_scale, # how big covariance matrices
# are before fitting
)
model = WPPM(
input_dim=input_dim,
prior=prior,
task=task,
noise=noise,
diag_term=1e-4, # ensure positive-definite covariances
)