Skip to content

Psyphy


psyphy

psyphy

Psychophysical modeling and adaptive trial placement.

This package implements the Wishart Process Psychophysical Model (WPPM) with modular components for priors, task likelihoods, and noise models, which can be fitted to incoming subject data and used to adaptively select new trials to present to the subject next. This is useful for efficiently estimating psychophysical parameters (e.g. threshold contours) with minimal trials.


Workflow
Core design
  1. WPPM (model/wppm.py):
  2. Structural definition of the psychophysical model.
  3. Maintains parameterization of local covariance fields.
  4. Computes discriminability between stimuli.
  5. Delegates trial likelihoods and predictions to the task.

  6. Prior (model/prior.py):

  7. Defines the distribution over model parameters.
  8. WPPM: structured prior over basis weights and decay_rate-controlled covariance fields.

  9. TaskLikelihood (model/task.py):

  10. Encodes the psychophysical decision rule.
  11. WPPM: loglik and predict implemented via Monte Carlo observer simulations, using the noise model explicitly.

  12. NoiseModel (model/noise.py):

  13. Defines the distribution of internal representation noise.
  14. WPPM: GaussianNoise or StudentTNoise option.
Unified import style

Top-level (core models + session): from psyphy import WPPM, Prior, OddityTask, GaussianNoise, MAPOptimizer from psyphy import ExperimentSession, ResponseData, TrialBatch

Subpackages: from psyphy.model import WPPM, Prior, OddityTask, GaussianNoise, StudentTNoise from psyphy.inference import MAPOptimizer, LangevinSampler, LaplaceApproximation from psyphy.posterior import Posterior, effective_sample_size, rhat from psyphy.acquisition import expected_improvement, upper_confidence_bound, mutual_information from psyphy.acquisition import optimize_acqf, optimize_acqf_discrete, optimize_acqf_random from psyphy.trial_placement import GridPlacement, SobolPlacement from psyphy.utils import grid_candidates, sobol_candidates, custom_candidates, chebyshev_basis from psyphy.utils import bootstrap_predictions, bootstrap_statistic, bootstrap_compare_models

Data flow
  • A ResponseData object (psyphy.data) contains trial stimuli and responses.
  • WPPM.init_params(prior) samples parameter initialization.
  • Inference engines optimize the log posterior: log_posterior = task.loglik(params, data, model=WPPM, noise=NoiseModel) + prior.log_prob(params)
  • Posterior predictions (p(correct), threshold ellipses) are always obtained through WPPM delegating to TaskLikelihood.
Extensibility
  • To add a new task: subclass TaskLikelihood, implement predict/loglik.
  • To add a new noise model: subclass NoiseModel, implement logpdf/sample.

Classes:

Name Description
ExperimentSession

High-level experiment orchestrator.

GaussianNoise
LangevinSampler

Langevin sampler (stub).

LaplaceApproximation

Laplace approximation around MAP estimate.

MAPOptimizer

MAP (Maximum A Posteriori) optimizer.

OddityTask

Three-alternative forced-choice oddity task (MC-based only).

Prior

Prior distribution over WPPM parameters

ResponseData

Container for psychophysical trial data.

StudentTNoise
TrialBatch

Container for a proposed batch of trials

WPPM

Wishart Process Psychophysical Model (WPPM).

Attributes:

Name Type Description
Posterior

Posterior

Posterior = MAPPosterior

ExperimentSession

ExperimentSession(
    model, inference, placement, init_placement=None
)

High-level experiment orchestrator.

Parameters:

Name Type Description Default
model WPPM

(Psychophysical) model instance.

required
inference InferenceEngine

Inference engine (MAP, Langevin, etc.).

required
placement TrialPlacement

Adaptive trial placement strategy.

required
init_placement TrialPlacement

Initial placement strategy (e.g., Sobol exploration).

None

Attributes:

Name Type Description
data ResponseData

Stores all collected trials.

posterior Posterior or None

Current posterior estimate (None before initialization).

Methods:

Name Description
initialize

Fit an initial posterior before any adaptive placement.

next_batch

Propose the next batch of trials.

update

Refit posterior with accumulated data.

Source code in src/psyphy/session/experiment_session.py
def __init__(self, model, inference, placement, init_placement=None):
    self.model = model
    self.inference = inference
    self.placement = placement
    self.init_placement = init_placement

    # Data store starts empty
    self.data = ResponseData()

    # Posterior will be set after initialize() or update()
    self.posterior = None

data

data = ResponseData()

inference

inference = inference

init_placement

init_placement = init_placement

model

model = model

placement

placement = placement

posterior

posterior = None

initialize

initialize()

Fit an initial posterior before any adaptive placement.

Returns:

Type Description
Posterior

Posterior object wrapping fitted parameters.

Notes

MVP: Posterior is fitted to empty data (prior only). Full WPPM mode: Could use pilot data or pre-collected trials along grid etc.

Source code in src/psyphy/session/experiment_session.py
def initialize(self):
    """
    Fit an initial posterior before any adaptive placement.

    Returns
    -------
    Posterior
        Posterior object wrapping fitted parameters.

    Notes
    -----
    MVP:
        Posterior is fitted to empty data (prior only).
    Full WPPM mode:
        Could use pilot data or pre-collected trials along grid etc.
    """
    self.posterior = self.inference.fit(self.model, self.data)
    return self.posterior

next_batch

next_batch(batch_size: int)

Propose the next batch of trials.

Parameters:

Name Type Description Default
batch_size int

Number of trials to propose.

required

Returns:

Type Description
TrialBatch

Batch of proposed (reference, probe) stimuli.

Notes

MVP: Always calls placement.propose() on current posterior. Full WPPM mode: Could support hybrid placement (init strategy -> adaptive strategy).

Source code in src/psyphy/session/experiment_session.py
def next_batch(self, batch_size: int):
    """
    Propose the next batch of trials.

    Parameters
    ----------
    batch_size : int
        Number of trials to propose.

    Returns
    -------
    TrialBatch
        Batch of proposed (reference, probe) stimuli.

    Notes
    -----
    MVP:
        Always calls placement.propose() on current posterior.
    Full WPPM mode:
        Could support hybrid placement (init strategy -> adaptive strategy).
    """
    if self.posterior is None:
        raise RuntimeError("Posterior not initialized. Call initialize() first.")
    return self.placement.propose(self.posterior, batch_size)

update

update()

Refit posterior with accumulated data.

Returns:

Type Description
Posterior

Updated posterior.

Notes

MVP: Re-optimizes from scratch using all data. Full WPPM mode: Could support warm-start or online parameter updates.

Source code in src/psyphy/session/experiment_session.py
def update(self):
    """
    Refit posterior with accumulated data.

    Returns
    -------
    Posterior
        Updated posterior.

    Notes
    -----
    MVP:
        Re-optimizes from scratch using all data.
    Full WPPM mode:
        Could support warm-start or online parameter updates.
    """
    self.posterior = self.inference.fit(self.model, self.data)
    return self.posterior

GaussianNoise

