Skip to content

Model

Package


model

psyphy.model

Model-layer API: everything model-related in one place.

Includes
  • WPPM (core model)
  • Priors (Prior)
  • Tasks (TaskLikelihood base, OddityTask)
  • Noise models (GaussianNoise, StudentTNoise)

All functions/classes use JAX arrays (jax.numpy as jnp) for autodiff and optimization with Optax.

Typical usage
1
from psyphy.model import WPPM, Prior, OddityTask, GaussianNoise

Classes:

Name Description
GaussianNoise
Model

Abstract base class for psychophysical models.

OddityTask

Three-alternative forced-choice oddity task (MC-based only).

OnlineConfig

Configuration for online learning and memory management.

Prior

Prior distribution over WPPM parameters

StudentTNoise
TaskLikelihood

Abstract base class for task likelihoods

WPPM

Wishart Process Psychophysical Model (WPPM).

GaussianNoise

GaussianNoise(sigma: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Gaussian (mean=0, var=1).

Attributes:

Name Type Description
sigma float

sigma

sigma: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Gaussian (mean=0, var=1).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Gaussian (mean=0, var=1)."""
    return jr.normal(key, shape)

Model

Model(*, online_config: OnlineConfig | None = None)

Bases: ABC

Abstract base class for psychophysical models.

Provides API that mimics BoTorch style: - fit(X, y) --> train model - posterior(X) --> get predictions - condition_on_observations(X, y) --> online updates

Subclasses must implement: - init_params(key) --> sample initial parameters - log_likelihood_from_data(params, data) --> compute likelihood

Parameters:

Name Type Description Default
online_config OnlineConfig | None

Configuration for online learning. If None, uses default (unbounded memory).

None

Attributes:

Name Type Description
_posterior ParameterPosterior | None

Cached parameter posterior from last fit

_inference_engine InferenceEngine | None

Cached inference engine for warm-start refitting

_data_buffer ResponseData | None

Data buffer managed according to online_config

_n_updates int

Number of condition_on_observations calls

online_config OnlineConfig

Online learning configuration

Initialize model.

Parameters:

Name Type Description Default
online_config OnlineConfig | None

Online learning configuration. If None, uses default settings.

None

Methods:

Name Description
condition_on_observations

Update model with new observations (online learning).

fit

Fit model to data.

init_params

Sample initial parameters from prior.

log_likelihood_from_data

Compute log p(data | params).

posterior

Return posterior distribution.

predict_with_params

Evaluate model at specific parameter values (no marginalization).

Source code in src/psyphy/model/base.py
def __init__(self, *, online_config: OnlineConfig | None = None):
    """
    Initialize model.

    Parameters
    ----------
    online_config : OnlineConfig | None
        Online learning configuration. If None, uses default settings.
    """
    self._posterior: ParameterPosterior | None = None
    self._inference_engine: InferenceEngine | None = None
    self._data_buffer: ResponseData | None = None
    self._n_updates: int = 0
    self.online_config = online_config or OnlineConfig()

online_config

online_config = online_config or OnlineConfig()

condition_on_observations

condition_on_observations(X: ndarray, y: ndarray) -> Model

Update model with new observations (online learning).

Behavior depends on self.online_config.strategy: - "full": Accumulate all data, refit periodically - "sliding_window": Keep only recent window_size trials - "reservoir": Random sampling of window_size trials - "none": Refit from scratch (no caching)

Returns a NEW model instance (immutable update).

Parameters:

Name Type Description Default
X ndarray

New stimuli

required
y ndarray

New responses

required

Returns:

Type Description
Model

Updated model (new instance)

Examples:

1
2
3
4
5
>>> # Online learning loop
>>> model = WPPM(...).fit(X_init, y_init)
>>> for X_new, y_new in stream:
...     model = model.condition_on_observations(X_new, y_new)
...     # Model automatically manages memory and refitting
Source code in src/psyphy/model/base.py
def condition_on_observations(self, X: jnp.ndarray, y: jnp.ndarray) -> Model:
    """
    Update model with new observations (online learning).

    Behavior depends on self.online_config.strategy:
    - "full": Accumulate all data, refit periodically
    - "sliding_window": Keep only recent window_size trials
    - "reservoir": Random sampling of window_size trials
    - "none": Refit from scratch (no caching)

    Returns a NEW model instance (immutable update).

    Parameters
    ----------
    X : jnp.ndarray
        New stimuli
    y : jnp.ndarray
        New responses

    Returns
    -------
    Model
        Updated model (new instance)

    Examples
    --------
    >>> # Online learning loop
    >>> model = WPPM(...).fit(X_init, y_init)
    >>> for X_new, y_new in stream:
    ...     model = model.condition_on_observations(X_new, y_new)
    ...     # Model automatically manages memory and refitting
    """
    from psyphy.data import ResponseData

    # Convert new data
    new_data = ResponseData.from_arrays(X, y)

    # Update data buffer according to strategy
    if self.online_config.strategy == "none":
        data_to_fit = new_data

    elif self.online_config.strategy == "full":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "sliding_window":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)

        # Keep only last N trials
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for sliding_window"
        if len(self._data_buffer) > window_size:
            self._data_buffer = self._data_buffer.tail(window_size)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "reservoir":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()

        # Reservoir sampling
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for reservoir"
        self._data_buffer = self._reservoir_update(
            self._data_buffer,
            new_data,
            window_size,
        )
        data_to_fit = self._data_buffer
    else:
        raise ValueError(f"Unknown strategy: {self.online_config.strategy}")

    # Decide whether to refit
    self._n_updates += 1
    should_refit = self._n_updates % self.online_config.refit_interval == 0

    if not should_refit:
        # Return clone with updated buffer but old posterior
        new_model = self._clone()
        new_model._data_buffer = data_to_fit
        return new_model

    # Refit with optional warm start
    inference = self._inference_engine
    assert inference is not None, (
        "Model must be fit before condition_on_observations"
    )

    if self.online_config.warm_start and self._posterior is not None:
        # TODO: Add warm-start support to inference engines
        # For now, just refit from cached params
        pass

    new_model = self._clone()
    new_model._data_buffer = data_to_fit
    new_model._posterior = inference.fit(new_model, data_to_fit)
    new_model._inference_engine = inference

    return new_model

fit

fit(
    X: ndarray,
    y: ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model

Fit model to data.

Parameters:

Name Type Description Default
X ndarray

Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs or (n_trials, input_dim) for references only

required
y ndarray

Responses, shape (n_trials,)

required
inference InferenceEngine | str

Inference engine or string key ("map", "laplace", "langevin")

"laplace"
inference_config dict | None

Hyperparameters for string-based inference. Examples: {"steps": 500, "lr": 1e-3} for MAP

None

Returns:

Type Description
Model

Self for method chaining

Examples:

>>> # Simple: use defaults
>>> model.fit(X, y)
1
2
3
>>> # Explicit optimizer
>>> from psyphy.inference import MAPOptimizer
>>> model.fit(X, y, inference=MAPOptimizer(steps=500))
>>> # String + config (for experiment tracking)
>>> model.fit(X, y, inference="map", inference_config={"steps": 500})
Source code in src/psyphy/model/base.py
def fit(
    self,
    X: jnp.ndarray,
    y: jnp.ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model:
    """
    Fit model to data.

    Parameters
    ----------
    X : jnp.ndarray
        Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs
        or (n_trials, input_dim) for references only
    y : jnp.ndarray
        Responses, shape (n_trials,)
    inference : InferenceEngine | str, default="laplace"
        Inference engine or string key ("map", "laplace", "langevin")
    inference_config : dict | None
        Hyperparameters for string-based inference.
        Examples: {"steps": 500, "lr": 1e-3} for MAP

    Returns
    -------
    Model
        Self for method chaining

    Examples
    --------
    >>> # Simple: use defaults
    >>> model.fit(X, y)

    >>> # Explicit optimizer
    >>> from psyphy.inference import MAPOptimizer
    >>> model.fit(X, y, inference=MAPOptimizer(steps=500))

    >>> # String + config (for experiment tracking)
    >>> model.fit(X, y, inference="map", inference_config={"steps": 500})
    """
    from psyphy.data import ResponseData
    from psyphy.inference import INFERENCE_ENGINES, InferenceEngine

    # Resolve inference engine
    is_string_inference = isinstance(inference, str)

    if is_string_inference:
        config = inference_config or {}
        inference_key: str = inference  # type: ignore[assignment]
        if inference_key not in INFERENCE_ENGINES:
            available = ", ".join(INFERENCE_ENGINES.keys())
            raise ValueError(
                f"Unknown inference: '{inference}'. Available: {available}"
            )
        inference_engine: InferenceEngine = INFERENCE_ENGINES[inference_key](
            **config
        )
    elif isinstance(inference, InferenceEngine):
        inference_engine = inference
    else:
        raise TypeError(
            f"inference must be InferenceEngine or str, got {type(inference)}"
        )

    if inference_config is not None and not is_string_inference:
        raise ValueError(
            "Cannot pass inference_config with InferenceEngine instance"
        )

    # Convert data
    data = ResponseData.from_arrays(X, y)

    # Fit
    self._posterior = inference_engine.fit(self, data)
    self._inference_engine = inference_engine
    self._data_buffer = data  # Initialize buffer
    return self

init_params

init_params(key: Any) -> dict

Sample initial parameters from prior.

Parameters:

Name Type Description Default
key KeyArray

PRNG key

required

Returns:

Type Description
dict

Parameter PyTree

Source code in src/psyphy/model/base.py
@abstractmethod
def init_params(self, key: Any) -> dict:  # jax.random.KeyArray
    """
    Sample initial parameters from prior.

    Parameters
    ----------
    key : jax.random.KeyArray
        PRNG key

    Returns
    -------
    dict
        Parameter PyTree
    """
    ...

log_likelihood_from_data

log_likelihood_from_data(
    params: dict, data: ResponseData
) -> ndarray

Compute log p(data | params).

Parameters:

Name Type Description Default
params dict

Model parameters

required
data ResponseData

Observed trials

required

Returns:

Type Description
ndarray

Log-likelihood (scalar)

Source code in src/psyphy/model/base.py
@abstractmethod
def log_likelihood_from_data(self, params: dict, data: ResponseData) -> jnp.ndarray:
    """
    Compute log p(data | params).

    Parameters
    ----------
    params : dict
        Model parameters
    data : ResponseData
        Observed trials

    Returns
    -------
    jnp.ndarray
        Log-likelihood (scalar)
    """
    ...

posterior

posterior(
    X: ndarray | None = None,
    *,
    probes: ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior

Return posterior distribution.

Parameters:

Name Type Description Default
X ndarray | None

Test stimuli (references), shape (n_test, input_dim). Required for predictive posteriors, optional for parameter posteriors.

None
probes ndarray | None

Test probes, shape (n_test, input_dim). Required for predictive posteriors.

None
kind ('predictive', 'parameter')

Type of posterior to return: - "predictive": PredictivePosterior over f(X*) [for acquisitions] - "parameter": ParameterPosterior over θ [for diagnostics]

"predictive"

Returns:

Type Description
PredictivePosterior | ParameterPosterior

Posterior distribution

Raises:

Type Description
RuntimeError

If model has not been fit yet

Examples:

1
2
3
4
>>> # For acquisition functions
>>> pred_post = model.posterior(X_candidates, probes=X_probes)
>>> mean = pred_post.mean
>>> var = pred_post.variance
1
2
3
>>> # For diagnostics
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(42))
Source code in src/psyphy/model/base.py
def posterior(
    self,
    X: jnp.ndarray | None = None,
    *,
    probes: jnp.ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior:
    """
    Return posterior distribution.

    Parameters
    ----------
    X : jnp.ndarray | None
        Test stimuli (references), shape (n_test, input_dim).
        Required for predictive posteriors, optional for parameter posteriors.
    probes : jnp.ndarray | None
        Test probes, shape (n_test, input_dim).
        Required for predictive posteriors.
    kind : {"predictive", "parameter"}
        Type of posterior to return:
        - "predictive": PredictivePosterior over f(X*) [for acquisitions]
        - "parameter": ParameterPosterior over θ [for diagnostics]

    Returns
    -------
    PredictivePosterior | ParameterPosterior
        Posterior distribution

    Raises
    ------
    RuntimeError
        If model has not been fit yet

    Examples
    --------
    >>> # For acquisition functions
    >>> pred_post = model.posterior(X_candidates, probes=X_probes)
    >>> mean = pred_post.mean
    >>> var = pred_post.variance

    >>> # For diagnostics
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(42))
    """
    if self._posterior is None:
        raise RuntimeError("Must call fit() before posterior()")

    if kind == "parameter":
        return self._posterior
    elif kind == "predictive":
        if X is None:
            raise ValueError("X is required for predictive posteriors")
        from psyphy.posterior import WPPMPredictivePosterior

        return WPPMPredictivePosterior(self._posterior, X, probes=probes)
    else:
        raise ValueError(
            f"Unknown kind: '{kind}'. Use 'predictive' or 'parameter'."
        )

predict_with_params

predict_with_params(
    X: ndarray,
    probes: ndarray | None,
    params: dict[str, ndarray],
) -> ndarray

Evaluate model at specific parameter values (no marginalization).

This is useful for: - Threshold uncertainty estimation (evaluate at sampled parameters) - Parameter sensitivity analysis - Debugging and diagnostics

NOT for making predictions (use .posterior() instead, which marginalizes over parameter uncertainty).

Parameters:

Name Type Description Default
X (ndarray, shape(n_test, input_dim))

Test stimuli (references)

required
probes (ndarray, shape(n_test, input_dim))

Probe stimuli (for discrimination tasks)

required
params dict[str, ndarray]

Specific parameter values to evaluate at. Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

required

Returns:

Name Type Description
predictions (ndarray, shape(n_test))

Predicted probabilities at each test point, given these parameters

Examples:

>>> # Sample parameters and evaluate
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(0))
>>>
>>> # Evaluate at first parameter sample
>>> params_0 = {k: v[0] for k, v in samples.items()}
>>> predictions = model.predict_with_params(X_test, probes, params_0)
>>>
>>> # Use for threshold uncertainty estimation
>>> threshold_locs = []
>>> for i in range(100):
...     params_i = {k: v[i] for k, v in samples.items()}
...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
...     threshold_locs.append(threshold_idx)
Notes

This bypasses the posterior marginalization. For acquisition functions, always use .posterior() which properly accounts for parameter uncertainty.

Source code in src/psyphy/model/base.py
def predict_with_params(
    self,
    X: jnp.ndarray,
    probes: jnp.ndarray | None,
    params: dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Evaluate model at specific parameter values (no marginalization).

    This is useful for:
    - Threshold uncertainty estimation (evaluate at sampled parameters)
    - Parameter sensitivity analysis
    - Debugging and diagnostics

    NOT for making predictions (use .posterior() instead, which
    marginalizes over parameter uncertainty).

    Parameters
    ----------
    X : jnp.ndarray, shape (n_test, input_dim)
        Test stimuli (references)
    probes : jnp.ndarray, shape (n_test, input_dim), optional
        Probe stimuli (for discrimination tasks)
    params : dict[str, jnp.ndarray]
        Specific parameter values to evaluate at.
        Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

    Returns
    -------
    predictions : jnp.ndarray, shape (n_test,)
        Predicted probabilities at each test point, given these parameters

    Examples
    --------
    >>> # Sample parameters and evaluate
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(0))
    >>>
    >>> # Evaluate at first parameter sample
    >>> params_0 = {k: v[0] for k, v in samples.items()}
    >>> predictions = model.predict_with_params(X_test, probes, params_0)
    >>>
    >>> # Use for threshold uncertainty estimation
    >>> threshold_locs = []
    >>> for i in range(100):
    ...     params_i = {k: v[i] for k, v in samples.items()}
    ...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
    ...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
    ...     threshold_locs.append(threshold_idx)

    Notes
    -----
    This bypasses the posterior marginalization. For acquisition functions,
    always use .posterior() which properly accounts for parameter uncertainty.
    """
    return self._forward(X, probes, params)

