# Single reference point at the centre of the stimulus space.
ref_point = jnp.array([[0.0, 0.0]]) # shape (1, 2) — kept as a batch for generality
seed = 3
key = jr.PRNGKey(seed)
# Repeat the reference point for every trial.
refs = jnp.repeat(ref_point, repeats=NUM_TRIALS, axis=0) # (NUM_TRIALS, 2)
# Evaluate Σ at the reference point.
truth_field = WPPMCovarianceField(truth_model, truth_params)
Sigmas_ref = truth_field(refs) # (NUM_TRIALS, 2, 2)
# Sample unit directions and build covariance-scaled probe displacements.
k_dir, k_pred, k_y = jr.split(key, 3)
angles = jr.uniform(k_dir, shape=(NUM_TRIALS,), minval=0.0, maxval=2.0 * jnp.pi)
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1) # (N, 2)
# Constant Mahalanobis radius: probe = ref + MAHAL_RADIUS * chol(Σ_ref) @ unit_dir
MAHAL_RADIUS = 2.8
L = jnp.linalg.cholesky(Sigmas_ref) # (N, 2, 2)
# location of comparisons = ref+delta
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs) # (N, 2)
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)
# Compute p(correct) via MC simulation of the oddity task.
trial_pred_keys = jr.split(k_pred, NUM_TRIALS)
def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
return task._simulate_trial_mc(
params=truth_params,
ref=ref,
comparison=comp,
model=truth_model,
noise=truth_model.noise,
num_samples=int(task.config.num_samples),
bandwidth=float(task.config.bandwidth),
key=kk,
)
p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)
# Sample observed responses y ~ Bernoulli(p_correct).
ys = jr.bernoulli(k_y, p_correct, shape=(NUM_TRIALS,)).astype(jnp.int32)
data = TrialData(
refs=refs, comparisons=comparisons, responses=ys
) # contains 3 JAX arrays