GaussianNoise(sigma: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Gaussian (mean=0, var=1).

Attributes:

Name Type Description
sigma float

sigma

sigma: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Gaussian (mean=0, var=1).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Gaussian (mean=0, var=1)."""
    return jr.normal(key, shape)

LangevinSampler

LangevinSampler(
    steps: int = 1000,
    step_size: float = 0.001,
    temperature: float = 1.0,
)

Langevin sampler (stub).

Parameters:

Name Type Description Default
steps int

Number of Langevin steps.

1000
step_size float

Integration step size.

1e-3
temperature float

Noise scale (temperature).

1.0

Methods:

Name Description
fit

Fit model parameters with Langevin dynamics (stub).

Attributes:

Name Type Description
step_size
steps
temperature
Source code in src/psyphy/inference/langevin.py
def __init__(
    self, steps: int = 1000, step_size: float = 1e-3, temperature: float = 1.0
):
    self.steps = steps
    self.step_size = step_size
    self.temperature = temperature

step_size

step_size = step_size

steps

steps = steps

temperature

temperature = temperature

fit

fit(model, data) -> Posterior

Fit model parameters with Langevin dynamics (stub).

Parameters:

Name Type Description Default
model WPPM

Model instance.

required
data ResponseData

Observed trials.

required

Returns:

Type Description
Posterior

Posterior wrapper (MVP: params from init).

Source code in src/psyphy/inference/langevin.py
def fit(self, model, data) -> Posterior:
    """
    Fit model parameters with Langevin dynamics (stub).

    Parameters
    ----------
    model : WPPM
        Model instance.
    data : ResponseData
        Observed trials.

    Returns
    -------
    Posterior
        Posterior wrapper (MVP: params from init).
    """
    return Posterior(params=model.init_params(None), model=model)

LaplaceApproximation

Laplace approximation around MAP estimate.

Methods:

Name Description
from_map

Construct a Gaussian approximation centered at MAP.

from_map

from_map(map_posterior: Posterior) -> Posterior

Return posterior approximation from MAP.

Parameters:

Name Type Description Default
map_posterior Posterior

Posterior object from MAP optimization.

required

Returns:

Type Description
Posterior

Same posterior object (MVP).

Source code in src/psyphy/inference/laplace.py
def from_map(self, map_posterior: Posterior) -> Posterior:
    """
    Return posterior approximation from MAP.

    Parameters
    ----------
    map_posterior : Posterior
        Posterior object from MAP optimization.

    Returns
    -------
    Posterior
        Same posterior object (MVP).
    """
    return map_posterior

MAPOptimizer

MAPOptimizer(
    steps: int = 500,
    learning_rate: float = 5e-05,
    momentum: float = 0.9,
    optimizer: GradientTransformation | None = None,
    *,
    track_history: bool = True,
    log_every: int = 1,
    progress_every: int = 10,
    show_progress: bool = False,
    max_grad_norm: float | None = 1.0,
)

Bases: InferenceEngine

MAP (Maximum A Posteriori) optimizer.

Parameters:

Name Type Description Default
steps int

Number of optimization steps.

500
optimizer GradientTransformation

Optax optimizer to use. Default: SGD with momentum.

None
Notes
  • Loss function = negative log posterior.
  • Gradients computed with jax.grad.

Create a MAP optimizer.

Parameters:

Name Type Description Default
steps int

Number of optimization steps.

500
optimizer GradientTransformation | None

Optax optimizer to use.

None
learning_rate float

Learning rate for the default optimizer (SGD with momentum).

5e-05
momentum float

Momentum for the default optimizer (SGD with momentum).

0.9
track_history bool

When True, record loss history during fitting for plotting.

True
log_every int

Record every N steps (also records the last step).

1
progress_every int

Update the progress-bar loss display every N steps (and the last step) when show_progress=True. This is kept separate from log_every so you can record loss at high frequency for plotting (e.g. log_every=1) without forcing a device->host sync for the progress UI every step.

10
show_progress bool

When True, display a tqdm progress bar during fitting. This is a UI feature: if tqdm is not installed, fitting proceeds without a progress bar.

False
max_grad_norm float | None

If set, clip gradients by global norm to this value before applying optimizer updates. This stabilizes optimization when gradients blow up.

1.0

Methods:

Name Description
fit

Fit model parameters with MAP optimization.

get_history

Return (steps, losses) recorded during the last fit when tracking was enabled.

Attributes:

Name Type Description
log_every
loss_history list[float]
loss_steps list[int]
max_grad_norm
optimizer
progress_every
show_progress
steps
track_history
Source code in src/psyphy/inference/map_optimizer.py
def __init__(
    self,
    steps: int = 500,
    learning_rate: float = 5e-5,
    momentum: float = 0.9,
    optimizer: optax.GradientTransformation | None = None,
    *,
    track_history: bool = True,
    log_every: int = 1,
    progress_every: int = 10,
    show_progress: bool = False,
    max_grad_norm: float | None = 1.0,
):
    """Create a MAP optimizer.

    Parameters
    ----------
    steps : int
        Number of optimization steps.
    optimizer : optax.GradientTransformation | None
        Optax optimizer to use.
    learning_rate : float, optional
        Learning rate for the default optimizer (SGD with momentum).
    momentum : float, optional
        Momentum for the default optimizer (SGD with momentum).
    track_history : bool, optional
        When True, record loss history during fitting for plotting.
    log_every : int, optional
        Record every N steps (also records the last step).
    progress_every : int, optional
        Update the progress-bar loss display every N steps (and the last step)
        when show_progress=True.
        This is kept separate from log_every so you can record loss at high
        frequency for plotting (e.g. log_every=1) without forcing a device->host
        sync for the progress UI every step.
    show_progress : bool, optional
        When True, display a tqdm progress bar during fitting.
        This is a UI feature: if tqdm is not installed,
        fitting proceeds without a progress bar.
    max_grad_norm : float | None, optional
        If set, clip gradients by global norm to this value before applying
        optimizer updates. This stabilizes optimization when gradients blow up.
    """
    self.steps = steps
    base_optimizer = optimizer or optax.sgd(
        learning_rate=learning_rate, momentum=momentum
    )
    if max_grad_norm is None:
        self.optimizer = base_optimizer
    else:
        self.optimizer = optax.chain(
            optax.clip_by_global_norm(float(max_grad_norm)),
            base_optimizer,
        )

    self.track_history = track_history
    self.log_every = max(1, int(log_every))
    self.progress_every = max(1, int(progress_every))
    self.show_progress = bool(show_progress)
    self.max_grad_norm = max_grad_norm
    # Exposed after fit() when tracking is enabled
    self.loss_steps: list[int] = []
    self.loss_history: list[float] = []

log_every

log_every = max(1, int(log_every))

loss_history

loss_history: list[float] = []

loss_steps

loss_steps: list[int] = []

max_grad_norm

max_grad_norm = max_grad_norm

optimizer

optimizer = base_optimizer

progress_every

progress_every = max(1, int(progress_every))

show_progress

show_progress = bool(show_progress)

steps

steps = steps

track_history

track_history = track_history

fit

fit(
    model,
    data,
    init_params: dict | None = None,
    seed: int | None = None,
) -> MAPPosterior

Fit model parameters with MAP optimization.

Parameters:

Name Type Description Default
model WPPM

Model instance.

required
data ResponseData

Observed trials.

required
init_params dict | None

Initial parameter PyTree to start optimization from. If provided, this takes precedence over the seed.

None
seed int | None

PRNG seed used to draw initial parameters from the model's prior when init_params is not provided. If None, defaults to 0.

None

Returns:

Type Description
MAPPosterior

Posterior wrapper around MAP params and model.

Source code in src/psyphy/inference/map_optimizer.py
def fit(
    self,
    model,
    data,
    init_params: dict | None = None,
    seed: int | None = None,
) -> MAPPosterior:
    """
    Fit model parameters with MAP optimization.

    Parameters
    ----------
    model : WPPM
        Model instance.
    data : ResponseData
        Observed trials.
    init_params : dict | None, optional
        Initial parameter PyTree to start optimization from. If provided,
        this takes precedence over the seed.
    seed : int | None, optional
        PRNG seed used to draw initial parameters from the model's prior
        when init_params is not provided. If None, defaults to 0.

    Returns
    -------
    MAPPosterior
        Posterior wrapper around MAP params and model.
    """

    def loss_fn(params):
        return -model.log_posterior_from_data(params, data)

    # Initialize parameters
    if init_params is not None:
        params = init_params
    else:
        rng_seed = 0 if seed is None else int(seed)
        params = model.init_params(jax.random.PRNGKey(rng_seed))
    opt_state = self.optimizer.init(params)

    @jax.jit
    def step(params, opt_state):
        # Ensure params and opt_state are JAX PyTrees for JIT compatibility
        loss, grads = jax.value_and_grad(loss_fn)(params)  # auto-diff
        updates, opt_state = self.optimizer.update(
            grads, opt_state, params
        )  # optimizer update
        params = optax.apply_updates(params, updates)  # apply updates
        # Only return JAX-compatible types (PyTrees of arrays, scalars)
        return params, opt_state, loss

    # clear any previous history
    if self.track_history:
        self.loss_steps.clear()
        self.loss_history.clear()

    # Optional progress bar.
    #
    # Why we *manually* advance the bar:
    # - When JAX runs on GPU, the first `step(...)` call can spend a long time in
    #   compilation, and tqdm may not visibly advance if the underlying iterator
    #   doesn't get a chance to redraw.
    # - By keeping a normal `range(self.steps)` loop and calling `pbar.update(1)`
    #   ourselves, we ensure the bar advances exactly once per iteration.
    #
    # Performance note: *displaying the loss* requires transferring `loss` from
    # device -> host, which can add sync overhead. We therefore only attach a
    # loss postfix every `progress_every` steps.
    pbar = None
    if self.show_progress:
        try:
            from tqdm.auto import tqdm

            pbar = tqdm(total=self.steps, desc="MAP fit", leave=False)
        except Exception:
            # Soft dependency: tqdm not available (or terminal unsuitable).
            pbar = None

    for i in range(self.steps):
        params, opt_state, loss = step(params, opt_state)

        # Non-finite guard: if loss becomes NaN/Inf, optimization has diverged.
        # Stop early so downstream plots don’t look “truncated” due to NaNs.
        if not bool(jax.numpy.isfinite(loss)):
            if self.track_history:
                try:
                    self.loss_steps.append(i)
                    self.loss_history.append(float(loss))
                except Exception:
                    pass
            print(
                f"[MAPOptimizer] Non-finite loss at step {i}: {loss}. "
                "Stopping early."
            )
            if pbar is not None:
                with contextlib.suppress(Exception):
                    pbar.update(1)
            break

        if self.track_history and (
            (i % self.log_every == 0) or (i == self.steps - 1)
        ):
            # Pull scalar to host and record
            try:
                self.loss_steps.append(i)
                self.loss_history.append(float(loss))
            except Exception:
                #  do not break fitting if logging fails
                pass

        #  progress bar loss display (avoid host sync every step)
        if pbar is not None and (
            (i % self.progress_every == 0) or (i == self.steps - 1)
        ):
            with contextlib.suppress(Exception):
                pbar.set_postfix(loss=float(loss))

        if pbar is not None:
            with contextlib.suppress(Exception):
                pbar.update(1)
                # Encourage a redraw occasionally in environments with buffered/stale
                # TTY updates.
                if (i % self.progress_every == 0) or (i == self.steps - 1):
                    pbar.refresh()

    if pbar is not None:
        with contextlib.suppress(Exception):
            pbar.close()

    return MAPPosterior(params=params, model=model)

get_history

get_history() -> tuple[list[int], list[float]]

Return (steps, losses) recorded during the last fit when tracking was enabled.

Source code in src/psyphy/inference/map_optimizer.py
def get_history(self) -> tuple[list[int], list[float]]:
    """Return (steps, losses) recorded during the last fit when tracking was enabled."""
    return self.loss_steps, self.loss_history

OddityTask

OddityTask(config: OddityTaskConfig | None = None)

Bases: TaskLikelihood

Three-alternative forced-choice oddity task (MC-based only).

Implements the full 3-stimulus oddity task using Monte Carlo simulation: - Samples three internal representations per trial (z0, z1, z2) - Uses proper oddity decision rule with three pairwise distances - Suitable for complex covariance structures

Notes

MC simulation in loglik() (full 3-stimulus oddity): 1. Sample three internal representations: z_ref, z_refprime ~ N(ref, Σ_ref), z_comparison ~ N(comparison, Σ_comparison) 2. Compute average covariance: Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison 3. Compute three pairwise Mahalanobis distances: - d^2(z_ref, z_refprime) = distance between two reference samples - d^2(z_ref, z_comparison) = distance from ref to comparison - d^2(z_refprime, z_comparison) = distance from reference_prime to comparison 4. Apply oddity decision rule: delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime) 5. Logistic smoothing: P(correct) pprox logistic.cdf(delta / bandwidth) 6. Average over samples

Examples:

>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.task import OddityTaskConfig
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.noise import GaussianNoise
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>>
>>> # Create task and model (task-owned MC controls)
>>> task = OddityTask(config=OddityTaskConfig(num_samples=1000, bandwidth=1e-2))
>>> model = WPPM(
...     input_dim=2, prior=Prior(input_dim=2), task=task, noise=GaussianNoise()
... )
>>> params = model.init_params(jr.PRNGKey(0))
1
2
3
4
5
6
>>> # MC simulation
>>> from psyphy.data.dataset import ResponseData
>>> data = ResponseData()
>>> data.add_trial(ref, comparison, resp=1)
>>> ll_mc = task.loglik(params, data, model, model.noise, key=jr.PRNGKey(42))
>>> print(f"Log-likelihood (MC): {ll_mc:.4f}")

Methods:

Name Description
loglik
1
Compute log-likelihood via Monte Carlo observer simulation.
predict

Predict p(correct) for a single (ref, comparison) stimulus.

Attributes:

Name Type Description
config
Source code in src/psyphy/model/task.py
def __init__(self, config: OddityTaskConfig | None = None) -> None:
    # No analytical parameters in MC-only mode.
    self.config = config or OddityTaskConfig()

config

config = config or OddityTaskConfig()

loglik

loglik(
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    **kwargs: Any,
) -> ndarray
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Compute log-likelihood via Monte Carlo observer simulation.

This method implements the FULL 3-stimulus oddity task. Instead of using
an analytical approximation, we:
1. Sample three internal noisy representations per trial:
   - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
   - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
2. Compute three pairwise Mahalanobis distances
3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
4. Apply logistic smoothing to approximate P(correct)
5. Average over MC samples
1
Parameters
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
params : Any
    Model parameters as expected by ``model._compute_sqrt``.
data : ResponseData
    Trial data with refs, comparisons, and responses
model : WPPM
    Model instance providing ``_compute_sqrt`` for covariance computation.
noise : NoiseModel
    Observer noise model (provides ``sample_standard``).
key : jax.random.PRNGKey, optional
    Random key for reproducible sampling.
    If None, uses ``OddityTaskConfig.default_key_seed``.
1
Returns
1
2
3
jnp.ndarray
    Scalar sum of log-likelihoods over all trials.
    Same shape and interpretation as ``loglik``.
1
Raises
1
2
3
4
TypeError
    If ``num_samples`` or ``bandwidth`` are provided as kwargs.
ValueError
    If the task configuration is invalid (e.g. ``num_samples <= 0``).
1
Notes
1
2
3
4
5
6
Monte Carlo controls (``num_samples``, ``bandwidth``) are owned by the
task configuration:

- Create the task with ``OddityTask(config=OddityTaskConfig(...))``.
- Pass only the PRNG ``key`` at call time when you want to control
  randomness.
1
Notes
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
**Full 3-stimulus oddity task algorithm:**

For each trial (ref, comparison, response):
1. Compute covariances:
   - Σ_ref = U_ref @ U_ref.T + σ^2 I
   - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
   - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

2. Sample three internal representations:
   - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
   - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

3. Compute three pairwise Mahalanobis distances:
   - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
   - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
   - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

4. Apply oddity decision rule:
   - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
   - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

5. Apply logistic smoothing:
   - P(correct) pprox mean(logistic.cdf(delta / bandwidth))

6. Bernoulli log-likelihood:
   - LL = Σ [y * log(p) + (1-y) * log(1-p)]

Performance:
- Memory: O(num_samples * input_dim) per trial
- Vectorized across trials using jax.vmap for GPU acceleration
  • Can be JIT-compiled for additional speed (future optimization)
1
Examples
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.noise import GaussianNoise
>>> from psyphy.data.dataset import ResponseData
>>>
>>> # Setup
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2, basis_degree=3),
...     task=OddityTask(),
...     noise=GaussianNoise(sigma=0.03),
... )
>>> params = model.init_params(jr.PRNGKey(0))
>>>
>>> # Create trial data
>>> data = ResponseData()
>>> data.add_trial(
...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
... )
>>>
>>> loglik = model.task.loglik(
...     params,
...     data,
...     model,
...     model.noise,
...     num_samples=5000,
...     bandwidth=1e-3,
...     key=jr.PRNGKey(42),
... )
>>> print(f"MC (N=5000): {loglik:.4f}")
Source code in src/psyphy/model/task.py
def loglik(
    self, params: Any, data: Any, model: Any, noise: Any, **kwargs: Any
) -> jnp.ndarray:
    """
        Compute log-likelihood via Monte Carlo observer simulation.

        This method implements the FULL 3-stimulus oddity task. Instead of using
        an analytical approximation, we:
        1. Sample three internal noisy representations per trial:
           - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
           - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
        2. Compute three pairwise Mahalanobis distances
        3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
        4. Apply logistic smoothing to approximate P(correct)
        5. Average over MC samples

        Parameters
        ----------
        params : Any
            Model parameters as expected by ``model._compute_sqrt``.
        data : ResponseData
            Trial data with refs, comparisons, and responses
        model : WPPM
            Model instance providing ``_compute_sqrt`` for covariance computation.
        noise : NoiseModel
            Observer noise model (provides ``sample_standard``).
        key : jax.random.PRNGKey, optional
            Random key for reproducible sampling.
            If None, uses ``OddityTaskConfig.default_key_seed``.

        Returns
        -------
        jnp.ndarray
            Scalar sum of log-likelihoods over all trials.
            Same shape and interpretation as ``loglik``.

        Raises
        ------
        TypeError
            If ``num_samples`` or ``bandwidth`` are provided as kwargs.
        ValueError
            If the task configuration is invalid (e.g. ``num_samples <= 0``).

        Notes
        -----
        Monte Carlo controls (``num_samples``, ``bandwidth``) are owned by the
        task configuration:

        - Create the task with ``OddityTask(config=OddityTaskConfig(...))``.
        - Pass only the PRNG ``key`` at call time when you want to control
          randomness.

        Notes
        -----
        **Full 3-stimulus oddity task algorithm:**

        For each trial (ref, comparison, response):
        1. Compute covariances:
           - Σ_ref = U_ref @ U_ref.T + σ^2 I
           - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
           - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

        2. Sample three internal representations:
           - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
           - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

        3. Compute three pairwise Mahalanobis distances:
           - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
           - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
           - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

        4. Apply oddity decision rule:
           - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
           - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

        5. Apply logistic smoothing:
           - P(correct) \approx mean(logistic.cdf(delta / bandwidth))

        6. Bernoulli log-likelihood:
           - LL = Σ [y * log(p) + (1-y) * log(1-p)]

        Performance:
        - Memory: O(num_samples * input_dim) per trial
        - Vectorized across trials using jax.vmap for GPU acceleration
    - Can be JIT-compiled for additional speed (future optimization)

        Examples
        --------
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> from psyphy.model import WPPM, Prior
        >>> from psyphy.model.task import OddityTask
        >>> from psyphy.model.noise import GaussianNoise
        >>> from psyphy.data.dataset import ResponseData
        >>>
        >>> # Setup
        >>> model = WPPM(
        ...     input_dim=2,
        ...     prior=Prior(input_dim=2, basis_degree=3),
        ...     task=OddityTask(),
        ...     noise=GaussianNoise(sigma=0.03),
        ... )
        >>> params = model.init_params(jr.PRNGKey(0))
        >>>
        >>> # Create trial data
        >>> data = ResponseData()
        >>> data.add_trial(
        ...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
        ... )
        >>>
        >>> loglik = model.task.loglik(
        ...     params,
        ...     data,
        ...     model,
        ...     model.noise,
        ...     num_samples=5000,
        ...     bandwidth=1e-3,
        ...     key=jr.PRNGKey(42),
        ... )
        >>> print(f"MC (N=5000): {loglik:.4f}")


    """
    # Task is the single source of truth for MC controls.
    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)

    # Only PRNG key is accepted dynamically.
    key = kwargs.pop("key", None)
    if "num_samples" in kwargs or "bandwidth" in kwargs:
        raise TypeError(
            "OddityTask.loglik does not accept 'num_samples'/'bandwidth' overrides. "
            "Configure them via OddityTaskConfig when constructing the task."
        )
    if kwargs:
        unexpected = ", ".join(sorted(kwargs.keys()))
        raise TypeError(
            f"Unexpected keyword arguments for OddityTask.loglik: {unexpected}"
        )

    if num_samples <= 0:
        raise ValueError(f"num_samples must be > 0, got {num_samples}")
    if bandwidth <= 0:
        raise ValueError(f"bandwidth must be > 0, got {bandwidth}")

    if key is None:
        key = jr.PRNGKey(int(self.config.default_key_seed))

    # Unpack trial data
    refs, comparisons, responses = data.to_numpy()
    n_trials = len(refs)

    # Split keys for each trial (ensures independent sampling)
    trial_keys = jr.split(key, n_trials)

    # Vectorized computation of P(correct) for all trials
    # This processes all trials in parallel using jax.vmap
    # Note: probabilities are already clipped in _simulate_trial_mc()
    probs = self._simulate_trials_mc_vectorized(
        params=params,
        refs=refs,
        comparisons=comparisons,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        trial_keys=trial_keys,
    )

    # Bernoulli log-likelihood: LL = Σ [y log(p) + (1-y) log(1-p)]
    # Probabilities are already clipped to [eps, 1-eps] so log is safe
    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),  # Correct response
        jnp.log(1.0 - probs),  # Incorrect response
    )

    return jnp.sum(log_likelihoods)

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict p(correct) for a single (ref, comparison) stimulus.

Even though OddityTask is MC-only, we still implement predict. Reason: large parts of the library (posterior predictive, acquisition functions, diagnostics, etc.) need a forward model that returns p(correct) at candidate stimuli. Historically this used an analytical approximation, but in MC-only mode we compute it via simulation.

Notes
  • This method is intentionally lightweight: it performs the same single-trial Monte Carlo simulation used by loglik. - If you need to control MC fidelity/smoothing, set OddityTaskConfig(num_samples=..., bandwidth=...) when you construct the task. - If you need reproducible randomness, pass key=... to loglik.
Source code in src/psyphy/model/task.py
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """Predict p(correct) for a single (ref, comparison) stimulus.

    Even though OddityTask is *MC-only*, we still implement ``predict``.
    Reason: large parts of the library (posterior predictive, acquisition
    functions, diagnostics, etc.) need a forward model that returns
    p(correct) at candidate stimuli. Historically this used an analytical
    approximation, but in MC-only mode we compute it via simulation.

    Notes
    -----
    - This method is intentionally lightweight: it performs the same
      single-trial Monte Carlo simulation used by ``loglik``.
            - If you need to control MC fidelity/smoothing, set
                ``OddityTaskConfig(num_samples=..., bandwidth=...)`` when you
                construct the task.
            - If you need reproducible randomness, pass ``key=...`` to ``loglik``.
    """

    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)
    key = jr.PRNGKey(int(self.config.default_key_seed))

    ref, comparison = stimuli
    return self._simulate_trial_mc(
        params=params,
        ref=ref,
        comparison=comparison,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        key=key,
    )

Prior

Prior(
    input_dim: int,
    basis_degree: int | None = None,
    variance_scale: float = 1.0,
    decay_rate: float = 0.5,
    extra_embedding_dims: int = 0,
)

Prior distribution over WPPM parameters

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the model space (same as WPPM.input_dim)

required
basis_degree int | None

Degree of Chebyshev basis for Wishart process. If set, uses Wishart mode with W coefficients.

None
variance_scale float

Prior variance for degree-0 (constant) coefficient in Wishart mode. Controls overall scale of covariances.

1.0
decay_rate float

Geometric decay rate for prior variance over higher-degree coefficients. Prior variance for degree-d coefficient = variance_scale * (decay_rate^d). Smaller decay_rate -> stronger smoothness prior.

0.5
extra_embedding_dims int

Additional latent dimensions in U matrices beyond input dimensions. Allows richer ellipsoid shapes in Wishart mode.

0

Methods:

Name Description
log_prob

Compute log prior density (up to a constant)

sample_params

Sample initial parameters from the prior.

Attributes:

Name Type Description
basis_degree int | None
decay_rate float
extra_embedding_dims int
input_dim int
variance_scale float

basis_degree

basis_degree: int | None = None

decay_rate

decay_rate: float = 0.5

extra_embedding_dims

extra_embedding_dims: int = 0

input_dim

input_dim: int

variance_scale

variance_scale: float = 1.0

log_prob

log_prob(params: Params) -> ndarray

Compute log prior density (up to a constant)

Gaussian prior on W with smoothness via decay_rate log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

Parameters:

Name Type Description Default
params dict

Parameter dictionary

required

Returns:

Name Type Description
log_prob float

Log prior probability (up to normalizing constant)

Source code in src/psyphy/model/prior.py
def log_prob(self, params: Params) -> jnp.ndarray:
    """
    Compute log prior density (up to a constant)

    Gaussian prior on W with smoothness via decay_rate
        log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

    Parameters
    ----------
    params : dict
        Parameter dictionary

    Returns
    -------
    log_prob : float
        Log prior probability (up to normalizing constant)
    """

    if "W" in params:
        # Wishart mode
        W = params["W"]
        variances = self._compute_W_prior_variances()

        # Gaussian log probability for each entry
        # log N(x | 0, σ^2) = -0.5 * (x^2/σ^2 + log(2πσ^2))
        # Up to constant: -0.5 * x^2/σ^2

        if self.input_dim == 2:
            # Each W[i,j,:,:] ~ Normal(0, variance[i,j] * I)
            return -0.5 * jnp.sum((W**2) / (variances[:, :, None, None] + 1e-10))
        elif self.input_dim == 3:
            return -0.5 * jnp.sum((W**2) / (variances[:, :, :, None, None] + 1e-10))

    raise ValueError("params must contain weights 'W'")

sample_params

sample_params(key: Any) -> Params

Sample initial parameters from the prior.

Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)} for 2D, where embedding_dim = input_dim + extra_embedding_dims

Note: The 3rd dimension is input_dim (output space dimension). This matches the einsum in _compute_sqrt: U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

Parameters:

Name Type Description Default
key JAX random key
required

Returns:

Name Type Description
params dict

Parameter dictionary

Source code in src/psyphy/model/prior.py
def sample_params(self, key: Any) -> Params:
    """
    Sample initial parameters from the prior.


    Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)}
    for 2D, where embedding_dim = input_dim + extra_embedding_dims

    Note: The 3rd dimension is input_dim (output space dimension).
    This matches the einsum in _compute_sqrt:
    U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

    Parameters
    ----------
    key : JAX random key

    Returns
    -------
    params : dict
        Parameter dictionary
    """
    if self.basis_degree is None:
        raise ValueError(
            "'basis_degree' is None; please set "
            "`Prior.basis_degree` to an integer >0."
        )

    # Basis function coefficients W
    variances = self._compute_W_prior_variances()
    embedding_dim = self.input_dim + self.extra_embedding_dims

    if self.input_dim == 2:
        # Sample W ~ Normal(0, variances) for each matrix entry
        # Shape: (degree+1, degree+1, input_dim, embedding_dim)
        # Note: degree+1 to match number of basis functions [T_0, ..., T_degree]
        W = jnp.sqrt(variances)[:, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    elif self.input_dim == 3:
        # Shape: (degree+1, degree+1, degree+1, input_dim, embedding_dim)
        W = jnp.sqrt(variances)[:, :, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    else:
        raise NotImplementedError(
            f"Wishart process only supports 2D and 3D. Got input_dim={self.input_dim}"
        )

    return {"W": W}

ResponseData

ResponseData()

Container for psychophysical trial data.

Attributes:

Name Type Description
refs List[Any]

List of reference stimuli.

comparisons List[Any]

List of comparison stimuli.

responses List[int]

List of subject responses (e.g., 0/1 or categorical).

Methods:

Name Description
add_batch

Append responses for a batch of trials.

add_trial

append a single trial.

copy

Create a deep copy of this dataset.

from_arrays

Construct ResponseData from arrays.

merge

Merge another dataset into this one (in-place).

tail

Return last n trials as a new ResponseData.

to_numpy

Return refs, comparisons, responses as numpy arrays.

Source code in src/psyphy/data/dataset.py
def __init__(self) -> None:
    self.refs: list[Any] = []
    self.comparisons: list[Any] = []
    self.responses: list[int] = []

comparisons

comparisons: list[Any] = []

refs

refs: list[Any] = []

responses

responses: list[int] = []

trials

trials: list[tuple[Any, Any, int]]

Return list of (ref, comparison, response) tuples.

Returns:

Type Description
list[tuple]

Each element is (ref, comparison, resp)

add_batch

add_batch(
    responses: list[int], trial_batch: TrialBatch
) -> None

Append responses for a batch of trials.

Parameters:

Name Type Description Default
responses List[int]

Responses corresponding to each (ref, comparison) in the trial batch.

required
trial_batch TrialBatch

The batch of proposed trials.

required
Source code in src/psyphy/data/dataset.py
def add_batch(self, responses: list[int], trial_batch: TrialBatch) -> None:
    """
    Append responses for a batch of trials.

    Parameters
    ----------
    responses : List[int]
        Responses corresponding to each (ref, comparison) in the trial batch.
    trial_batch : TrialBatch
        The batch of proposed trials.
    """
    for (ref, comparison), resp in zip(trial_batch.stimuli, responses):
        self.add_trial(ref, comparison, resp)

add_trial

add_trial(ref: Any, comparison: Any, resp: int) -> None

append a single trial.

Parameters:

Name Type Description Default
ref Any

Reference stimulus (numpy array, list, etc.)

required
comparison Any

Probe stimulus

required
resp int

Subject response (binary or categorical)

required
Source code in src/psyphy/data/dataset.py
def add_trial(self, ref: Any, comparison: Any, resp: int) -> None:
    """
    append a single trial.

    Parameters
    ----------
    ref : Any
        Reference stimulus (numpy array, list, etc.)
    comparison : Any
        Probe stimulus
    resp : int
        Subject response (binary or categorical)
    """
    self.refs.append(ref)
    self.comparisons.append(comparison)
    self.responses.append(resp)

copy

copy() -> ResponseData

Create a deep copy of this dataset.

Returns:

Type Description
ResponseData

New dataset with copied data

Source code in src/psyphy/data/dataset.py
def copy(self) -> ResponseData:
    """
    Create a deep copy of this dataset.

    Returns
    -------
    ResponseData
        New dataset with copied data
    """
    new_data = ResponseData()
    new_data.refs = list(self.refs)
    new_data.comparisons = list(self.comparisons)
    new_data.responses = list(self.responses)
    return new_data

from_arrays

from_arrays(
    X: ndarray | ndarray,
    y: ndarray | ndarray,
    *,
    comparisons: ndarray | ndarray | None = None,
) -> ResponseData

Construct ResponseData from arrays.

Parameters:

Name Type Description Default
X (array, shape(n_trials, 2, input_dim) or (n_trials, input_dim))

Stimuli. If 3D, second axis is [reference, comparison]. If 2D, comparisons must be provided separately.

required
y (array, shape(n_trials))

Responses

required
comparisons (array, shape(n_trials, input_dim))

Probe stimuli. Only needed if X is 2D.

None

Returns:

Type Description
ResponseData

Data container

Examples:

1
2
3
4
>>> # From paired stimuli
>>> X = jnp.array([[[0, 0], [1, 0]], [[1, 1], [2, 1]]])
>>> y = jnp.array([1, 0])
>>> data = ResponseData.from_arrays(X, y)
1
2
3
4
>>> # From separate refs and comparisons
>>> refs = jnp.array([[0, 0], [1, 1]])
>>> comparisons = jnp.array([[1, 0], [2, 1]])
>>> data = ResponseData.from_arrays(refs, y, comparisons=comparisons)
Source code in src/psyphy/data/dataset.py
@classmethod
def from_arrays(
    cls,
    X: jnp.ndarray | np.ndarray,
    y: jnp.ndarray | np.ndarray,
    *,
    comparisons: jnp.ndarray | np.ndarray | None = None,
) -> ResponseData:
    """
    Construct ResponseData from arrays.

    Parameters
    ----------
    X : array, shape (n_trials, 2, input_dim) or (n_trials, input_dim)
        Stimuli. If 3D, second axis is [reference, comparison].
        If 2D, comparisons must be provided separately.
    y : array, shape (n_trials,)
        Responses
    comparisons : array, shape (n_trials, input_dim), optional
        Probe stimuli. Only needed if X is 2D.

    Returns
    -------
    ResponseData
        Data container

    Examples
    --------
    >>> # From paired stimuli
    >>> X = jnp.array([[[0, 0], [1, 0]], [[1, 1], [2, 1]]])
    >>> y = jnp.array([1, 0])
    >>> data = ResponseData.from_arrays(X, y)

    >>> # From separate refs and comparisons
    >>> refs = jnp.array([[0, 0], [1, 1]])
    >>> comparisons = jnp.array([[1, 0], [2, 1]])
    >>> data = ResponseData.from_arrays(refs, y, comparisons=comparisons)
    """
    data = cls()

    X = np.asarray(X)
    y = np.asarray(y)

    if X.ndim == 3:
        # X is (n_trials, 2, input_dim)
        refs = X[:, 0, :]
        comparisons_arr = X[:, 1, :]
    elif X.ndim == 2 and comparisons is not None:
        refs = X
        comparisons_arr = np.asarray(comparisons)
    else:
        raise ValueError(
            "X must be shape (n_trials, 2, input_dim) or "
            "(n_trials, input_dim) with comparisons argument"
        )

    for ref, comparison, response in zip(refs, comparisons_arr, y):
        data.add_trial(ref, comparison, int(response))

    return data

merge

merge(other: ResponseData) -> None

Merge another dataset into this one (in-place).

Parameters:

Name Type Description Default
other ResponseData

Dataset to merge

required
Source code in src/psyphy/data/dataset.py
def merge(self, other: ResponseData) -> None:
    """
    Merge another dataset into this one (in-place).

    Parameters
    ----------
    other : ResponseData
        Dataset to merge
    """
    self.refs.extend(other.refs)
    self.comparisons.extend(other.comparisons)
    self.responses.extend(other.responses)

tail

tail(n: int) -> ResponseData

Return last n trials as a new ResponseData.

Parameters:

Name Type Description Default
n int

Number of trials to keep

required

Returns:

Type Description
ResponseData

New dataset with last n trials

Source code in src/psyphy/data/dataset.py
def tail(self, n: int) -> ResponseData:
    """
    Return last n trials as a new ResponseData.

    Parameters
    ----------
    n : int
        Number of trials to keep

    Returns
    -------
    ResponseData
        New dataset with last n trials
    """
    new_data = ResponseData()
    new_data.refs = self.refs[-n:]
    new_data.comparisons = self.comparisons[-n:]
    new_data.responses = self.responses[-n:]
    return new_data

to_numpy

to_numpy() -> tuple[ndarray, ndarray, ndarray]

Return refs, comparisons, responses as numpy arrays.

Returns:

Name Type Description
refs ndarray
comparisons ndarray
responses ndarray
Source code in src/psyphy/data/dataset.py
def to_numpy(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Return refs, comparisons, responses as numpy arrays.

    Returns
    -------
    refs : np.ndarray
    comparisons : np.ndarray
    responses : np.ndarray
    """
    return (
        np.array(self.refs),
        np.array(self.comparisons),
        np.array(self.responses),
    )