OddityTask

OddityTask(config: OddityTaskConfig | None = None)

Bases: TaskLikelihood

Three-alternative forced-choice oddity task (MC-based only).

Implements the full 3-stimulus oddity task using Monte Carlo simulation: - Samples three internal representations per trial (z0, z1, z2) - Uses proper oddity decision rule with three pairwise distances - Suitable for complex covariance structures

Notes

MC simulation in loglik() (full 3-stimulus oddity): 1. Sample three internal representations: z_ref, z_refprime ~ N(ref, Σ_ref), z_comparison ~ N(comparison, Σ_comparison) 2. Compute average covariance: Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison 3. Compute three pairwise Mahalanobis distances: - d^2(z_ref, z_refprime) = distance between two reference samples - d^2(z_ref, z_comparison) = distance from ref to comparison - d^2(z_refprime, z_comparison) = distance from reference_prime to comparison 4. Apply oddity decision rule: delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime) 5. Logistic smoothing: P(correct) pprox logistic.cdf(delta / bandwidth) 6. Average over samples

Examples:

>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.task import OddityTaskConfig
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.noise import GaussianNoise
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>>
>>> # Create task and model (task-owned MC controls)
>>> task = OddityTask(config=OddityTaskConfig(num_samples=1000, bandwidth=1e-2))
>>> model = WPPM(
...     input_dim=2, prior=Prior(input_dim=2), task=task, noise=GaussianNoise()
... )
>>> params = model.init_params(jr.PRNGKey(0))
1
2
3
4
5
6
>>> # MC simulation
>>> from psyphy.data.dataset import ResponseData
>>> data = ResponseData()
>>> data.add_trial(ref, comparison, resp=1)
>>> ll_mc = task.loglik(params, data, model, model.noise, key=jr.PRNGKey(42))
>>> print(f"Log-likelihood (MC): {ll_mc:.4f}")

Methods:

Name Description
loglik
1
Compute log-likelihood via Monte Carlo observer simulation.
predict

Predict p(correct) for a single (ref, comparison) stimulus.

Attributes:

Name Type Description
config
Source code in src/psyphy/model/task.py
def __init__(self, config: OddityTaskConfig | None = None) -> None:
    # No analytical parameters in MC-only mode.
    self.config = config or OddityTaskConfig()

config

config = config or OddityTaskConfig()

loglik

loglik(
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    **kwargs: Any,
) -> ndarray
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Compute log-likelihood via Monte Carlo observer simulation.

This method implements the FULL 3-stimulus oddity task. Instead of using
an analytical approximation, we:
1. Sample three internal noisy representations per trial:
   - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
   - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
2. Compute three pairwise Mahalanobis distances
3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
4. Apply logistic smoothing to approximate P(correct)
5. Average over MC samples
1
Parameters
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
params : Any
    Model parameters as expected by ``model._compute_sqrt``.
data : ResponseData
    Trial data with refs, comparisons, and responses
model : WPPM
    Model instance providing ``_compute_sqrt`` for covariance computation.
noise : NoiseModel
    Observer noise model (provides ``sample_standard``).
key : jax.random.PRNGKey, optional
    Random key for reproducible sampling.
    If None, uses ``OddityTaskConfig.default_key_seed``.
1
Returns
1
2
3
jnp.ndarray
    Scalar sum of log-likelihoods over all trials.
    Same shape and interpretation as ``loglik``.
1
Raises
1
2
3
4
TypeError
    If ``num_samples`` or ``bandwidth`` are provided as kwargs.
ValueError
    If the task configuration is invalid (e.g. ``num_samples <= 0``).
1
Notes
1
2
3
4
5
6
Monte Carlo controls (``num_samples``, ``bandwidth``) are owned by the
task configuration:

- Create the task with ``OddityTask(config=OddityTaskConfig(...))``.
- Pass only the PRNG ``key`` at call time when you want to control
  randomness.
1
Notes
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
**Full 3-stimulus oddity task algorithm:**

For each trial (ref, comparison, response):
1. Compute covariances:
   - Σ_ref = U_ref @ U_ref.T + σ^2 I
   - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
   - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

2. Sample three internal representations:
   - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
   - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

3. Compute three pairwise Mahalanobis distances:
   - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
   - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
   - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

4. Apply oddity decision rule:
   - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
   - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

5. Apply logistic smoothing:
   - P(correct) pprox mean(logistic.cdf(delta / bandwidth))

6. Bernoulli log-likelihood:
   - LL = Σ [y * log(p) + (1-y) * log(1-p)]

Performance:
- Memory: O(num_samples * input_dim) per trial
- Vectorized across trials using jax.vmap for GPU acceleration
  • Can be JIT-compiled for additional speed (future optimization)
1
Examples
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.noise import GaussianNoise
>>> from psyphy.data.dataset import ResponseData
>>>
>>> # Setup
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2, basis_degree=3),
...     task=OddityTask(),
...     noise=GaussianNoise(sigma=0.03),
... )
>>> params = model.init_params(jr.PRNGKey(0))
>>>
>>> # Create trial data
>>> data = ResponseData()
>>> data.add_trial(
...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
... )
>>>
>>> loglik = model.task.loglik(
...     params,
...     data,
...     model,
...     model.noise,
...     num_samples=5000,
...     bandwidth=1e-3,
...     key=jr.PRNGKey(42),
... )
>>> print(f"MC (N=5000): {loglik:.4f}")
Source code in src/psyphy/model/task.py
def loglik(
    self, params: Any, data: Any, model: Any, noise: Any, **kwargs: Any
) -> jnp.ndarray:
    """
        Compute log-likelihood via Monte Carlo observer simulation.

        This method implements the FULL 3-stimulus oddity task. Instead of using
        an analytical approximation, we:
        1. Sample three internal noisy representations per trial:
           - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
           - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
        2. Compute three pairwise Mahalanobis distances
        3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
        4. Apply logistic smoothing to approximate P(correct)
        5. Average over MC samples

        Parameters
        ----------
        params : Any
            Model parameters as expected by ``model._compute_sqrt``.
        data : ResponseData
            Trial data with refs, comparisons, and responses
        model : WPPM
            Model instance providing ``_compute_sqrt`` for covariance computation.
        noise : NoiseModel
            Observer noise model (provides ``sample_standard``).
        key : jax.random.PRNGKey, optional
            Random key for reproducible sampling.
            If None, uses ``OddityTaskConfig.default_key_seed``.

        Returns
        -------
        jnp.ndarray
            Scalar sum of log-likelihoods over all trials.
            Same shape and interpretation as ``loglik``.

        Raises
        ------
        TypeError
            If ``num_samples`` or ``bandwidth`` are provided as kwargs.
        ValueError
            If the task configuration is invalid (e.g. ``num_samples <= 0``).

        Notes
        -----
        Monte Carlo controls (``num_samples``, ``bandwidth``) are owned by the
        task configuration:

        - Create the task with ``OddityTask(config=OddityTaskConfig(...))``.
        - Pass only the PRNG ``key`` at call time when you want to control
          randomness.

        Notes
        -----
        **Full 3-stimulus oddity task algorithm:**

        For each trial (ref, comparison, response):
        1. Compute covariances:
           - Σ_ref = U_ref @ U_ref.T + σ^2 I
           - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
           - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

        2. Sample three internal representations:
           - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
           - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

        3. Compute three pairwise Mahalanobis distances:
           - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
           - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
           - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

        4. Apply oddity decision rule:
           - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
           - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

        5. Apply logistic smoothing:
           - P(correct) \approx mean(logistic.cdf(delta / bandwidth))

        6. Bernoulli log-likelihood:
           - LL = Σ [y * log(p) + (1-y) * log(1-p)]

        Performance:
        - Memory: O(num_samples * input_dim) per trial
        - Vectorized across trials using jax.vmap for GPU acceleration
    - Can be JIT-compiled for additional speed (future optimization)

        Examples
        --------
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> from psyphy.model import WPPM, Prior
        >>> from psyphy.model.task import OddityTask
        >>> from psyphy.model.noise import GaussianNoise
        >>> from psyphy.data.dataset import ResponseData
        >>>
        >>> # Setup
        >>> model = WPPM(
        ...     input_dim=2,
        ...     prior=Prior(input_dim=2, basis_degree=3),
        ...     task=OddityTask(),
        ...     noise=GaussianNoise(sigma=0.03),
        ... )
        >>> params = model.init_params(jr.PRNGKey(0))
        >>>
        >>> # Create trial data
        >>> data = ResponseData()
        >>> data.add_trial(
        ...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
        ... )
        >>>
        >>> loglik = model.task.loglik(
        ...     params,
        ...     data,
        ...     model,
        ...     model.noise,
        ...     num_samples=5000,
        ...     bandwidth=1e-3,
        ...     key=jr.PRNGKey(42),
        ... )
        >>> print(f"MC (N=5000): {loglik:.4f}")


    """
    # Task is the single source of truth for MC controls.
    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)

    # Only PRNG key is accepted dynamically.
    key = kwargs.pop("key", None)
    if "num_samples" in kwargs or "bandwidth" in kwargs:
        raise TypeError(
            "OddityTask.loglik does not accept 'num_samples'/'bandwidth' overrides. "
            "Configure them via OddityTaskConfig when constructing the task."
        )
    if kwargs:
        unexpected = ", ".join(sorted(kwargs.keys()))
        raise TypeError(
            f"Unexpected keyword arguments for OddityTask.loglik: {unexpected}"
        )

    if num_samples <= 0:
        raise ValueError(f"num_samples must be > 0, got {num_samples}")
    if bandwidth <= 0:
        raise ValueError(f"bandwidth must be > 0, got {bandwidth}")

    if key is None:
        key = jr.PRNGKey(int(self.config.default_key_seed))

    # Unpack trial data
    refs, comparisons, responses = data.to_numpy()
    n_trials = len(refs)

    # Split keys for each trial (ensures independent sampling)
    trial_keys = jr.split(key, n_trials)

    # Vectorized computation of P(correct) for all trials
    # This processes all trials in parallel using jax.vmap
    # Note: probabilities are already clipped in _simulate_trial_mc()
    probs = self._simulate_trials_mc_vectorized(
        params=params,
        refs=refs,
        comparisons=comparisons,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        trial_keys=trial_keys,
    )

    # Bernoulli log-likelihood: LL = Σ [y log(p) + (1-y) log(1-p)]
    # Probabilities are already clipped to [eps, 1-eps] so log is safe
    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),  # Correct response
        jnp.log(1.0 - probs),  # Incorrect response
    )

    return jnp.sum(log_likelihoods)

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict p(correct) for a single (ref, comparison) stimulus.

Even though OddityTask is MC-only, we still implement predict. Reason: large parts of the library (posterior predictive, acquisition functions, diagnostics, etc.) need a forward model that returns p(correct) at candidate stimuli. Historically this used an analytical approximation, but in MC-only mode we compute it via simulation.

Notes
  • This method is intentionally lightweight: it performs the same single-trial Monte Carlo simulation used by loglik. - If you need to control MC fidelity/smoothing, set OddityTaskConfig(num_samples=..., bandwidth=...) when you construct the task. - If you need reproducible randomness, pass key=... to loglik.
Source code in src/psyphy/model/task.py
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """Predict p(correct) for a single (ref, comparison) stimulus.

    Even though OddityTask is *MC-only*, we still implement ``predict``.
    Reason: large parts of the library (posterior predictive, acquisition
    functions, diagnostics, etc.) need a forward model that returns
    p(correct) at candidate stimuli. Historically this used an analytical
    approximation, but in MC-only mode we compute it via simulation.

    Notes
    -----
    - This method is intentionally lightweight: it performs the same
      single-trial Monte Carlo simulation used by ``loglik``.
            - If you need to control MC fidelity/smoothing, set
                ``OddityTaskConfig(num_samples=..., bandwidth=...)`` when you
                construct the task.
            - If you need reproducible randomness, pass ``key=...`` to ``loglik``.
    """

    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)
    key = jr.PRNGKey(int(self.config.default_key_seed))

    ref, comparison = stimuli
    return self._simulate_trial_mc(
        params=params,
        ref=ref,
        comparison=comparison,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        key=key,
    )

OnlineConfig

OnlineConfig(
    strategy: Literal[
        "full", "sliding_window", "reservoir", "none"
    ] = "full",
    window_size: int | None = None,
    refit_interval: int = 1,
    warm_start: bool = True,
)

Configuration for online learning and memory management.

Attributes:

Name Type Description
strategy {'full', 'sliding_window', 'reservoir', 'none'}

Data retention strategy: - "full": Keep all data (unbounded memory) - "sliding_window": Keep only last N trials (FIFO) - "reservoir": Reservoir sampling for uniform coverage - "none": No caching, refit from scratch each time

window_size int | None

Maximum number of trials to retain (for sliding_window/reservoir). Required for sliding_window and reservoir strategies.

refit_interval int

Refit model every N updates (1=always, 10=batch every 10 trials). Trades off accuracy vs. computational cost.

warm_start bool

If True, initialize refitting from cached parameters. Speeds up convergence for small updates.

Examples:

>>> # Unbounded memory (default)
>>> config = OnlineConfig(strategy="full")
1
2
3
4
5
6
>>> # Sliding window: keep last 10K trials
>>> config = OnlineConfig(
...     strategy="sliding_window",
...     window_size=10_000,
...     refit_interval=10,  # Batch every 10 trials
... )
1
2
3
4
5
>>> # Reservoir sampling: uniform coverage with 5K trials
>>> config = OnlineConfig(
...     strategy="reservoir",
...     window_size=5_000,
... )

refit_interval

refit_interval: int = 1

strategy

strategy: Literal[
    "full", "sliding_window", "reservoir", "none"
] = "full"

warm_start

warm_start: bool = True

window_size

window_size: int | None = None

Prior

Prior(
    input_dim: int,
    basis_degree: int | None = None,
    variance_scale: float = 1.0,
    decay_rate: float = 0.5,
    extra_embedding_dims: int = 0,
)

Prior distribution over WPPM parameters

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the model space (same as WPPM.input_dim)

required
basis_degree int | None

Degree of Chebyshev basis for Wishart process. If set, uses Wishart mode with W coefficients.

None
variance_scale float

Prior variance for degree-0 (constant) coefficient in Wishart mode. Controls overall scale of covariances.

1.0
decay_rate float

Geometric decay rate for prior variance over higher-degree coefficients. Prior variance for degree-d coefficient = variance_scale * (decay_rate^d). Smaller decay_rate -> stronger smoothness prior.

0.5
extra_embedding_dims int

Additional latent dimensions in U matrices beyond input dimensions. Allows richer ellipsoid shapes in Wishart mode.

0

Methods:

Name Description
log_prob

Compute log prior density (up to a constant)

sample_params

Sample initial parameters from the prior.

Attributes:

Name Type Description
basis_degree int | None
decay_rate float
extra_embedding_dims int
input_dim int
variance_scale float

basis_degree

basis_degree: int | None = None

decay_rate

decay_rate: float = 0.5

extra_embedding_dims

extra_embedding_dims: int = 0

input_dim

input_dim: int

variance_scale

variance_scale: float = 1.0

log_prob

log_prob(params: Params) -> ndarray

Compute log prior density (up to a constant)

Gaussian prior on W with smoothness via decay_rate log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

Parameters:

Name Type Description Default
params dict

Parameter dictionary

required

Returns:

Name Type Description
log_prob float

Log prior probability (up to normalizing constant)

Source code in src/psyphy/model/prior.py
def log_prob(self, params: Params) -> jnp.ndarray:
    """
    Compute log prior density (up to a constant)

    Gaussian prior on W with smoothness via decay_rate
        log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

    Parameters
    ----------
    params : dict
        Parameter dictionary

    Returns
    -------
    log_prob : float
        Log prior probability (up to normalizing constant)
    """

    if "W" in params:
        # Wishart mode
        W = params["W"]
        variances = self._compute_W_prior_variances()

        # Gaussian log probability for each entry
        # log N(x | 0, σ^2) = -0.5 * (x^2/σ^2 + log(2πσ^2))
        # Up to constant: -0.5 * x^2/σ^2

        if self.input_dim == 2:
            # Each W[i,j,:,:] ~ Normal(0, variance[i,j] * I)
            return -0.5 * jnp.sum((W**2) / (variances[:, :, None, None] + 1e-10))
        elif self.input_dim == 3:
            return -0.5 * jnp.sum((W**2) / (variances[:, :, :, None, None] + 1e-10))

    raise ValueError("params must contain weights 'W'")

sample_params

sample_params(key: Any) -> Params

Sample initial parameters from the prior.

Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)} for 2D, where embedding_dim = input_dim + extra_embedding_dims

Note: The 3rd dimension is input_dim (output space dimension). This matches the einsum in _compute_sqrt: U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

Parameters:

Name Type Description Default
key JAX random key
required

Returns:

Name Type Description
params dict

Parameter dictionary