StudentTNoise

StudentTNoise(df: float = 3.0, scale: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Student-t (df=self.df).

Attributes:

Name Type Description
df float
scale float

df

df: float = 3.0

scale

scale: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Student-t (df=self.df).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Student-t (df=self.df)."""
    return jr.t(key, self.df, shape)

TrialBatch

TrialBatch(stimuli: list[tuple[Any, Any]])

Container for a proposed batch of trials

Attributes:

Name Type Description
stimuli List[Tuple[Any, Any]]

Each trial is a (reference, comparison) tuple.

Methods:

Name Description
from_stimuli

Construct a TrialBatch from a list of stimuli (ref, comparison) pairs.

Source code in src/psyphy/data/dataset.py
def __init__(self, stimuli: list[tuple[Any, Any]]) -> None:
    self.stimuli = list(stimuli)

stimuli

stimuli = list(stimuli)

from_stimuli

from_stimuli(pairs: list[tuple[Any, Any]]) -> TrialBatch

Construct a TrialBatch from a list of stimuli (ref, comparison) pairs.

Source code in src/psyphy/data/dataset.py
@classmethod
def from_stimuli(cls, pairs: list[tuple[Any, Any]]) -> TrialBatch:
    """
    Construct a TrialBatch from a list of stimuli (ref, comparison) pairs.
    """
    return cls(pairs)

WPPM

WPPM(
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,
    online_config: OnlineConfig | None = None,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    decay_rate: float = 1.0,
    diag_term: float = 1e-06,
    **model_kwargs: Any,
)

Bases: Model

Wishart Process Psychophysical Model (WPPM).

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input stimulus space (e.g., 2 for isoluminant plane, 3 for RGB). Both reference and probe live in R^{input_dim}.

required
prior Prior

Prior distribution over model parameters. Controls basis_degree in WPPM (basis expansion). The WPPM delegates to prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

required
task TaskLikelihood

Psychophysical task mapping that defines how discriminability translates to p(correct) and how log-likelihood of responses is computed. (e.g., OddityTask)

required
noise Any

Noise model describing internal representation noise (e.g., GaussianNoise).

None
Forward-compatible hyperparameters

extra_dims : int, default=0 Additional embedding dimensions for basis expansions (beyond input_dim). embedding_dim = input_dim + extra_dims. variance_scale : float, default=1.0 Global scaling factor for covariance magnitude decay_rate : float, default=1.0 Smoothness/length-scale for spatial covariance variation diag_term : float, default=1e-6 Small positive value added to the covariance diagonal for numerical stability. online_config : OnlineConfig | None, optional (keyword-only) Base-model lifecycle / online-learning policy. This is the supported way to configure buffering and refit scheduling via Model.condition_on_observations.

model_kwargs : Any Reserved for future keyword arguments accepted by the base Model.__init__. Do not pass WPPM math knobs or task/likelihood knobs here.

Methods:

Name Description
condition_on_observations

Update model with new observations (online learning).

discriminability

Compute scalar discriminability d >= 0 for a (reference, probe) pair

fit

Fit model to data.

init_params

Sample initial parameters from the prior.

local_covariance

Return local covariance Σ(x) at stimulus location x.

log_likelihood

Compute the log-likelihood for arrays of trials.

log_likelihood_from_data

Compute log-likelihood directly from a ResponseData object.

log_posterior_from_data

Compute log posterior from data.

posterior

Return posterior distribution.

predict_prob

Predict probability of a correct response for a single stimulus.

predict_with_params

Evaluate model at specific parameter values (no marginalization).

Attributes:

Name Type Description
basis_degree int | None

Chebyshev polynomial degree for Wishart process basis expansion.

decay_rate
diag_term
embedding_dim int

Dimension of the embedding space (perceptual space).

extra_dims
input_dim
noise
online_config
prior
task
variance_scale
Source code in src/psyphy/model/wppm.py
def __init__(
    self,
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,  # everything after here is keyword-only
    online_config: OnlineConfig | None = None,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    decay_rate: float = 1.0,
    diag_term: float = 1e-6,
    **model_kwargs: Any,
) -> None:
    # Base-model configuration (lifecycle / online learning).
    #
    # `online_config` is the explicit, user-facing knob for online learning
    # and data retention (see `psyphy.model.base.OnlineConfig`).
    #
    # `model_kwargs` is reserved for *future* base `Model.__init__` kwargs.
    # It should NOT be used for WPPM-specific math (e.g. alternative covariance
    # parameterizations) or for task-specific likelihood knobs.
    if model_kwargs:
        known_misuses = {"num_samples", "bandwidth"}
        bad = sorted(known_misuses.intersection(model_kwargs.keys()))
        if bad:
            raise TypeError(
                "Do not pass task-specific kwargs via WPPM(..., **model_kwargs). "
                f"Move {bad} into the task config (e.g. OddityTaskConfig)."
            )

    super().__init__(online_config=online_config, **model_kwargs)

    # --- core components ---
    self.input_dim = int(input_dim)  # stimulus-space dimensionality
    self.prior = prior  # prior over parameter PyTree

    if self.prior.input_dim != self.input_dim:
        raise ValueError(
            f"Dimension mismatch: Model initialized with input_dim={self.input_dim}, "
            f"but Prior expects input_dim={self.prior.input_dim}."
        )

    self.task = task  # task mapping and likelihood
    self.noise = noise  # noise model

    self.extra_dims = int(extra_dims)
    self.variance_scale = float(variance_scale)
    self.decay_rate = float(decay_rate)
    self.diag_term = float(diag_term)

basis_degree

basis_degree: int | None

Chebyshev polynomial degree for Wishart process basis expansion.

This property delegates to self.prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

Returns:

Type Description
int | None

Degree of Chebyshev polynomial basis (0 = constant, 1 = linear, etc.)

Notes

WPPM gets its basis_degree parameter from Prior.basis_degree.

decay_rate

decay_rate = float(decay_rate)

diag_term

diag_term = float(diag_term)

embedding_dim

embedding_dim: int

Dimension of the embedding space (perceptual space).

embedding_dim = input_dim + extra_dims. this represents the full perceptual space where: - First input_dim dimensions correspond to observable stimulus features - Remaining extra_dims are latent dimensions

Returns:

Type Description
int

input_dim + extra_dims

Notes

This is a computed property, not a constructor parameter.

extra_dims

extra_dims = int(extra_dims)

input_dim

input_dim = int(input_dim)

noise

noise = noise

online_config

online_config = online_config or OnlineConfig()

prior

prior = prior

task

task = task

variance_scale

variance_scale = float(variance_scale)

condition_on_observations

condition_on_observations(X: ndarray, y: ndarray) -> Model

Update model with new observations (online learning).

Behavior depends on self.online_config.strategy: - "full": Accumulate all data, refit periodically - "sliding_window": Keep only recent window_size trials - "reservoir": Random sampling of window_size trials - "none": Refit from scratch (no caching)

Returns a NEW model instance (immutable update).

Parameters:

Name Type Description Default
X ndarray

New stimuli

required
y ndarray

New responses

required

Returns:

Type Description
Model

Updated model (new instance)

Examples:

1
2
3
4
5
>>> # Online learning loop
>>> model = WPPM(...).fit(X_init, y_init)
>>> for X_new, y_new in stream:
...     model = model.condition_on_observations(X_new, y_new)
...     # Model automatically manages memory and refitting
Source code in src/psyphy/model/base.py
def condition_on_observations(self, X: jnp.ndarray, y: jnp.ndarray) -> Model:
    """
    Update model with new observations (online learning).

    Behavior depends on self.online_config.strategy:
    - "full": Accumulate all data, refit periodically
    - "sliding_window": Keep only recent window_size trials
    - "reservoir": Random sampling of window_size trials
    - "none": Refit from scratch (no caching)

    Returns a NEW model instance (immutable update).

    Parameters
    ----------
    X : jnp.ndarray
        New stimuli
    y : jnp.ndarray
        New responses

    Returns
    -------
    Model
        Updated model (new instance)

    Examples
    --------
    >>> # Online learning loop
    >>> model = WPPM(...).fit(X_init, y_init)
    >>> for X_new, y_new in stream:
    ...     model = model.condition_on_observations(X_new, y_new)
    ...     # Model automatically manages memory and refitting
    """
    from psyphy.data import ResponseData

    # Convert new data
    new_data = ResponseData.from_arrays(X, y)

    # Update data buffer according to strategy
    if self.online_config.strategy == "none":
        data_to_fit = new_data

    elif self.online_config.strategy == "full":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "sliding_window":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)

        # Keep only last N trials
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for sliding_window"
        if len(self._data_buffer) > window_size:
            self._data_buffer = self._data_buffer.tail(window_size)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "reservoir":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()

        # Reservoir sampling
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for reservoir"
        self._data_buffer = self._reservoir_update(
            self._data_buffer,
            new_data,
            window_size,
        )
        data_to_fit = self._data_buffer
    else:
        raise ValueError(f"Unknown strategy: {self.online_config.strategy}")

    # Decide whether to refit
    self._n_updates += 1
    should_refit = self._n_updates % self.online_config.refit_interval == 0

    if not should_refit:
        # Return clone with updated buffer but old posterior
        new_model = self._clone()
        new_model._data_buffer = data_to_fit
        return new_model

    # Refit with optional warm start
    inference = self._inference_engine
    assert inference is not None, (
        "Model must be fit before condition_on_observations"
    )

    if self.online_config.warm_start and self._posterior is not None:
        # TODO: Add warm-start support to inference engines
        # For now, just refit from cached params
        pass

    new_model = self._clone()
    new_model._data_buffer = data_to_fit
    new_model._posterior = inference.fit(new_model, data_to_fit)
    new_model._inference_engine = inference

    return new_model

discriminability

discriminability(
    params: Params, stimulus: Stimulus
) -> ndarray

Compute scalar discriminability d >= 0 for a (reference, probe) pair

WPPM (rectangular U design) if extra_dims > 0: d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) ) where Σ(ref) is directly computed in stimulus space (input_dim, input_dim) via U(x) @ U(x)^T with U rectangular.

The discrimination task only depends on observable stimulus dimensions. The rectangular U design means local_covariance() already returns the stimulus covariance - no block extraction needed.

WPPM: d is implicit via Monte Carlo simulation of internal noisy responses under the task's decision rule (no closed form). In that case, tasks will directly implement predict/loglik with MC, and this method may be used only for diagnostics.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
stimulus tuple

(reference, probe) arrays of shape (input_dim,).

required

Returns:

Name Type Description
d ndarray

Nonnegative scalar discriminability.

Source code in src/psyphy/model/wppm.py
def discriminability(self, params: Params, stimulus: Stimulus) -> jnp.ndarray:
    """
    Compute scalar discriminability d >= 0 for a (reference, probe) pair


    WPPM (rectangular U design) if extra_dims > 0:
        d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) )
        where Σ(ref) is directly computed in stimulus space (input_dim, input_dim)
        via U(x) @ U(x)^T with U rectangular.

    The discrimination task only depends on observable stimulus dimensions.
    The rectangular U design means local_covariance() already returns
    the stimulus covariance - no block extraction needed.

    WPPM:
        d is implicit via Monte Carlo simulation of internal noisy responses
        under the task's decision rule (no closed form). In that case, tasks
        will directly implement predict/loglik with MC, and this method may be
        used only for diagnostics.

    Parameters
    ----------
    params : dict
        Model parameters.
    stimulus : tuple
        (reference, probe) arrays of shape (input_dim,).

    Returns
    -------
    d : jnp.ndarray
        Nonnegative scalar discriminability.
    """
    ref, probe = stimulus

    # Delta is in stimulus space (input_dim)
    delta = probe - ref

    # Get stimulus covariance at reference
    # (rectangular U design: already returns (input_dim, input_dim))
    Sigma = self.local_covariance(params, ref)

    # Add jitter for stable solve; diag_term is configurable
    jitter = self.diag_term * jnp.eye(self.input_dim)

    # Solve (Σ + jitter)^{-1} delta using a PD-aware solver
    x = jax.scipy.linalg.solve(Sigma + jitter, delta, assume_a="pos")
    d2 = jnp.dot(delta, x)  # quadratic form

    # Guard against tiny negative values from numerical error
    return jnp.sqrt(jnp.maximum(d2, 0.0))

fit

fit(
    X: ndarray,
    y: ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model

Fit model to data.

Parameters:

Name Type Description Default
X ndarray

Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs or (n_trials, input_dim) for references only

required
y ndarray

Responses, shape (n_trials,)

required
inference InferenceEngine | str

Inference engine or string key ("map", "laplace", "langevin")

"laplace"
inference_config dict | None

Hyperparameters for string-based inference. Examples: {"steps": 500, "lr": 1e-3} for MAP

None

Returns:

Type Description
Model

Self for method chaining

Examples:

>>> # Simple: use defaults
>>> model.fit(X, y)
1
2
3
>>> # Explicit optimizer
>>> from psyphy.inference import MAPOptimizer
>>> model.fit(X, y, inference=MAPOptimizer(steps=500))
>>> # String + config (for experiment tracking)
>>> model.fit(X, y, inference="map", inference_config={"steps": 500})
Source code in src/psyphy/model/base.py
def fit(
    self,
    X: jnp.ndarray,
    y: jnp.ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model:
    """
    Fit model to data.

    Parameters
    ----------
    X : jnp.ndarray
        Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs
        or (n_trials, input_dim) for references only
    y : jnp.ndarray
        Responses, shape (n_trials,)
    inference : InferenceEngine | str, default="laplace"
        Inference engine or string key ("map", "laplace", "langevin")
    inference_config : dict | None
        Hyperparameters for string-based inference.
        Examples: {"steps": 500, "lr": 1e-3} for MAP

    Returns
    -------
    Model
        Self for method chaining

    Examples
    --------
    >>> # Simple: use defaults
    >>> model.fit(X, y)

    >>> # Explicit optimizer
    >>> from psyphy.inference import MAPOptimizer
    >>> model.fit(X, y, inference=MAPOptimizer(steps=500))

    >>> # String + config (for experiment tracking)
    >>> model.fit(X, y, inference="map", inference_config={"steps": 500})
    """
    from psyphy.data import ResponseData
    from psyphy.inference import INFERENCE_ENGINES, InferenceEngine

    # Resolve inference engine
    is_string_inference = isinstance(inference, str)

    if is_string_inference:
        config = inference_config or {}
        inference_key: str = inference  # type: ignore[assignment]
        if inference_key not in INFERENCE_ENGINES:
            available = ", ".join(INFERENCE_ENGINES.keys())
            raise ValueError(
                f"Unknown inference: '{inference}'. Available: {available}"
            )
        inference_engine: InferenceEngine = INFERENCE_ENGINES[inference_key](
            **config
        )
    elif isinstance(inference, InferenceEngine):
        inference_engine = inference
    else:
        raise TypeError(
            f"inference must be InferenceEngine or str, got {type(inference)}"
        )

    if inference_config is not None and not is_string_inference:
        raise ValueError(
            "Cannot pass inference_config with InferenceEngine instance"
        )

    # Convert data
    data = ResponseData.from_arrays(X, y)

    # Fit
    self._posterior = inference_engine.fit(self, data)
    self._inference_engine = inference_engine
    self._data_buffer = data  # Initialize buffer
    return self

init_params

init_params(key: Array) -> Params

Sample initial parameters from the prior.

Returns:

Name Type Description
params dict[str, ndarray]
Source code in src/psyphy/model/wppm.py
def init_params(self, key: jax.Array) -> Params:
    """Sample initial parameters from the prior.

    Returns
    -------
    params : dict[str, jnp.ndarray]
    """
    return self.prior.sample_params(key)

local_covariance

local_covariance(params: Params, x: ndarray) -> ndarray

Return local covariance Σ(x) at stimulus location x.

Wishart mode (basis_degree set): Σ(x) = U(x) @ U(x)^T + diag_term * I where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0. - Varies smoothly with x - Guaranteed positive-definite - Returns stimulus covariance directly (input_dim, input_dim)

Parameters:

Name Type Description Default
params dict

Model parameters: - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}

required
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
Σ : jnp.ndarray, shape (input_dim, input_dim)

Covariance matrix in stimulus space.

Source code in src/psyphy/model/wppm.py
def local_covariance(self, params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    Return local covariance Σ(x) at stimulus location x.


    Wishart mode (basis_degree set):
        Σ(x) = U(x) @ U(x)^T + diag_term * I
        where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0.
        - Varies smoothly with x
        - Guaranteed positive-definite
        - Returns stimulus covariance directly (input_dim, input_dim)

    Parameters
    ----------
    params : dict
        Model parameters:
        - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    Σ : jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix in stimulus space.
    """
    # WPPM: spatially-varying covariance
    if "W" in params:
        U = self._compute_sqrt(params, x)  # (input_dim, embedding_dim)
        # Σ(x) = U(x) @ U(x)^T + diag_term * I
        # Result is (input_dim, input_dim)
        Sigma = U @ U.T + self.diag_term * jnp.eye(self.input_dim)
        return Sigma

    raise ValueError("params must contain 'W' (weights of WPPM)")