Source code in src/psyphy/model/prior.py
def sample_params(self, key: Any) -> Params:
    """
    Sample initial parameters from the prior.


    Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)}
    for 2D, where embedding_dim = input_dim + extra_embedding_dims

    Note: The 3rd dimension is input_dim (output space dimension).
    This matches the einsum in _compute_sqrt:
    U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

    Parameters
    ----------
    key : JAX random key

    Returns
    -------
    params : dict
        Parameter dictionary
    """
    if self.basis_degree is None:
        raise ValueError(
            "'basis_degree' is None; please set "
            "`Prior.basis_degree` to an integer >0."
        )

    # Basis function coefficients W
    variances = self._compute_W_prior_variances()
    embedding_dim = self.input_dim + self.extra_embedding_dims

    if self.input_dim == 2:
        # Sample W ~ Normal(0, variances) for each matrix entry
        # Shape: (degree+1, degree+1, input_dim, embedding_dim)
        # Note: degree+1 to match number of basis functions [T_0, ..., T_degree]
        W = jnp.sqrt(variances)[:, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    elif self.input_dim == 3:
        # Shape: (degree+1, degree+1, degree+1, input_dim, embedding_dim)
        W = jnp.sqrt(variances)[:, :, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    else:
        raise NotImplementedError(
            f"Wishart process only supports 2D and 3D. Got input_dim={self.input_dim}"
        )

    return {"W": W}

StudentTNoise

StudentTNoise(df: float = 3.0, scale: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Student-t (df=self.df).

Attributes:

Name Type Description
df float
scale float

df

df: float = 3.0

scale

scale: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Student-t (df=self.df).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Student-t (df=self.df)."""
    return jr.t(key, self.df, shape)

TaskLikelihood

Bases: ABC

Abstract base class for task likelihoods

Methods:

Name Description
loglik

Compute log-likelihood of observed responses under this task.

predict

Predict probability of correct response for a stimulus.

loglik

loglik(
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    **kwargs: Any,
) -> ndarray

Compute log-likelihood of observed responses under this task.

Why **kwargs? - Different tasks may need different optional runtime controls. - MC-based tasks may need parameters such as a PRNG key. In particular, :class:OddityTask takes Monte Carlo controls (num_samples and bandwidth) exclusively from :class:OddityTaskConfig to avoid silent mismatch bugs.

Notes
  • Task implementations should document which kwargs they accept.
  • Callers should not assume arbitrary kwargs are supported.
Source code in src/psyphy/model/task.py
@abstractmethod
def loglik(
    self, params: Any, data: Any, model: Any, noise: Any, **kwargs: Any
) -> jnp.ndarray:
    """Compute log-likelihood of observed responses under this task.

    Why ``**kwargs``?
    - Different tasks may need different optional runtime controls.
    - MC-based tasks may need parameters such as a PRNG ``key``.
        In particular, :class:`OddityTask` takes Monte Carlo controls
        (``num_samples`` and ``bandwidth``) exclusively from
        :class:`OddityTaskConfig` to avoid silent mismatch bugs.

    Notes
    -----
    - Task implementations should document which kwargs they accept.
    - Callers should not assume arbitrary kwargs are supported.
    """
    ...

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict probability of correct response for a stimulus.

Source code in src/psyphy/model/task.py
@abstractmethod
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """Predict probability of correct response for a stimulus."""
    ...

WPPM

WPPM(
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,
    online_config: OnlineConfig | None = None,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    decay_rate: float = 1.0,
    diag_term: float = 1e-06,
    **model_kwargs: Any,
)

Bases: Model

Wishart Process Psychophysical Model (WPPM).

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input stimulus space (e.g., 2 for isoluminant plane, 3 for RGB). Both reference and probe live in R^{input_dim}.

required
prior Prior

Prior distribution over model parameters. Controls basis_degree in WPPM (basis expansion). The WPPM delegates to prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

required
task TaskLikelihood

Psychophysical task mapping that defines how discriminability translates to p(correct) and how log-likelihood of responses is computed. (e.g., OddityTask)

required
noise Any

Noise model describing internal representation noise (e.g., GaussianNoise).

None
Forward-compatible hyperparameters

extra_dims : int, default=0 Additional embedding dimensions for basis expansions (beyond input_dim). embedding_dim = input_dim + extra_dims. variance_scale : float, default=1.0 Global scaling factor for covariance magnitude decay_rate : float, default=1.0 Smoothness/length-scale for spatial covariance variation diag_term : float, default=1e-6 Small positive value added to the covariance diagonal for numerical stability. online_config : OnlineConfig | None, optional (keyword-only) Base-model lifecycle / online-learning policy. This is the supported way to configure buffering and refit scheduling via Model.condition_on_observations.

model_kwargs : Any Reserved for future keyword arguments accepted by the base Model.__init__. Do not pass WPPM math knobs or task/likelihood knobs here.

Methods:

Name Description
condition_on_observations

Update model with new observations (online learning).

discriminability

Compute scalar discriminability d >= 0 for a (reference, probe) pair

fit

Fit model to data.

init_params

Sample initial parameters from the prior.

local_covariance

Return local covariance Σ(x) at stimulus location x.

log_likelihood

Compute the log-likelihood for arrays of trials.

log_likelihood_from_data

Compute log-likelihood directly from a ResponseData object.

log_posterior_from_data

Compute log posterior from data.

posterior

Return posterior distribution.

predict_prob

Predict probability of a correct response for a single stimulus.

predict_with_params

Evaluate model at specific parameter values (no marginalization).

Attributes:

Name Type Description
basis_degree int | None

Chebyshev polynomial degree for Wishart process basis expansion.

decay_rate
diag_term
embedding_dim int

Dimension of the embedding space (perceptual space).

extra_dims
input_dim
noise
online_config
prior
task
variance_scale
Source code in src/psyphy/model/wppm.py
def __init__(
    self,
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,  # everything after here is keyword-only
    online_config: OnlineConfig | None = None,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    decay_rate: float = 1.0,
    diag_term: float = 1e-6,
    **model_kwargs: Any,
) -> None:
    # Base-model configuration (lifecycle / online learning).
    #
    # `online_config` is the explicit, user-facing knob for online learning
    # and data retention (see `psyphy.model.base.OnlineConfig`).
    #
    # `model_kwargs` is reserved for *future* base `Model.__init__` kwargs.
    # It should NOT be used for WPPM-specific math (e.g. alternative covariance
    # parameterizations) or for task-specific likelihood knobs.
    if model_kwargs:
        known_misuses = {"num_samples", "bandwidth"}
        bad = sorted(known_misuses.intersection(model_kwargs.keys()))
        if bad:
            raise TypeError(
                "Do not pass task-specific kwargs via WPPM(..., **model_kwargs). "
                f"Move {bad} into the task config (e.g. OddityTaskConfig)."
            )

    super().__init__(online_config=online_config, **model_kwargs)

    # --- core components ---
    self.input_dim = int(input_dim)  # stimulus-space dimensionality
    self.prior = prior  # prior over parameter PyTree

    if self.prior.input_dim != self.input_dim:
        raise ValueError(
            f"Dimension mismatch: Model initialized with input_dim={self.input_dim}, "
            f"but Prior expects input_dim={self.prior.input_dim}."
        )

    self.task = task  # task mapping and likelihood
    self.noise = noise  # noise model

    self.extra_dims = int(extra_dims)
    self.variance_scale = float(variance_scale)
    self.decay_rate = float(decay_rate)
    self.diag_term = float(diag_term)

basis_degree

basis_degree: int | None

Chebyshev polynomial degree for Wishart process basis expansion.

This property delegates to self.prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

Returns:

Type Description
int | None

Degree of Chebyshev polynomial basis (0 = constant, 1 = linear, etc.)

Notes

WPPM gets its basis_degree parameter from Prior.basis_degree.

decay_rate

decay_rate = float(decay_rate)

diag_term

diag_term = float(diag_term)

embedding_dim

embedding_dim: int

Dimension of the embedding space (perceptual space).

embedding_dim = input_dim + extra_dims. this represents the full perceptual space where: - First input_dim dimensions correspond to observable stimulus features - Remaining extra_dims are latent dimensions

Returns:

Type Description
int

input_dim + extra_dims

Notes

This is a computed property, not a constructor parameter.

extra_dims

extra_dims = int(extra_dims)

input_dim

input_dim = int(input_dim)

noise

noise = noise

online_config

online_config = online_config or OnlineConfig()

prior

prior = prior

task

task = task

variance_scale

variance_scale = float(variance_scale)

condition_on_observations

condition_on_observations(X: ndarray, y: ndarray) -> Model

Update model with new observations (online learning).

Behavior depends on self.online_config.strategy: - "full": Accumulate all data, refit periodically - "sliding_window": Keep only recent window_size trials - "reservoir": Random sampling of window_size trials - "none": Refit from scratch (no caching)

Returns a NEW model instance (immutable update).

Parameters:

Name Type Description Default
X ndarray

New stimuli

required
y ndarray

New responses

required

Returns:

Type Description
Model

Updated model (new instance)

Examples:

1
2
3
4
5
>>> # Online learning loop
>>> model = WPPM(...).fit(X_init, y_init)
>>> for X_new, y_new in stream:
...     model = model.condition_on_observations(X_new, y_new)
...     # Model automatically manages memory and refitting
Source code in src/psyphy/model/base.py
def condition_on_observations(self, X: jnp.ndarray, y: jnp.ndarray) -> Model:
    """
    Update model with new observations (online learning).

    Behavior depends on self.online_config.strategy:
    - "full": Accumulate all data, refit periodically
    - "sliding_window": Keep only recent window_size trials
    - "reservoir": Random sampling of window_size trials
    - "none": Refit from scratch (no caching)

    Returns a NEW model instance (immutable update).

    Parameters
    ----------
    X : jnp.ndarray
        New stimuli
    y : jnp.ndarray
        New responses

    Returns
    -------
    Model
        Updated model (new instance)

    Examples
    --------
    >>> # Online learning loop
    >>> model = WPPM(...).fit(X_init, y_init)
    >>> for X_new, y_new in stream:
    ...     model = model.condition_on_observations(X_new, y_new)
    ...     # Model automatically manages memory and refitting
    """
    from psyphy.data import ResponseData

    # Convert new data
    new_data = ResponseData.from_arrays(X, y)

    # Update data buffer according to strategy
    if self.online_config.strategy == "none":
        data_to_fit = new_data

    elif self.online_config.strategy == "full":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "sliding_window":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)

        # Keep only last N trials
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for sliding_window"
        if len(self._data_buffer) > window_size:
            self._data_buffer = self._data_buffer.tail(window_size)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "reservoir":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()

        # Reservoir sampling
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for reservoir"
        self._data_buffer = self._reservoir_update(
            self._data_buffer,
            new_data,
            window_size,
        )
        data_to_fit = self._data_buffer
    else:
        raise ValueError(f"Unknown strategy: {self.online_config.strategy}")

    # Decide whether to refit
    self._n_updates += 1
    should_refit = self._n_updates % self.online_config.refit_interval == 0

    if not should_refit:
        # Return clone with updated buffer but old posterior
        new_model = self._clone()
        new_model._data_buffer = data_to_fit
        return new_model

    # Refit with optional warm start
    inference = self._inference_engine
    assert inference is not None, (
        "Model must be fit before condition_on_observations"
    )

    if self.online_config.warm_start and self._posterior is not None:
        # TODO: Add warm-start support to inference engines
        # For now, just refit from cached params
        pass

    new_model = self._clone()
    new_model._data_buffer = data_to_fit
    new_model._posterior = inference.fit(new_model, data_to_fit)
    new_model._inference_engine = inference

    return new_model

discriminability

discriminability(
    params: Params, stimulus: Stimulus
) -> ndarray

Compute scalar discriminability d >= 0 for a (reference, probe) pair

WPPM (rectangular U design) if extra_dims > 0: d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) ) where Σ(ref) is directly computed in stimulus space (input_dim, input_dim) via U(x) @ U(x)^T with U rectangular.

The discrimination task only depends on observable stimulus dimensions. The rectangular U design means local_covariance() already returns the stimulus covariance - no block extraction needed.

WPPM: d is implicit via Monte Carlo simulation of internal noisy responses under the task's decision rule (no closed form). In that case, tasks will directly implement predict/loglik with MC, and this method may be used only for diagnostics.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
stimulus tuple

(reference, probe) arrays of shape (input_dim,).

required

Returns:

Name Type Description
d ndarray

Nonnegative scalar discriminability.

Source code in src/psyphy/model/wppm.py
def discriminability(self, params: Params, stimulus: Stimulus) -> jnp.ndarray:
    """
    Compute scalar discriminability d >= 0 for a (reference, probe) pair


    WPPM (rectangular U design) if extra_dims > 0:
        d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) )
        where Σ(ref) is directly computed in stimulus space (input_dim, input_dim)
        via U(x) @ U(x)^T with U rectangular.

    The discrimination task only depends on observable stimulus dimensions.
    The rectangular U design means local_covariance() already returns
    the stimulus covariance - no block extraction needed.

    WPPM:
        d is implicit via Monte Carlo simulation of internal noisy responses
        under the task's decision rule (no closed form). In that case, tasks
        will directly implement predict/loglik with MC, and this method may be
        used only for diagnostics.

    Parameters
    ----------
    params : dict
        Model parameters.
    stimulus : tuple
        (reference, probe) arrays of shape (input_dim,).

    Returns
    -------
    d : jnp.ndarray
        Nonnegative scalar discriminability.
    """
    ref, probe = stimulus

    # Delta is in stimulus space (input_dim)
    delta = probe - ref

    # Get stimulus covariance at reference
    # (rectangular U design: already returns (input_dim, input_dim))
    Sigma = self.local_covariance(params, ref)

    # Add jitter for stable solve; diag_term is configurable
    jitter = self.diag_term * jnp.eye(self.input_dim)

    # Solve (Σ + jitter)^{-1} delta using a PD-aware solver
    x = jax.scipy.linalg.solve(Sigma + jitter, delta, assume_a="pos")
    d2 = jnp.dot(delta, x)  # quadratic form

    # Guard against tiny negative values from numerical error
    return jnp.sqrt(jnp.maximum(d2, 0.0))

fit

fit(
    X: ndarray,
    y: ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model

Fit model to data.

Parameters:

Name Type Description Default
X ndarray

Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs or (n_trials, input_dim) for references only

required
y ndarray

Responses, shape (n_trials,)

required
inference InferenceEngine | str

Inference engine or string key ("map", "laplace", "langevin")

"laplace"
inference_config dict | None

Hyperparameters for string-based inference. Examples: {"steps": 500, "lr": 1e-3} for MAP

None

Returns:

Type Description
Model

Self for method chaining

Examples:

>>> # Simple: use defaults
>>> model.fit(X, y)
1
2
3
>>> # Explicit optimizer
>>> from psyphy.inference import MAPOptimizer
>>> model.fit(X, y, inference=MAPOptimizer(steps=500))
>>> # String + config (for experiment tracking)
>>> model.fit(X, y, inference="map", inference_config={"steps": 500})
Source code in src/psyphy/model/base.py
def fit(
    self,
    X: jnp.ndarray,
    y: jnp.ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model:
    """
    Fit model to data.

    Parameters
    ----------
    X : jnp.ndarray
        Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs
        or (n_trials, input_dim) for references only
    y : jnp.ndarray
        Responses, shape (n_trials,)
    inference : InferenceEngine | str, default="laplace"
        Inference engine or string key ("map", "laplace", "langevin")
    inference_config : dict | None
        Hyperparameters for string-based inference.
        Examples: {"steps": 500, "lr": 1e-3} for MAP

    Returns
    -------
    Model
        Self for method chaining

    Examples
    --------
    >>> # Simple: use defaults
    >>> model.fit(X, y)

    >>> # Explicit optimizer
    >>> from psyphy.inference import MAPOptimizer
    >>> model.fit(X, y, inference=MAPOptimizer(steps=500))

    >>> # String + config (for experiment tracking)
    >>> model.fit(X, y, inference="map", inference_config={"steps": 500})
    """
    from psyphy.data import ResponseData
    from psyphy.inference import INFERENCE_ENGINES, InferenceEngine

    # Resolve inference engine
    is_string_inference = isinstance(inference, str)

    if is_string_inference:
        config = inference_config or {}
        inference_key: str = inference  # type: ignore[assignment]
        if inference_key not in INFERENCE_ENGINES:
            available = ", ".join(INFERENCE_ENGINES.keys())
            raise ValueError(
                f"Unknown inference: '{inference}'. Available: {available}"
            )
        inference_engine: InferenceEngine = INFERENCE_ENGINES[inference_key](
            **config
        )
    elif isinstance(inference, InferenceEngine):
        inference_engine = inference
    else:
        raise TypeError(
            f"inference must be InferenceEngine or str, got {type(inference)}"
        )

    if inference_config is not None and not is_string_inference:
        raise ValueError(
            "Cannot pass inference_config with InferenceEngine instance"
        )

    # Convert data
    data = ResponseData.from_arrays(X, y)

    # Fit
    self._posterior = inference_engine.fit(self, data)
    self._inference_engine = inference_engine
    self._data_buffer = data  # Initialize buffer
    return self

init_params

init_params(key: Array) -> Params

Sample initial parameters from the prior.

Returns:

Name Type Description
params dict[str, ndarray]
Source code in src/psyphy/model/wppm.py
def init_params(self, key: jax.Array) -> Params:
    """Sample initial parameters from the prior.

    Returns
    -------
    params : dict[str, jnp.ndarray]
    """
    return self.prior.sample_params(key)

local_covariance

local_covariance(params: Params, x: ndarray) -> ndarray

Return local covariance Σ(x) at stimulus location x.

Wishart mode (basis_degree set): Σ(x) = U(x) @ U(x)^T + diag_term * I where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0. - Varies smoothly with x - Guaranteed positive-definite - Returns stimulus covariance directly (input_dim, input_dim)

Parameters:

Name Type Description Default
params dict

Model parameters: - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}

required
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
Σ : jnp.ndarray, shape (input_dim, input_dim)

Covariance matrix in stimulus space.

Source code in src/psyphy/model/wppm.py
def local_covariance(self, params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    Return local covariance Σ(x) at stimulus location x.


    Wishart mode (basis_degree set):
        Σ(x) = U(x) @ U(x)^T + diag_term * I
        where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0.
        - Varies smoothly with x
        - Guaranteed positive-definite
        - Returns stimulus covariance directly (input_dim, input_dim)

    Parameters
    ----------
    params : dict
        Model parameters:
        - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    Σ : jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix in stimulus space.
    """
    # WPPM: spatially-varying covariance
    if "W" in params:
        U = self._compute_sqrt(params, x)  # (input_dim, embedding_dim)
        # Σ(x) = U(x) @ U(x)^T + diag_term * I
        # Result is (input_dim, input_dim)
        Sigma = U @ U.T + self.diag_term * jnp.eye(self.input_dim)
        return Sigma

    raise ValueError("params must contain 'W' (weights of WPPM)")

log_likelihood

log_likelihood(
    params: Params,
    refs: ndarray,
    probes: ndarray,
    responses: ndarray,
) -> ndarray

Compute the log-likelihood for arrays of trials.

IMPORTANT: We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV) or MC likelihood logic in multiple places.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
refs (ndarray, shape(N, input_dim))
required
probes (ndarray, shape(N, input_dim))
required
responses (ndarray, shape(N))

Typically 0/1; task may support richer encodings.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed)

1
2
3
4
5
6
7
8
9
Notes
-----
This method is intentionally strict and does not accept task-specific
runtime kwargs.

- Configure task behavior (e.g. MC fidelity/smoothing for ``OddityTask``)
    via the task instance passed to the model.
- If you need reproducible randomness, pass a ``key`` when calling the
    task directly.
Source code in src/psyphy/model/wppm.py
def log_likelihood(
    self,
    params: Params,
    refs: jnp.ndarray,
    probes: jnp.ndarray,
    responses: jnp.ndarray,
) -> jnp.ndarray:
    """
    Compute the log-likelihood for arrays of trials.

    IMPORTANT:
        We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV)
        or MC likelihood logic in multiple places.

    Parameters
    ----------
    params : dict
        Model parameters.
    refs : jnp.ndarray, shape (N, input_dim)
    probes : jnp.ndarray, shape (N, input_dim)
    responses : jnp.ndarray, shape (N,)
        Typically 0/1; task may support richer encodings.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed)

            Notes
            -----
            This method is intentionally strict and does not accept task-specific
            runtime kwargs.

            - Configure task behavior (e.g. MC fidelity/smoothing for ``OddityTask``)
                via the task instance passed to the model.
            - If you need reproducible randomness, pass a ``key`` when calling the
                task directly.
    """
    # We need a ResponseData-like object. To keep this method usable from
    # array inputs, we construct one on the fly. If you already have a
    # ResponseData instance, prefer `log_likelihood_from_data`.
    from psyphy.data.dataset import ResponseData  # local import to avoid cycles

    data = ResponseData()
    # ResponseData.add_trial(ref, probe, resp)
    for r, p, y in zip(refs, probes, responses):
        data.add_trial(r, p, int(y))
    return self.task.loglik(params, data, self, self.noise)

log_likelihood_from_data

log_likelihood_from_data(
    params: Params, data: Any
) -> ndarray

Compute log-likelihood directly from a ResponseData object.

Why delegate to the task? - The task knows the decision rule (oddity, 2AFC, ...). - The task can use the model (this WPPM) to fetch discriminabilities. - The task can use the noise model if it needs MC simulation.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data ResponseData

Collected trial data.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed).

Source code in src/psyphy/model/wppm.py
def log_likelihood_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """Compute log-likelihood directly from a ResponseData object.

    Why delegate to the task?
        - The task knows the decision rule (oddity, 2AFC, ...).
        - The task can use the model (this WPPM) to fetch discriminabilities.
        - The task can use the noise model if it needs MC simulation.

    Parameters
    ----------
    params : dict
        Model parameters.
    data : ResponseData
        Collected trial data.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed).
    """
    return self.task.loglik(params, data, self, self.noise)

log_posterior_from_data

log_posterior_from_data(
    params: Params, data: Any
) -> ndarray

Compute log posterior from data.

This simply adds the prior log-probability to the task log-likelihood. Inference engines (e.g., MAP optimizer) typically optimize this quantity.

Returns:

Type Description
ndarray

Scalar log posterior = loglik(params | data) + log_prior(params).

Source code in src/psyphy/model/wppm.py
def log_posterior_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """Compute log posterior from data.

    This simply adds the prior log-probability to the task log-likelihood.
    Inference engines (e.g., MAP optimizer) typically optimize this quantity.

    Returns
    -------
    jnp.ndarray
        Scalar log posterior = loglik(params | data) + log_prior(params).
    """
    return self.log_likelihood_from_data(params, data) + self.prior.log_prob(params)

posterior

posterior(
    X: ndarray | None = None,
    *,
    probes: ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior

Return posterior distribution.

Parameters:

Name Type Description Default
X ndarray | None

Test stimuli (references), shape (n_test, input_dim). Required for predictive posteriors, optional for parameter posteriors.

None
probes ndarray | None

Test probes, shape (n_test, input_dim). Required for predictive posteriors.

None
kind ('predictive', 'parameter')

Type of posterior to return: - "predictive": PredictivePosterior over f(X*) [for acquisitions] - "parameter": ParameterPosterior over θ [for diagnostics]

"predictive"

Returns:

Type Description
PredictivePosterior | ParameterPosterior

Posterior distribution

Raises:

Type Description
RuntimeError

If model has not been fit yet

Examples:

1
2
3
4
>>> # For acquisition functions
>>> pred_post = model.posterior(X_candidates, probes=X_probes)
>>> mean = pred_post.mean
>>> var = pred_post.variance
1
2
3
>>> # For diagnostics
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(42))
Source code in src/psyphy/model/base.py
def posterior(
    self,
    X: jnp.ndarray | None = None,
    *,
    probes: jnp.ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior:
    """
    Return posterior distribution.

    Parameters
    ----------
    X : jnp.ndarray | None
        Test stimuli (references), shape (n_test, input_dim).
        Required for predictive posteriors, optional for parameter posteriors.
    probes : jnp.ndarray | None
        Test probes, shape (n_test, input_dim).
        Required for predictive posteriors.
    kind : {"predictive", "parameter"}
        Type of posterior to return:
        - "predictive": PredictivePosterior over f(X*) [for acquisitions]
        - "parameter": ParameterPosterior over θ [for diagnostics]

    Returns
    -------
    PredictivePosterior | ParameterPosterior
        Posterior distribution

    Raises
    ------
    RuntimeError
        If model has not been fit yet

    Examples
    --------
    >>> # For acquisition functions
    >>> pred_post = model.posterior(X_candidates, probes=X_probes)
    >>> mean = pred_post.mean
    >>> var = pred_post.variance

    >>> # For diagnostics
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(42))
    """
    if self._posterior is None:
        raise RuntimeError("Must call fit() before posterior()")

    if kind == "parameter":
        return self._posterior
    elif kind == "predictive":
        if X is None:
            raise ValueError("X is required for predictive posteriors")
        from psyphy.posterior import WPPMPredictivePosterior

        return WPPMPredictivePosterior(self._posterior, X, probes=probes)
    else:
        raise ValueError(
            f"Unknown kind: '{kind}'. Use 'predictive' or 'parameter'."
        )

predict_prob

predict_prob(
    params: Params, stimulus: Stimulus, **task_kwargs: Any
) -> ndarray

Predict probability of a correct response for a single stimulus.

Design choice: WPPM computes discriminability & covariance; the TASK defines how that translates to performance. We therefore delegate to: task.predict(params, stimulus, model=self, noise=self.noise)

Parameters:

Name Type Description Default
params dict
required
stimulus (reference, probe)
required

Returns:

Name Type Description
p_correct ndarray
Source code in src/psyphy/model/wppm.py
def predict_prob(
    self, params: Params, stimulus: Stimulus, **task_kwargs: Any
) -> jnp.ndarray:
    """
    Predict probability of a correct response for a single stimulus.

    Design choice:
        WPPM computes discriminability & covariance; the TASK defines how
        that translates to performance. We therefore delegate to:
            task.predict(params, stimulus, model=self, noise=self.noise)

    Parameters
    ----------
    params : dict
    stimulus : (reference, probe)

    Returns
    -------
    p_correct : jnp.ndarray
    """
    # Strict task-owned configuration:
    # - MC control knobs (e.g. num_samples/bandwidth) live in the task config.
    # - WPPM.predict_prob therefore does not accept task-specific kwargs.
    if task_kwargs:
        unexpected = ", ".join(sorted(task_kwargs.keys()))
        raise TypeError(
            "WPPM.predict_prob does not accept task-specific kwargs. "
            "Configure task behavior via the task instance (e.g. OddityTaskConfig). "
            f"Unexpected: {unexpected}"
        )

    return self.task.predict(params, stimulus, self, self.noise)

predict_with_params

predict_with_params(
    X: ndarray,
    probes: ndarray | None,
    params: dict[str, ndarray],
) -> ndarray

Evaluate model at specific parameter values (no marginalization).

This is useful for: - Threshold uncertainty estimation (evaluate at sampled parameters) - Parameter sensitivity analysis - Debugging and diagnostics

NOT for making predictions (use .posterior() instead, which marginalizes over parameter uncertainty).

Parameters:

Name Type Description Default
X (ndarray, shape(n_test, input_dim))

Test stimuli (references)

required
probes (ndarray, shape(n_test, input_dim))

Probe stimuli (for discrimination tasks)

required
params dict[str, ndarray]

Specific parameter values to evaluate at. Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

required

Returns:

Name Type Description
predictions (ndarray, shape(n_test))

Predicted probabilities at each test point, given these parameters

Examples:

>>> # Sample parameters and evaluate
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(0))
>>>
>>> # Evaluate at first parameter sample
>>> params_0 = {k: v[0] for k, v in samples.items()}
>>> predictions = model.predict_with_params(X_test, probes, params_0)
>>>
>>> # Use for threshold uncertainty estimation
>>> threshold_locs = []
>>> for i in range(100):
...     params_i = {k: v[i] for k, v in samples.items()}
...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
...     threshold_locs.append(threshold_idx)
Notes

This bypasses the posterior marginalization. For acquisition functions, always use .posterior() which properly accounts for parameter uncertainty.