log_likelihood

log_likelihood(
    params: Params,
    refs: ndarray,
    probes: ndarray,
    responses: ndarray,
) -> ndarray

Compute the log-likelihood for arrays of trials.

IMPORTANT: We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV) or MC likelihood logic in multiple places.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
refs (ndarray, shape(N, input_dim))
required
probes (ndarray, shape(N, input_dim))
required
responses (ndarray, shape(N))

Typically 0/1; task may support richer encodings.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed)

1
2
3
4
5
6
7
8
9
Notes
-----
This method is intentionally strict and does not accept task-specific
runtime kwargs.

- Configure task behavior (e.g. MC fidelity/smoothing for ``OddityTask``)
    via the task instance passed to the model.
- If you need reproducible randomness, pass a ``key`` when calling the
    task directly.
Source code in src/psyphy/model/wppm.py
def log_likelihood(
    self,
    params: Params,
    refs: jnp.ndarray,
    probes: jnp.ndarray,
    responses: jnp.ndarray,
) -> jnp.ndarray:
    """
    Compute the log-likelihood for arrays of trials.

    IMPORTANT:
        We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV)
        or MC likelihood logic in multiple places.

    Parameters
    ----------
    params : dict
        Model parameters.
    refs : jnp.ndarray, shape (N, input_dim)
    probes : jnp.ndarray, shape (N, input_dim)
    responses : jnp.ndarray, shape (N,)
        Typically 0/1; task may support richer encodings.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed)

            Notes
            -----
            This method is intentionally strict and does not accept task-specific
            runtime kwargs.

            - Configure task behavior (e.g. MC fidelity/smoothing for ``OddityTask``)
                via the task instance passed to the model.
            - If you need reproducible randomness, pass a ``key`` when calling the
                task directly.
    """
    # We need a ResponseData-like object. To keep this method usable from
    # array inputs, we construct one on the fly. If you already have a
    # ResponseData instance, prefer `log_likelihood_from_data`.
    from psyphy.data.dataset import ResponseData  # local import to avoid cycles

    data = ResponseData()
    # ResponseData.add_trial(ref, probe, resp)
    for r, p, y in zip(refs, probes, responses):
        data.add_trial(r, p, int(y))
    return self.task.loglik(params, data, self, self.noise)