Source code in src/psyphy/model/base.py
def predict_with_params(
    self,
    X: jnp.ndarray,
    probes: jnp.ndarray | None,
    params: dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Evaluate model at specific parameter values (no marginalization).

    This is useful for:
    - Threshold uncertainty estimation (evaluate at sampled parameters)
    - Parameter sensitivity analysis
    - Debugging and diagnostics

    NOT for making predictions (use .posterior() instead, which
    marginalizes over parameter uncertainty).

    Parameters
    ----------
    X : jnp.ndarray, shape (n_test, input_dim)
        Test stimuli (references)
    probes : jnp.ndarray, shape (n_test, input_dim), optional
        Probe stimuli (for discrimination tasks)
    params : dict[str, jnp.ndarray]
        Specific parameter values to evaluate at.
        Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

    Returns
    -------
    predictions : jnp.ndarray, shape (n_test,)
        Predicted probabilities at each test point, given these parameters

    Examples
    --------
    >>> # Sample parameters and evaluate
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(0))
    >>>
    >>> # Evaluate at first parameter sample
    >>> params_0 = {k: v[0] for k, v in samples.items()}
    >>> predictions = model.predict_with_params(X_test, probes, params_0)
    >>>
    >>> # Use for threshold uncertainty estimation
    >>> threshold_locs = []
    >>> for i in range(100):
    ...     params_i = {k: v[i] for k, v in samples.items()}
    ...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
    ...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
    ...     threshold_locs.append(threshold_idx)

    Notes
    -----
    This bypasses the posterior marginalization. For acquisition functions,
    always use .posterior() which properly accounts for parameter uncertainty.
    """
    return self._forward(X, probes, params)

Wishart Psyochophysical Process Model (WPPM)


wppm

wppm.py

Wishart Process Psychophysical Model (WPPM)

Goals

Wishart Process Psychophysical Model (WPPM): - Expose hyperparameters needed to for example use Model config used in Hong et al.: * extra_dims: embedding size for basis expansions * variance_scale: global covariance scale * decay_rate: smoothness/length-scale for covariance field * diag_term: numerical stabilizer added to covariance diagonals - Later, replace local_covariance with a basis-expansion Wishart process and swap discriminability/likelihood with MC observer simulation.

All numerics use JAX (jax.numpy as jnp) to support autodiff and optax optimizers

Classes:

Name Description
WPPM

Wishart Process Psychophysical Model (WPPM).

Attributes:

Name Type Description
Params
Stimulus

Params

Params = dict[str, ndarray]

Stimulus

Stimulus = tuple[ndarray, ndarray]

WPPM

WPPM(
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,
    online_config: OnlineConfig | None = None,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    decay_rate: float = 1.0,
    diag_term: float = 1e-06,
    **model_kwargs: Any,
)

Bases: Model

Wishart Process Psychophysical Model (WPPM).

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input stimulus space (e.g., 2 for isoluminant plane, 3 for RGB). Both reference and probe live in R^{input_dim}.

required
prior Prior

Prior distribution over model parameters. Controls basis_degree in WPPM (basis expansion). The WPPM delegates to prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

required
task TaskLikelihood

Psychophysical task mapping that defines how discriminability translates to p(correct) and how log-likelihood of responses is computed. (e.g., OddityTask)

required
noise Any

Noise model describing internal representation noise (e.g., GaussianNoise).

None
Forward-compatible hyperparameters

extra_dims : int, default=0 Additional embedding dimensions for basis expansions (beyond input_dim). embedding_dim = input_dim + extra_dims. variance_scale : float, default=1.0 Global scaling factor for covariance magnitude decay_rate : float, default=1.0 Smoothness/length-scale for spatial covariance variation diag_term : float, default=1e-6 Small positive value added to the covariance diagonal for numerical stability. online_config : OnlineConfig | None, optional (keyword-only) Base-model lifecycle / online-learning policy. This is the supported way to configure buffering and refit scheduling via Model.condition_on_observations.

model_kwargs : Any Reserved for future keyword arguments accepted by the base Model.__init__. Do not pass WPPM math knobs or task/likelihood knobs here.

Methods:

Name Description
condition_on_observations

Update model with new observations (online learning).

discriminability

Compute scalar discriminability d >= 0 for a (reference, probe) pair

fit

Fit model to data.

init_params

Sample initial parameters from the prior.

local_covariance

Return local covariance Σ(x) at stimulus location x.

log_likelihood

Compute the log-likelihood for arrays of trials.

log_likelihood_from_data

Compute log-likelihood directly from a ResponseData object.

log_posterior_from_data

Compute log posterior from data.

posterior

Return posterior distribution.

predict_prob

Predict probability of a correct response for a single stimulus.

predict_with_params

Evaluate model at specific parameter values (no marginalization).

Attributes:

Name Type Description
basis_degree int | None

Chebyshev polynomial degree for Wishart process basis expansion.

decay_rate
diag_term
embedding_dim int

Dimension of the embedding space (perceptual space).

extra_dims
input_dim
noise
online_config
prior
task
variance_scale
Source code in src/psyphy/model/wppm.py
def __init__(
    self,
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,  # everything after here is keyword-only
    online_config: OnlineConfig | None = None,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    decay_rate: float = 1.0,
    diag_term: float = 1e-6,
    **model_kwargs: Any,
) -> None:
    # Base-model configuration (lifecycle / online learning).
    #
    # `online_config` is the explicit, user-facing knob for online learning
    # and data retention (see `psyphy.model.base.OnlineConfig`).
    #
    # `model_kwargs` is reserved for *future* base `Model.__init__` kwargs.
    # It should NOT be used for WPPM-specific math (e.g. alternative covariance
    # parameterizations) or for task-specific likelihood knobs.
    if model_kwargs:
        known_misuses = {"num_samples", "bandwidth"}
        bad = sorted(known_misuses.intersection(model_kwargs.keys()))
        if bad:
            raise TypeError(
                "Do not pass task-specific kwargs via WPPM(..., **model_kwargs). "
                f"Move {bad} into the task config (e.g. OddityTaskConfig)."
            )

    super().__init__(online_config=online_config, **model_kwargs)

    # --- core components ---
    self.input_dim = int(input_dim)  # stimulus-space dimensionality
    self.prior = prior  # prior over parameter PyTree

    if self.prior.input_dim != self.input_dim:
        raise ValueError(
            f"Dimension mismatch: Model initialized with input_dim={self.input_dim}, "
            f"but Prior expects input_dim={self.prior.input_dim}."
        )

    self.task = task  # task mapping and likelihood
    self.noise = noise  # noise model

    self.extra_dims = int(extra_dims)
    self.variance_scale = float(variance_scale)
    self.decay_rate = float(decay_rate)
    self.diag_term = float(diag_term)

basis_degree

basis_degree: int | None

Chebyshev polynomial degree for Wishart process basis expansion.

This property delegates to self.prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

Returns:

Type Description
int | None

Degree of Chebyshev polynomial basis (0 = constant, 1 = linear, etc.)

Notes

WPPM gets its basis_degree parameter from Prior.basis_degree.

decay_rate

decay_rate = float(decay_rate)

diag_term

diag_term = float(diag_term)

embedding_dim

embedding_dim: int

Dimension of the embedding space (perceptual space).

embedding_dim = input_dim + extra_dims. this represents the full perceptual space where: - First input_dim dimensions correspond to observable stimulus features - Remaining extra_dims are latent dimensions

Returns:

Type Description
int

input_dim + extra_dims

Notes

This is a computed property, not a constructor parameter.

extra_dims

extra_dims = int(extra_dims)

input_dim

input_dim = int(input_dim)

noise

noise = noise

online_config

online_config = online_config or OnlineConfig()

prior

prior = prior

task

task = task

variance_scale

variance_scale = float(variance_scale)

condition_on_observations

condition_on_observations(X: ndarray, y: ndarray) -> Model

Update model with new observations (online learning).

Behavior depends on self.online_config.strategy: - "full": Accumulate all data, refit periodically - "sliding_window": Keep only recent window_size trials - "reservoir": Random sampling of window_size trials - "none": Refit from scratch (no caching)

Returns a NEW model instance (immutable update).

Parameters:

Name Type Description Default
X ndarray

New stimuli

required
y ndarray

New responses

required

Returns:

Type Description
Model

Updated model (new instance)

Examples:

1
2
3
4
5
>>> # Online learning loop
>>> model = WPPM(...).fit(X_init, y_init)
>>> for X_new, y_new in stream:
...     model = model.condition_on_observations(X_new, y_new)
...     # Model automatically manages memory and refitting
Source code in src/psyphy/model/base.py
def condition_on_observations(self, X: jnp.ndarray, y: jnp.ndarray) -> Model:
    """
    Update model with new observations (online learning).

    Behavior depends on self.online_config.strategy:
    - "full": Accumulate all data, refit periodically
    - "sliding_window": Keep only recent window_size trials
    - "reservoir": Random sampling of window_size trials
    - "none": Refit from scratch (no caching)

    Returns a NEW model instance (immutable update).

    Parameters
    ----------
    X : jnp.ndarray
        New stimuli
    y : jnp.ndarray
        New responses

    Returns
    -------
    Model
        Updated model (new instance)

    Examples
    --------
    >>> # Online learning loop
    >>> model = WPPM(...).fit(X_init, y_init)
    >>> for X_new, y_new in stream:
    ...     model = model.condition_on_observations(X_new, y_new)
    ...     # Model automatically manages memory and refitting
    """
    from psyphy.data import ResponseData

    # Convert new data
    new_data = ResponseData.from_arrays(X, y)

    # Update data buffer according to strategy
    if self.online_config.strategy == "none":
        data_to_fit = new_data

    elif self.online_config.strategy == "full":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "sliding_window":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)

        # Keep only last N trials
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for sliding_window"
        if len(self._data_buffer) > window_size:
            self._data_buffer = self._data_buffer.tail(window_size)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "reservoir":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()

        # Reservoir sampling
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for reservoir"
        self._data_buffer = self._reservoir_update(
            self._data_buffer,
            new_data,
            window_size,
        )
        data_to_fit = self._data_buffer
    else:
        raise ValueError(f"Unknown strategy: {self.online_config.strategy}")

    # Decide whether to refit
    self._n_updates += 1
    should_refit = self._n_updates % self.online_config.refit_interval == 0

    if not should_refit:
        # Return clone with updated buffer but old posterior
        new_model = self._clone()
        new_model._data_buffer = data_to_fit
        return new_model

    # Refit with optional warm start
    inference = self._inference_engine
    assert inference is not None, (
        "Model must be fit before condition_on_observations"
    )

    if self.online_config.warm_start and self._posterior is not None:
        # TODO: Add warm-start support to inference engines
        # For now, just refit from cached params
        pass

    new_model = self._clone()
    new_model._data_buffer = data_to_fit
    new_model._posterior = inference.fit(new_model, data_to_fit)
    new_model._inference_engine = inference

    return new_model

discriminability

discriminability(
    params: Params, stimulus: Stimulus
) -> ndarray

Compute scalar discriminability d >= 0 for a (reference, probe) pair

WPPM (rectangular U design) if extra_dims > 0: d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) ) where Σ(ref) is directly computed in stimulus space (input_dim, input_dim) via U(x) @ U(x)^T with U rectangular.

The discrimination task only depends on observable stimulus dimensions. The rectangular U design means local_covariance() already returns the stimulus covariance - no block extraction needed.

WPPM: d is implicit via Monte Carlo simulation of internal noisy responses under the task's decision rule (no closed form). In that case, tasks will directly implement predict/loglik with MC, and this method may be used only for diagnostics.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
stimulus tuple

(reference, probe) arrays of shape (input_dim,).

required

Returns:

Name Type Description
d ndarray

Nonnegative scalar discriminability.

Source code in src/psyphy/model/wppm.py
def discriminability(self, params: Params, stimulus: Stimulus) -> jnp.ndarray:
    """
    Compute scalar discriminability d >= 0 for a (reference, probe) pair


    WPPM (rectangular U design) if extra_dims > 0:
        d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) )
        where Σ(ref) is directly computed in stimulus space (input_dim, input_dim)
        via U(x) @ U(x)^T with U rectangular.

    The discrimination task only depends on observable stimulus dimensions.
    The rectangular U design means local_covariance() already returns
    the stimulus covariance - no block extraction needed.

    WPPM:
        d is implicit via Monte Carlo simulation of internal noisy responses
        under the task's decision rule (no closed form). In that case, tasks
        will directly implement predict/loglik with MC, and this method may be
        used only for diagnostics.

    Parameters
    ----------
    params : dict
        Model parameters.
    stimulus : tuple
        (reference, probe) arrays of shape (input_dim,).

    Returns
    -------
    d : jnp.ndarray
        Nonnegative scalar discriminability.
    """
    ref, probe = stimulus

    # Delta is in stimulus space (input_dim)
    delta = probe - ref

    # Get stimulus covariance at reference
    # (rectangular U design: already returns (input_dim, input_dim))
    Sigma = self.local_covariance(params, ref)

    # Add jitter for stable solve; diag_term is configurable
    jitter = self.diag_term * jnp.eye(self.input_dim)

    # Solve (Σ + jitter)^{-1} delta using a PD-aware solver
    x = jax.scipy.linalg.solve(Sigma + jitter, delta, assume_a="pos")
    d2 = jnp.dot(delta, x)  # quadratic form

    # Guard against tiny negative values from numerical error
    return jnp.sqrt(jnp.maximum(d2, 0.0))

fit

fit(
    X: ndarray,
    y: ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model

Fit model to data.

Parameters:

Name Type Description Default
X ndarray

Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs or (n_trials, input_dim) for references only

required
y ndarray

Responses, shape (n_trials,)

required
inference InferenceEngine | str

Inference engine or string key ("map", "laplace", "langevin")

"laplace"
inference_config dict | None

Hyperparameters for string-based inference. Examples: {"steps": 500, "lr": 1e-3} for MAP

None

Returns:

Type Description
Model

Self for method chaining

Examples:

>>> # Simple: use defaults
>>> model.fit(X, y)
1
2
3
>>> # Explicit optimizer
>>> from psyphy.inference import MAPOptimizer
>>> model.fit(X, y, inference=MAPOptimizer(steps=500))
>>> # String + config (for experiment tracking)
>>> model.fit(X, y, inference="map", inference_config={"steps": 500})
Source code in src/psyphy/model/base.py
def fit(
    self,
    X: jnp.ndarray,
    y: jnp.ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model:
    """
    Fit model to data.

    Parameters
    ----------
    X : jnp.ndarray
        Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs
        or (n_trials, input_dim) for references only
    y : jnp.ndarray
        Responses, shape (n_trials,)
    inference : InferenceEngine | str, default="laplace"
        Inference engine or string key ("map", "laplace", "langevin")
    inference_config : dict | None
        Hyperparameters for string-based inference.
        Examples: {"steps": 500, "lr": 1e-3} for MAP

    Returns
    -------
    Model
        Self for method chaining

    Examples
    --------
    >>> # Simple: use defaults
    >>> model.fit(X, y)

    >>> # Explicit optimizer
    >>> from psyphy.inference import MAPOptimizer
    >>> model.fit(X, y, inference=MAPOptimizer(steps=500))

    >>> # String + config (for experiment tracking)
    >>> model.fit(X, y, inference="map", inference_config={"steps": 500})
    """
    from psyphy.data import ResponseData
    from psyphy.inference import INFERENCE_ENGINES, InferenceEngine

    # Resolve inference engine
    is_string_inference = isinstance(inference, str)

    if is_string_inference:
        config = inference_config or {}
        inference_key: str = inference  # type: ignore[assignment]
        if inference_key not in INFERENCE_ENGINES:
            available = ", ".join(INFERENCE_ENGINES.keys())
            raise ValueError(
                f"Unknown inference: '{inference}'. Available: {available}"
            )
        inference_engine: InferenceEngine = INFERENCE_ENGINES[inference_key](
            **config
        )
    elif isinstance(inference, InferenceEngine):
        inference_engine = inference
    else:
        raise TypeError(
            f"inference must be InferenceEngine or str, got {type(inference)}"
        )

    if inference_config is not None and not is_string_inference:
        raise ValueError(
            "Cannot pass inference_config with InferenceEngine instance"
        )

    # Convert data
    data = ResponseData.from_arrays(X, y)

    # Fit
    self._posterior = inference_engine.fit(self, data)
    self._inference_engine = inference_engine
    self._data_buffer = data  # Initialize buffer
    return self

init_params

init_params(key: Array) -> Params

Sample initial parameters from the prior.

Returns:

Name Type Description
params dict[str, ndarray]
Source code in src/psyphy/model/wppm.py
def init_params(self, key: jax.Array) -> Params:
    """Sample initial parameters from the prior.

    Returns
    -------
    params : dict[str, jnp.ndarray]
    """
    return self.prior.sample_params(key)

local_covariance

local_covariance(params: Params, x: ndarray) -> ndarray

Return local covariance Σ(x) at stimulus location x.

Wishart mode (basis_degree set): Σ(x) = U(x) @ U(x)^T + diag_term * I where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0. - Varies smoothly with x - Guaranteed positive-definite - Returns stimulus covariance directly (input_dim, input_dim)

Parameters:

Name Type Description Default
params dict

Model parameters: - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}

required
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
Σ : jnp.ndarray, shape (input_dim, input_dim)

Covariance matrix in stimulus space.

Source code in src/psyphy/model/wppm.py
def local_covariance(self, params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    Return local covariance Σ(x) at stimulus location x.


    Wishart mode (basis_degree set):
        Σ(x) = U(x) @ U(x)^T + diag_term * I
        where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0.
        - Varies smoothly with x
        - Guaranteed positive-definite
        - Returns stimulus covariance directly (input_dim, input_dim)

    Parameters
    ----------
    params : dict
        Model parameters:
        - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    Σ : jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix in stimulus space.
    """
    # WPPM: spatially-varying covariance
    if "W" in params:
        U = self._compute_sqrt(params, x)  # (input_dim, embedding_dim)
        # Σ(x) = U(x) @ U(x)^T + diag_term * I
        # Result is (input_dim, input_dim)
        Sigma = U @ U.T + self.diag_term * jnp.eye(self.input_dim)
        return Sigma

    raise ValueError("params must contain 'W' (weights of WPPM)")

log_likelihood

log_likelihood(
    params: Params,
    refs: ndarray,
    probes: ndarray,
    responses: ndarray,
) -> ndarray

Compute the log-likelihood for arrays of trials.

IMPORTANT: We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV) or MC likelihood logic in multiple places.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
refs (ndarray, shape(N, input_dim))
required
probes (ndarray, shape(N, input_dim))
required
responses (ndarray, shape(N))

Typically 0/1; task may support richer encodings.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed)

1
2
3
4
5
6
7
8
9
Notes
-----
This method is intentionally strict and does not accept task-specific
runtime kwargs.

- Configure task behavior (e.g. MC fidelity/smoothing for ``OddityTask``)
    via the task instance passed to the model.
- If you need reproducible randomness, pass a ``key`` when calling the
    task directly.
Source code in src/psyphy/model/wppm.py
def log_likelihood(
    self,
    params: Params,
    refs: jnp.ndarray,
    probes: jnp.ndarray,
    responses: jnp.ndarray,
) -> jnp.ndarray:
    """
    Compute the log-likelihood for arrays of trials.

    IMPORTANT:
        We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV)
        or MC likelihood logic in multiple places.

    Parameters
    ----------
    params : dict
        Model parameters.
    refs : jnp.ndarray, shape (N, input_dim)
    probes : jnp.ndarray, shape (N, input_dim)
    responses : jnp.ndarray, shape (N,)
        Typically 0/1; task may support richer encodings.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed)

            Notes
            -----
            This method is intentionally strict and does not accept task-specific
            runtime kwargs.

            - Configure task behavior (e.g. MC fidelity/smoothing for ``OddityTask``)
                via the task instance passed to the model.
            - If you need reproducible randomness, pass a ``key`` when calling the
                task directly.
    """
    # We need a ResponseData-like object. To keep this method usable from
    # array inputs, we construct one on the fly. If you already have a
    # ResponseData instance, prefer `log_likelihood_from_data`.
    from psyphy.data.dataset import ResponseData  # local import to avoid cycles

    data = ResponseData()
    # ResponseData.add_trial(ref, probe, resp)
    for r, p, y in zip(refs, probes, responses):
        data.add_trial(r, p, int(y))
    return self.task.loglik(params, data, self, self.noise)

log_likelihood_from_data

log_likelihood_from_data(
    params: Params, data: Any
) -> ndarray

Compute log-likelihood directly from a ResponseData object.

Why delegate to the task? - The task knows the decision rule (oddity, 2AFC, ...). - The task can use the model (this WPPM) to fetch discriminabilities. - The task can use the noise model if it needs MC simulation.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data ResponseData

Collected trial data.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed).

Source code in src/psyphy/model/wppm.py
def log_likelihood_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """Compute log-likelihood directly from a ResponseData object.

    Why delegate to the task?
        - The task knows the decision rule (oddity, 2AFC, ...).
        - The task can use the model (this WPPM) to fetch discriminabilities.
        - The task can use the noise model if it needs MC simulation.

    Parameters
    ----------
    params : dict
        Model parameters.
    data : ResponseData
        Collected trial data.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed).
    """
    return self.task.loglik(params, data, self, self.noise)

log_posterior_from_data

log_posterior_from_data(
    params: Params, data: Any
) -> ndarray

Compute log posterior from data.

This simply adds the prior log-probability to the task log-likelihood. Inference engines (e.g., MAP optimizer) typically optimize this quantity.

Returns:

Type Description
ndarray

Scalar log posterior = loglik(params | data) + log_prior(params).

Source code in src/psyphy/model/wppm.py
def log_posterior_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """Compute log posterior from data.

    This simply adds the prior log-probability to the task log-likelihood.
    Inference engines (e.g., MAP optimizer) typically optimize this quantity.

    Returns
    -------
    jnp.ndarray
        Scalar log posterior = loglik(params | data) + log_prior(params).
    """
    return self.log_likelihood_from_data(params, data) + self.prior.log_prob(params)

posterior

posterior(
    X: ndarray | None = None,
    *,
    probes: ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior

Return posterior distribution.

Parameters:

Name Type Description Default
X ndarray | None

Test stimuli (references), shape (n_test, input_dim). Required for predictive posteriors, optional for parameter posteriors.

None
probes ndarray | None

Test probes, shape (n_test, input_dim). Required for predictive posteriors.

None
kind ('predictive', 'parameter')

Type of posterior to return: - "predictive": PredictivePosterior over f(X*) [for acquisitions] - "parameter": ParameterPosterior over θ [for diagnostics]

"predictive"

Returns:

Type Description
PredictivePosterior | ParameterPosterior

Posterior distribution

Raises:

Type Description
RuntimeError

If model has not been fit yet

Examples:

1
2
3
4
>>> # For acquisition functions
>>> pred_post = model.posterior(X_candidates, probes=X_probes)
>>> mean = pred_post.mean
>>> var = pred_post.variance
1
2
3
>>> # For diagnostics
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(42))
Source code in src/psyphy/model/base.py
def posterior(
    self,
    X: jnp.ndarray | None = None,
    *,
    probes: jnp.ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior:
    """
    Return posterior distribution.

    Parameters
    ----------
    X : jnp.ndarray | None
        Test stimuli (references), shape (n_test, input_dim).
        Required for predictive posteriors, optional for parameter posteriors.
    probes : jnp.ndarray | None
        Test probes, shape (n_test, input_dim).
        Required for predictive posteriors.
    kind : {"predictive", "parameter"}
        Type of posterior to return:
        - "predictive": PredictivePosterior over f(X*) [for acquisitions]
        - "parameter": ParameterPosterior over θ [for diagnostics]

    Returns
    -------
    PredictivePosterior | ParameterPosterior
        Posterior distribution

    Raises
    ------
    RuntimeError
        If model has not been fit yet

    Examples
    --------
    >>> # For acquisition functions
    >>> pred_post = model.posterior(X_candidates, probes=X_probes)
    >>> mean = pred_post.mean
    >>> var = pred_post.variance

    >>> # For diagnostics
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(42))
    """
    if self._posterior is None:
        raise RuntimeError("Must call fit() before posterior()")

    if kind == "parameter":
        return self._posterior
    elif kind == "predictive":
        if X is None:
            raise ValueError("X is required for predictive posteriors")
        from psyphy.posterior import WPPMPredictivePosterior

        return WPPMPredictivePosterior(self._posterior, X, probes=probes)
    else:
        raise ValueError(
            f"Unknown kind: '{kind}'. Use 'predictive' or 'parameter'."
        )

predict_prob

predict_prob(
    params: Params, stimulus: Stimulus, **task_kwargs: Any
) -> ndarray

Predict probability of a correct response for a single stimulus.

Design choice: WPPM computes discriminability & covariance; the TASK defines how that translates to performance. We therefore delegate to: task.predict(params, stimulus, model=self, noise=self.noise)

Parameters:

Name Type Description Default
params dict
required
stimulus (reference, probe)
required

Returns:

Name Type Description
p_correct ndarray
Source code in src/psyphy/model/wppm.py
def predict_prob(
    self, params: Params, stimulus: Stimulus, **task_kwargs: Any
) -> jnp.ndarray:
    """
    Predict probability of a correct response for a single stimulus.

    Design choice:
        WPPM computes discriminability & covariance; the TASK defines how
        that translates to performance. We therefore delegate to:
            task.predict(params, stimulus, model=self, noise=self.noise)

    Parameters
    ----------
    params : dict
    stimulus : (reference, probe)

    Returns
    -------
    p_correct : jnp.ndarray
    """
    # Strict task-owned configuration:
    # - MC control knobs (e.g. num_samples/bandwidth) live in the task config.
    # - WPPM.predict_prob therefore does not accept task-specific kwargs.
    if task_kwargs:
        unexpected = ", ".join(sorted(task_kwargs.keys()))
        raise TypeError(
            "WPPM.predict_prob does not accept task-specific kwargs. "
            "Configure task behavior via the task instance (e.g. OddityTaskConfig). "
            f"Unexpected: {unexpected}"
        )

    return self.task.predict(params, stimulus, self, self.noise)

predict_with_params

predict_with_params(
    X: ndarray,
    probes: ndarray | None,
    params: dict[str, ndarray],
) -> ndarray

Evaluate model at specific parameter values (no marginalization).

This is useful for: - Threshold uncertainty estimation (evaluate at sampled parameters) - Parameter sensitivity analysis - Debugging and diagnostics

NOT for making predictions (use .posterior() instead, which marginalizes over parameter uncertainty).

Parameters:

Name Type Description Default
X (ndarray, shape(n_test, input_dim))

Test stimuli (references)

required
probes (ndarray, shape(n_test, input_dim))

Probe stimuli (for discrimination tasks)

required
params dict[str, ndarray]

Specific parameter values to evaluate at. Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

required

Returns:

Name Type Description
predictions (ndarray, shape(n_test))

Predicted probabilities at each test point, given these parameters

Examples:

>>> # Sample parameters and evaluate
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(0))
>>>
>>> # Evaluate at first parameter sample
>>> params_0 = {k: v[0] for k, v in samples.items()}
>>> predictions = model.predict_with_params(X_test, probes, params_0)
>>>
>>> # Use for threshold uncertainty estimation
>>> threshold_locs = []
>>> for i in range(100):
...     params_i = {k: v[i] for k, v in samples.items()}
...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
...     threshold_locs.append(threshold_idx)
Notes

This bypasses the posterior marginalization. For acquisition functions, always use .posterior() which properly accounts for parameter uncertainty.

Source code in src/psyphy/model/base.py
def predict_with_params(
    self,
    X: jnp.ndarray,
    probes: jnp.ndarray | None,
    params: dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Evaluate model at specific parameter values (no marginalization).

    This is useful for:
    - Threshold uncertainty estimation (evaluate at sampled parameters)
    - Parameter sensitivity analysis
    - Debugging and diagnostics

    NOT for making predictions (use .posterior() instead, which
    marginalizes over parameter uncertainty).

    Parameters
    ----------
    X : jnp.ndarray, shape (n_test, input_dim)
        Test stimuli (references)
    probes : jnp.ndarray, shape (n_test, input_dim), optional
        Probe stimuli (for discrimination tasks)
    params : dict[str, jnp.ndarray]
        Specific parameter values to evaluate at.
        Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

    Returns
    -------
    predictions : jnp.ndarray, shape (n_test,)
        Predicted probabilities at each test point, given these parameters

    Examples
    --------
    >>> # Sample parameters and evaluate
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(0))
    >>>
    >>> # Evaluate at first parameter sample
    >>> params_0 = {k: v[0] for k, v in samples.items()}
    >>> predictions = model.predict_with_params(X_test, probes, params_0)
    >>>
    >>> # Use for threshold uncertainty estimation
    >>> threshold_locs = []
    >>> for i in range(100):
    ...     params_i = {k: v[i] for k, v in samples.items()}
    ...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
    ...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
    ...     threshold_locs.append(threshold_idx)

    Notes
    -----
    This bypasses the posterior marginalization. For acquisition functions,
    always use .posterior() which properly accounts for parameter uncertainty.
    """
    return self._forward(X, probes, params)

Priors


prior

prior.py

Prior distributions for WPPM parameters

Hyperparameters: * variance_scale : global scaling factor for covariance magnitude * decay_rate : smoothness controlling spatial variation * extra_embedding_dims : embedding dimension for basis expansions

Connections
  • WPPM calls Prior.sample_params() to initialize model parameters
  • WPPM adds Prior.log_prob(params) to task log-likelihoods to form the log posterior
  • Prior will generate structured parameters for basis expansions and decay_rate-controlled smooth covariance fields

Classes:

Name Description
Prior

Prior distribution over WPPM parameters

Attributes:

Name Type Description
Params

Params

Params = dict[str, ndarray]

Prior

Prior(
    input_dim: int,
    basis_degree: int | None = None,
    variance_scale: float = 1.0,
    decay_rate: float = 0.5,
    extra_embedding_dims: int = 0,
)

Prior distribution over WPPM parameters

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the model space (same as WPPM.input_dim)

required
basis_degree int | None

Degree of Chebyshev basis for Wishart process. If set, uses Wishart mode with W coefficients.

None
variance_scale float

Prior variance for degree-0 (constant) coefficient in Wishart mode. Controls overall scale of covariances.

1.0
decay_rate float

Geometric decay rate for prior variance over higher-degree coefficients. Prior variance for degree-d coefficient = variance_scale * (decay_rate^d). Smaller decay_rate -> stronger smoothness prior.

0.5
extra_embedding_dims int

Additional latent dimensions in U matrices beyond input dimensions. Allows richer ellipsoid shapes in Wishart mode.

0

Methods:

Name Description
log_prob

Compute log prior density (up to a constant)

sample_params

Sample initial parameters from the prior.

Attributes:

Name Type Description
basis_degree int | None
decay_rate float
extra_embedding_dims int
input_dim int
variance_scale float

basis_degree

basis_degree: int | None = None

decay_rate

decay_rate: float = 0.5

extra_embedding_dims

extra_embedding_dims: int = 0

input_dim

input_dim: int

variance_scale

variance_scale: float = 1.0

log_prob

log_prob(params: Params) -> ndarray

Compute log prior density (up to a constant)

Gaussian prior on W with smoothness via decay_rate log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

Parameters:

Name Type Description Default
params dict

Parameter dictionary

required

Returns:

Name Type Description
log_prob float

Log prior probability (up to normalizing constant)

Source code in src/psyphy/model/prior.py
def log_prob(self, params: Params) -> jnp.ndarray:
    """
    Compute log prior density (up to a constant)

    Gaussian prior on W with smoothness via decay_rate
        log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

    Parameters
    ----------
    params : dict
        Parameter dictionary

    Returns
    -------
    log_prob : float
        Log prior probability (up to normalizing constant)
    """

    if "W" in params:
        # Wishart mode
        W = params["W"]
        variances = self._compute_W_prior_variances()

        # Gaussian log probability for each entry
        # log N(x | 0, σ^2) = -0.5 * (x^2/σ^2 + log(2πσ^2))
        # Up to constant: -0.5 * x^2/σ^2

        if self.input_dim == 2:
            # Each W[i,j,:,:] ~ Normal(0, variance[i,j] * I)
            return -0.5 * jnp.sum((W**2) / (variances[:, :, None, None] + 1e-10))
        elif self.input_dim == 3:
            return -0.5 * jnp.sum((W**2) / (variances[:, :, :, None, None] + 1e-10))

    raise ValueError("params must contain weights 'W'")

sample_params

sample_params(key: Any) -> Params

Sample initial parameters from the prior.

Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)} for 2D, where embedding_dim = input_dim + extra_embedding_dims

Note: The 3rd dimension is input_dim (output space dimension). This matches the einsum in _compute_sqrt: U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

Parameters:

Name Type Description Default
key JAX random key
required

Returns:

Name Type Description
params dict

Parameter dictionary

Source code in src/psyphy/model/prior.py
def sample_params(self, key: Any) -> Params:
    """
    Sample initial parameters from the prior.


    Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)}
    for 2D, where embedding_dim = input_dim + extra_embedding_dims

    Note: The 3rd dimension is input_dim (output space dimension).
    This matches the einsum in _compute_sqrt:
    U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

    Parameters
    ----------
    key : JAX random key

    Returns
    -------
    params : dict
        Parameter dictionary
    """
    if self.basis_degree is None:
        raise ValueError(
            "'basis_degree' is None; please set "
            "`Prior.basis_degree` to an integer >0."
        )

    # Basis function coefficients W
    variances = self._compute_W_prior_variances()
    embedding_dim = self.input_dim + self.extra_embedding_dims

    if self.input_dim == 2:
        # Sample W ~ Normal(0, variances) for each matrix entry
        # Shape: (degree+1, degree+1, input_dim, embedding_dim)
        # Note: degree+1 to match number of basis functions [T_0, ..., T_degree]
        W = jnp.sqrt(variances)[:, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    elif self.input_dim == 3:
        # Shape: (degree+1, degree+1, degree+1, input_dim, embedding_dim)
        W = jnp.sqrt(variances)[:, :, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    else:
        raise NotImplementedError(
            f"Wishart process only supports 2D and 3D. Got input_dim={self.input_dim}"
        )

    return {"W": W}

Noise


noise

Classes:

Name Description
GaussianNoise
StudentTNoise

GaussianNoise

GaussianNoise(sigma: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Gaussian (mean=0, var=1).

Attributes:

Name Type Description
sigma float

sigma

sigma: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Gaussian (mean=0, var=1).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Gaussian (mean=0, var=1)."""
    return jr.normal(key, shape)

StudentTNoise