log_likelihood_from_data

log_likelihood_from_data(
    params: Params, data: Any
) -> ndarray

Compute log-likelihood directly from a ResponseData object.

Why delegate to the task? - The task knows the decision rule (oddity, 2AFC, ...). - The task can use the model (this WPPM) to fetch discriminabilities. - The task can use the noise model if it needs MC simulation.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data ResponseData

Collected trial data.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed).

Source code in src/psyphy/model/wppm.py
def log_likelihood_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """Compute log-likelihood directly from a ResponseData object.

    Why delegate to the task?
        - The task knows the decision rule (oddity, 2AFC, ...).
        - The task can use the model (this WPPM) to fetch discriminabilities.
        - The task can use the noise model if it needs MC simulation.

    Parameters
    ----------
    params : dict
        Model parameters.
    data : ResponseData
        Collected trial data.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed).
    """
    return self.task.loglik(params, data, self, self.noise)

log_posterior_from_data

log_posterior_from_data(
    params: Params, data: Any
) -> ndarray

Compute log posterior from data.

This simply adds the prior log-probability to the task log-likelihood. Inference engines (e.g., MAP optimizer) typically optimize this quantity.

Returns:

Type Description
ndarray

Scalar log posterior = loglik(params | data) + log_prior(params).

Source code in src/psyphy/model/wppm.py
def log_posterior_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """Compute log posterior from data.

    This simply adds the prior log-probability to the task log-likelihood.
    Inference engines (e.g., MAP optimizer) typically optimize this quantity.

    Returns
    -------
    jnp.ndarray
        Scalar log posterior = loglik(params | data) + log_prior(params).
    """
    return self.log_likelihood_from_data(params, data) + self.prior.log_prob(params)