StudentTNoise(df: float = 3.0, scale: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Student-t (df=self.df).

Attributes:

Name Type Description
df float
scale float

df

df: float = 3.0

scale

scale: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Student-t (df=self.df).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Student-t (df=self.df)."""
    return jr.t(key, self.df, shape)

Tasks


task

psyphy.model.task

Task likelihoods for psychophysical experiments.

This module defines task-specific mappings from a model (e.g., WPPM) and stimuli to response likelihoods.

Current direction

OddityTask: the log-likelihood is computed via Monte Carlo observer simulation of the full 3-stimulus oddity decision rule (two identical references, one comparison).

The public API is:

  • TaskLikelihood.predict(params, stimuli, model, noise) Optional fast predictor for p(correct). For MC-only tasks this may be unimplemented.

  • TaskLikelihood.loglik(params, data, model, noise, **kwargs) Compute log-likelihood of observed responses under this task.

Connections
  • WPPM delegates to the task to compute likelihood.
  • Noise models are passed through so tasks can simulate observer responses.

Classes:

Name Description
OddityTask

Three-alternative forced-choice oddity task (MC-based only).

OddityTaskConfig

Configuration for :class:OddityTask.

TaskLikelihood

Abstract base class for task likelihoods

Attributes:

Name Type Description
Stimulus

Stimulus

Stimulus = tuple[ndarray, ndarray]

OddityTask

OddityTask(config: OddityTaskConfig | None = None)

Bases: TaskLikelihood

Three-alternative forced-choice oddity task (MC-based only).

Implements the full 3-stimulus oddity task using Monte Carlo simulation: - Samples three internal representations per trial (z0, z1, z2) - Uses proper oddity decision rule with three pairwise distances - Suitable for complex covariance structures

Notes

MC simulation in loglik() (full 3-stimulus oddity): 1. Sample three internal representations: z_ref, z_refprime ~ N(ref, Σ_ref), z_comparison ~ N(comparison, Σ_comparison) 2. Compute average covariance: Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison 3. Compute three pairwise Mahalanobis distances: - d^2(z_ref, z_refprime) = distance between two reference samples - d^2(z_ref, z_comparison) = distance from ref to comparison - d^2(z_refprime, z_comparison) = distance from reference_prime to comparison 4. Apply oddity decision rule: delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime) 5. Logistic smoothing: P(correct) pprox logistic.cdf(delta / bandwidth) 6. Average over samples

Examples:

>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.task import OddityTaskConfig
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.noise import GaussianNoise
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>>
>>> # Create task and model (task-owned MC controls)
>>> task = OddityTask(config=OddityTaskConfig(num_samples=1000, bandwidth=1e-2))
>>> model = WPPM(
...     input_dim=2, prior=Prior(input_dim=2), task=task, noise=GaussianNoise()
... )
>>> params = model.init_params(jr.PRNGKey(0))
1
2
3
4
5
6
>>> # MC simulation
>>> from psyphy.data.dataset import ResponseData
>>> data = ResponseData()
>>> data.add_trial(ref, comparison, resp=1)
>>> ll_mc = task.loglik(params, data, model, model.noise, key=jr.PRNGKey(42))
>>> print(f"Log-likelihood (MC): {ll_mc:.4f}")

Methods:

Name Description
loglik
1
Compute log-likelihood via Monte Carlo observer simulation.
predict

Predict p(correct) for a single (ref, comparison) stimulus.

Attributes:

Name Type Description
config
Source code in src/psyphy/model/task.py
def __init__(self, config: OddityTaskConfig | None = None) -> None:
    # No analytical parameters in MC-only mode.
    self.config = config or OddityTaskConfig()

config

config = config or OddityTaskConfig()

loglik

loglik(
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    **kwargs: Any,
) -> ndarray
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Compute log-likelihood via Monte Carlo observer simulation.

This method implements the FULL 3-stimulus oddity task. Instead of using
an analytical approximation, we:
1. Sample three internal noisy representations per trial:
   - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
   - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
2. Compute three pairwise Mahalanobis distances
3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
4. Apply logistic smoothing to approximate P(correct)
5. Average over MC samples
1
Parameters
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
params : Any
    Model parameters as expected by ``model._compute_sqrt``.
data : ResponseData
    Trial data with refs, comparisons, and responses
model : WPPM
    Model instance providing ``_compute_sqrt`` for covariance computation.
noise : NoiseModel
    Observer noise model (provides ``sample_standard``).
key : jax.random.PRNGKey, optional
    Random key for reproducible sampling.
    If None, uses ``OddityTaskConfig.default_key_seed``.
1
Returns
1
2
3
jnp.ndarray
    Scalar sum of log-likelihoods over all trials.
    Same shape and interpretation as ``loglik``.
1
Raises
1
2
3
4
TypeError
    If ``num_samples`` or ``bandwidth`` are provided as kwargs.
ValueError
    If the task configuration is invalid (e.g. ``num_samples <= 0``).
1
Notes
1
2
3
4
5
6
Monte Carlo controls (``num_samples``, ``bandwidth``) are owned by the
task configuration:

- Create the task with ``OddityTask(config=OddityTaskConfig(...))``.
- Pass only the PRNG ``key`` at call time when you want to control
  randomness.
1
Notes
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
**Full 3-stimulus oddity task algorithm:**

For each trial (ref, comparison, response):
1. Compute covariances:
   - Σ_ref = U_ref @ U_ref.T + σ^2 I
   - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
   - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

2. Sample three internal representations:
   - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
   - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

3. Compute three pairwise Mahalanobis distances:
   - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
   - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
   - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

4. Apply oddity decision rule:
   - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
   - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

5. Apply logistic smoothing:
   - P(correct) pprox mean(logistic.cdf(delta / bandwidth))

6. Bernoulli log-likelihood:
   - LL = Σ [y * log(p) + (1-y) * log(1-p)]

Performance:
- Memory: O(num_samples * input_dim) per trial
- Vectorized across trials using jax.vmap for GPU acceleration
  • Can be JIT-compiled for additional speed (future optimization)
1
Examples
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.noise import GaussianNoise
>>> from psyphy.data.dataset import ResponseData
>>>
>>> # Setup
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2, basis_degree=3),
...     task=OddityTask(),
...     noise=GaussianNoise(sigma=0.03),
... )
>>> params = model.init_params(jr.PRNGKey(0))
>>>
>>> # Create trial data
>>> data = ResponseData()
>>> data.add_trial(
...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
... )
>>>
>>> loglik = model.task.loglik(
...     params,
...     data,
...     model,
...     model.noise,
...     num_samples=5000,
...     bandwidth=1e-3,
...     key=jr.PRNGKey(42),
... )
>>> print(f"MC (N=5000): {loglik:.4f}")
Source code in src/psyphy/model/task.py
def loglik(
    self, params: Any, data: Any, model: Any, noise: Any, **kwargs: Any
) -> jnp.ndarray:
    """
        Compute log-likelihood via Monte Carlo observer simulation.

        This method implements the FULL 3-stimulus oddity task. Instead of using
        an analytical approximation, we:
        1. Sample three internal noisy representations per trial:
           - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
           - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
        2. Compute three pairwise Mahalanobis distances
        3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
        4. Apply logistic smoothing to approximate P(correct)
        5. Average over MC samples

        Parameters
        ----------
        params : Any
            Model parameters as expected by ``model._compute_sqrt``.
        data : ResponseData
            Trial data with refs, comparisons, and responses
        model : WPPM
            Model instance providing ``_compute_sqrt`` for covariance computation.
        noise : NoiseModel
            Observer noise model (provides ``sample_standard``).
        key : jax.random.PRNGKey, optional
            Random key for reproducible sampling.
            If None, uses ``OddityTaskConfig.default_key_seed``.

        Returns
        -------
        jnp.ndarray
            Scalar sum of log-likelihoods over all trials.
            Same shape and interpretation as ``loglik``.

        Raises
        ------
        TypeError
            If ``num_samples`` or ``bandwidth`` are provided as kwargs.
        ValueError
            If the task configuration is invalid (e.g. ``num_samples <= 0``).

        Notes
        -----
        Monte Carlo controls (``num_samples``, ``bandwidth``) are owned by the
        task configuration:

        - Create the task with ``OddityTask(config=OddityTaskConfig(...))``.
        - Pass only the PRNG ``key`` at call time when you want to control
          randomness.

        Notes
        -----
        **Full 3-stimulus oddity task algorithm:**

        For each trial (ref, comparison, response):
        1. Compute covariances:
           - Σ_ref = U_ref @ U_ref.T + σ^2 I
           - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
           - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

        2. Sample three internal representations:
           - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
           - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

        3. Compute three pairwise Mahalanobis distances:
           - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
           - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
           - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

        4. Apply oddity decision rule:
           - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
           - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

        5. Apply logistic smoothing:
           - P(correct) \approx mean(logistic.cdf(delta / bandwidth))

        6. Bernoulli log-likelihood:
           - LL = Σ [y * log(p) + (1-y) * log(1-p)]

        Performance:
        - Memory: O(num_samples * input_dim) per trial
        - Vectorized across trials using jax.vmap for GPU acceleration
    - Can be JIT-compiled for additional speed (future optimization)

        Examples
        --------
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> from psyphy.model import WPPM, Prior
        >>> from psyphy.model.task import OddityTask
        >>> from psyphy.model.noise import GaussianNoise
        >>> from psyphy.data.dataset import ResponseData
        >>>
        >>> # Setup
        >>> model = WPPM(
        ...     input_dim=2,
        ...     prior=Prior(input_dim=2, basis_degree=3),
        ...     task=OddityTask(),
        ...     noise=GaussianNoise(sigma=0.03),
        ... )
        >>> params = model.init_params(jr.PRNGKey(0))
        >>>
        >>> # Create trial data
        >>> data = ResponseData()
        >>> data.add_trial(
        ...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
        ... )
        >>>
        >>> loglik = model.task.loglik(
        ...     params,
        ...     data,
        ...     model,
        ...     model.noise,
        ...     num_samples=5000,
        ...     bandwidth=1e-3,
        ...     key=jr.PRNGKey(42),
        ... )
        >>> print(f"MC (N=5000): {loglik:.4f}")


    """
    # Task is the single source of truth for MC controls.
    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)

    # Only PRNG key is accepted dynamically.
    key = kwargs.pop("key", None)
    if "num_samples" in kwargs or "bandwidth" in kwargs:
        raise TypeError(
            "OddityTask.loglik does not accept 'num_samples'/'bandwidth' overrides. "
            "Configure them via OddityTaskConfig when constructing the task."
        )
    if kwargs:
        unexpected = ", ".join(sorted(kwargs.keys()))
        raise TypeError(
            f"Unexpected keyword arguments for OddityTask.loglik: {unexpected}"
        )

    if num_samples <= 0:
        raise ValueError(f"num_samples must be > 0, got {num_samples}")
    if bandwidth <= 0:
        raise ValueError(f"bandwidth must be > 0, got {bandwidth}")

    if key is None:
        key = jr.PRNGKey(int(self.config.default_key_seed))

    # Unpack trial data
    refs, comparisons, responses = data.to_numpy()
    n_trials = len(refs)

    # Split keys for each trial (ensures independent sampling)
    trial_keys = jr.split(key, n_trials)

    # Vectorized computation of P(correct) for all trials
    # This processes all trials in parallel using jax.vmap
    # Note: probabilities are already clipped in _simulate_trial_mc()
    probs = self._simulate_trials_mc_vectorized(
        params=params,
        refs=refs,
        comparisons=comparisons,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        trial_keys=trial_keys,
    )

    # Bernoulli log-likelihood: LL = Σ [y log(p) + (1-y) log(1-p)]
    # Probabilities are already clipped to [eps, 1-eps] so log is safe
    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),  # Correct response
        jnp.log(1.0 - probs),  # Incorrect response
    )

    return jnp.sum(log_likelihoods)

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict p(correct) for a single (ref, comparison) stimulus.

Even though OddityTask is MC-only, we still implement predict. Reason: large parts of the library (posterior predictive, acquisition functions, diagnostics, etc.) need a forward model that returns p(correct) at candidate stimuli. Historically this used an analytical approximation, but in MC-only mode we compute it via simulation.

Notes
  • This method is intentionally lightweight: it performs the same single-trial Monte Carlo simulation used by loglik. - If you need to control MC fidelity/smoothing, set OddityTaskConfig(num_samples=..., bandwidth=...) when you construct the task. - If you need reproducible randomness, pass key=... to loglik.
Source code in src/psyphy/model/task.py
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """Predict p(correct) for a single (ref, comparison) stimulus.

    Even though OddityTask is *MC-only*, we still implement ``predict``.
    Reason: large parts of the library (posterior predictive, acquisition
    functions, diagnostics, etc.) need a forward model that returns
    p(correct) at candidate stimuli. Historically this used an analytical
    approximation, but in MC-only mode we compute it via simulation.

    Notes
    -----
    - This method is intentionally lightweight: it performs the same
      single-trial Monte Carlo simulation used by ``loglik``.
            - If you need to control MC fidelity/smoothing, set
                ``OddityTaskConfig(num_samples=..., bandwidth=...)`` when you
                construct the task.
            - If you need reproducible randomness, pass ``key=...`` to ``loglik``.
    """

    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)
    key = jr.PRNGKey(int(self.config.default_key_seed))

    ref, comparison = stimuli
    return self._simulate_trial_mc(
        params=params,
        ref=ref,
        comparison=comparison,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        key=key,
    )

OddityTaskConfig

OddityTaskConfig(
    num_samples: int = 1000,
    bandwidth: float = 0.01,
    default_key_seed: int = 0,
)

Configuration for :class:OddityTask.

This is the single source of truth for MC likelihood controls.

Attributes:

Name Type Description
num_samples int

Number of Monte Carlo samples per trial.

bandwidth float

Logistic CDF smoothing bandwidth.

default_key_seed int

Seed used when no key is provided (keeps behavior deterministic by default while allowing reproducibility control upstream).

bandwidth

bandwidth: float = 0.01

default_key_seed

default_key_seed: int = 0

num_samples

num_samples: int = 1000

TaskLikelihood

Bases: ABC

Abstract base class for task likelihoods

Methods:

Name Description
loglik

Compute log-likelihood of observed responses under this task.

predict

Predict probability of correct response for a stimulus.

loglik

loglik(
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    **kwargs: Any,
) -> ndarray

Compute log-likelihood of observed responses under this task.

Why **kwargs? - Different tasks may need different optional runtime controls. - MC-based tasks may need parameters such as a PRNG key. In particular, :class:OddityTask takes Monte Carlo controls (num_samples and bandwidth) exclusively from :class:OddityTaskConfig to avoid silent mismatch bugs.

Notes
  • Task implementations should document which kwargs they accept.
  • Callers should not assume arbitrary kwargs are supported.
Source code in src/psyphy/model/task.py
@abstractmethod
def loglik(
    self, params: Any, data: Any, model: Any, noise: Any, **kwargs: Any
) -> jnp.ndarray:
    """Compute log-likelihood of observed responses under this task.

    Why ``**kwargs``?
    - Different tasks may need different optional runtime controls.
    - MC-based tasks may need parameters such as a PRNG ``key``.
        In particular, :class:`OddityTask` takes Monte Carlo controls
        (``num_samples`` and ``bandwidth``) exclusively from
        :class:`OddityTaskConfig` to avoid silent mismatch bugs.

    Notes
    -----
    - Task implementations should document which kwargs they accept.
    - Callers should not assume arbitrary kwargs are supported.
    """
    ...

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict probability of correct response for a stimulus.

Source code in src/psyphy/model/task.py
@abstractmethod
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """Predict probability of correct response for a stimulus."""
    ...