posterior

posterior(
    X: ndarray | None = None,
    *,
    probes: ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior

Return posterior distribution.

Parameters:

Name Type Description Default
X ndarray | None

Test stimuli (references), shape (n_test, input_dim). Required for predictive posteriors, optional for parameter posteriors.

None
probes ndarray | None

Test probes, shape (n_test, input_dim). Required for predictive posteriors.

None
kind ('predictive', 'parameter')

Type of posterior to return: - "predictive": PredictivePosterior over f(X*) [for acquisitions] - "parameter": ParameterPosterior over θ [for diagnostics]

"predictive"

Returns:

Type Description
PredictivePosterior | ParameterPosterior

Posterior distribution

Raises:

Type Description
RuntimeError

If model has not been fit yet

Examples:

1
2
3
4
>>> # For acquisition functions
>>> pred_post = model.posterior(X_candidates, probes=X_probes)
>>> mean = pred_post.mean
>>> var = pred_post.variance
1
2
3
>>> # For diagnostics
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(42))
Source code in src/psyphy/model/base.py
def posterior(
    self,
    X: jnp.ndarray | None = None,
    *,
    probes: jnp.ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior:
    """
    Return posterior distribution.

    Parameters
    ----------
    X : jnp.ndarray | None
        Test stimuli (references), shape (n_test, input_dim).
        Required for predictive posteriors, optional for parameter posteriors.
    probes : jnp.ndarray | None
        Test probes, shape (n_test, input_dim).
        Required for predictive posteriors.
    kind : {"predictive", "parameter"}
        Type of posterior to return:
        - "predictive": PredictivePosterior over f(X*) [for acquisitions]
        - "parameter": ParameterPosterior over θ [for diagnostics]

    Returns
    -------
    PredictivePosterior | ParameterPosterior
        Posterior distribution

    Raises
    ------
    RuntimeError
        If model has not been fit yet

    Examples
    --------
    >>> # For acquisition functions
    >>> pred_post = model.posterior(X_candidates, probes=X_probes)
    >>> mean = pred_post.mean
    >>> var = pred_post.variance

    >>> # For diagnostics
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(42))
    """
    if self._posterior is None:
        raise RuntimeError("Must call fit() before posterior()")

    if kind == "parameter":
        return self._posterior
    elif kind == "predictive":
        if X is None:
            raise ValueError("X is required for predictive posteriors")
        from psyphy.posterior import WPPMPredictivePosterior

        return WPPMPredictivePosterior(self._posterior, X, probes=probes)
    else:
        raise ValueError(
            f"Unknown kind: '{kind}'. Use 'predictive' or 'parameter'."
        )

predict_prob

predict_prob(
    params: Params, stimulus: Stimulus, **task_kwargs: Any
) -> ndarray

Predict probability of a correct response for a single stimulus.

Design choice: WPPM computes discriminability & covariance; the TASK defines how that translates to performance. We therefore delegate to: task.predict(params, stimulus, model=self, noise=self.noise)

Parameters:

Name Type Description Default
params dict
required
stimulus (reference, probe)
required

Returns:

Name Type Description
p_correct ndarray
Source code in src/psyphy/model/wppm.py
def predict_prob(
    self, params: Params, stimulus: Stimulus, **task_kwargs: Any
) -> jnp.ndarray:
    """
    Predict probability of a correct response for a single stimulus.

    Design choice:
        WPPM computes discriminability & covariance; the TASK defines how
        that translates to performance. We therefore delegate to:
            task.predict(params, stimulus, model=self, noise=self.noise)

    Parameters
    ----------
    params : dict
    stimulus : (reference, probe)

    Returns
    -------
    p_correct : jnp.ndarray
    """
    # Strict task-owned configuration:
    # - MC control knobs (e.g. num_samples/bandwidth) live in the task config.
    # - WPPM.predict_prob therefore does not accept task-specific kwargs.
    if task_kwargs:
        unexpected = ", ".join(sorted(task_kwargs.keys()))
        raise TypeError(
            "WPPM.predict_prob does not accept task-specific kwargs. "
            "Configure task behavior via the task instance (e.g. OddityTaskConfig). "
            f"Unexpected: {unexpected}"
        )

    return self.task.predict(params, stimulus, self, self.noise)

predict_with_params

predict_with_params(
    X: ndarray,
    probes: ndarray | None,
    params: dict[str, ndarray],
) -> ndarray

Evaluate model at specific parameter values (no marginalization).

This is useful for: - Threshold uncertainty estimation (evaluate at sampled parameters) - Parameter sensitivity analysis - Debugging and diagnostics

NOT for making predictions (use .posterior() instead, which marginalizes over parameter uncertainty).

Parameters:

Name Type Description Default
X (ndarray, shape(n_test, input_dim))

Test stimuli (references)

required
probes (ndarray, shape(n_test, input_dim))

Probe stimuli (for discrimination tasks)

required
params dict[str, ndarray]

Specific parameter values to evaluate at. Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

required

Returns:

Name Type Description
predictions (ndarray, shape(n_test))

Predicted probabilities at each test point, given these parameters

Examples:

>>> # Sample parameters and evaluate
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(0))
>>>
>>> # Evaluate at first parameter sample
>>> params_0 = {k: v[0] for k, v in samples.items()}
>>> predictions = model.predict_with_params(X_test, probes, params_0)
>>>
>>> # Use for threshold uncertainty estimation
>>> threshold_locs = []
>>> for i in range(100):
...     params_i = {k: v[i] for k, v in samples.items()}
...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
...     threshold_locs.append(threshold_idx)
Notes

This bypasses the posterior marginalization. For acquisition functions, always use .posterior() which properly accounts for parameter uncertainty.

Source code in src/psyphy/model/base.py
def predict_with_params(
    self,
    X: jnp.ndarray,
    probes: jnp.ndarray | None,
    params: dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Evaluate model at specific parameter values (no marginalization).

    This is useful for:
    - Threshold uncertainty estimation (evaluate at sampled parameters)
    - Parameter sensitivity analysis
    - Debugging and diagnostics

    NOT for making predictions (use .posterior() instead, which
    marginalizes over parameter uncertainty).

    Parameters
    ----------
    X : jnp.ndarray, shape (n_test, input_dim)
        Test stimuli (references)
    probes : jnp.ndarray, shape (n_test, input_dim), optional
        Probe stimuli (for discrimination tasks)
    params : dict[str, jnp.ndarray]
        Specific parameter values to evaluate at.
        Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

    Returns
    -------
    predictions : jnp.ndarray, shape (n_test,)
        Predicted probabilities at each test point, given these parameters

    Examples
    --------
    >>> # Sample parameters and evaluate
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(0))
    >>>
    >>> # Evaluate at first parameter sample
    >>> params_0 = {k: v[0] for k, v in samples.items()}
    >>> predictions = model.predict_with_params(X_test, probes, params_0)
    >>>
    >>> # Use for threshold uncertainty estimation
    >>> threshold_locs = []
    >>> for i in range(100):
    ...     params_i = {k: v[i] for k, v in samples.items()}
    ...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
    ...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
    ...     threshold_locs.append(threshold_idx)

    Notes
    -----
    This bypasses the posterior marginalization. For acquisition functions,
    always use .posterior() which properly accounts for parameter uncertainty.
    """
    return self._forward(X, probes, params)