Skip to content

Model

Package


model

psyphy.model

Model-layer API: everything model-related in one place.

Includes
  • WPPM (core model)
  • Priors (Prior)
  • Tasks (TaskLikelihood base, OddityTask)
  • Noise models (GaussianNoise, StudentTNoise)

All functions/classes use JAX arrays (jax.numpy as jnp) for autodiff and optimization with Optax.

Typical usage
1
from psyphy.model import WPPM, Prior, OddityTask, GaussianNoise

Classes:

Name Description
CovarianceField

Protocol for spatially-varying covariance fields Σ(x).

GaussianNoise
Model

Abstract base class for psychophysical models.

OddityTask

Three-alternative forced-choice oddity task (MC-based only).

OddityTaskConfig

Configuration for :class:OddityTask.

Prior

Prior distribution over WPPM parameters

StudentTNoise
TaskLikelihood

Abstract base class for task likelihoods.

WPPM

Wishart Process Psychophysical Model (WPPM).

WPPMCovarianceField

Covariance field for WPPM with Wishart process

CovarianceField

Bases: Protocol

Protocol for spatially-varying covariance fields Σ(x).

A covariance field maps stimulus locations x ∈ R^d to covariance matrices Σ(x) ∈ R^{dxd}.

Methods:

Name Description
__call__

Evaluate field at one or more locations. Supports both single points and arbitrary batch dimensions.

cov

Evaluate Σ(x) at stimulus location x (deprecated, use call).

sqrt_cov

Evaluate U(x) such that Σ(x) = U(x) @ U(x)^T + λI.

cov_batch

Vectorized evaluation at multiple locations (deprecated, use call).

Notes

This protocol enables polymorphic use of covariance fields from different sources (prior samples, fitted posteriors, custom parameterizations).

The field is callable for mathematical elegance and JAX compatibility: Sigma = field(x) # Single point or batch

cov

cov(x: ndarray) -> ndarray

Evaluate covariance matrix Σ(x) at stimulus location x (deprecated).

Source code in src/psyphy/model/covariance_field.py
def cov(self, x: jnp.ndarray) -> jnp.ndarray:
    """Evaluate covariance matrix Σ(x) at stimulus location x (deprecated)."""
    ...

cov_batch

cov_batch(X: ndarray) -> ndarray

Evaluate covariance at multiple locations (deprecated).

Source code in src/psyphy/model/covariance_field.py
def cov_batch(self, X: jnp.ndarray) -> jnp.ndarray:
    """Evaluate covariance at multiple locations (deprecated)."""
    ...

sqrt_cov

sqrt_cov(x: ndarray) -> ndarray

Evaluate "square root" matrix U(x) such that Σ(x) = U(x) @ U(x)^T + λI.

Source code in src/psyphy/model/covariance_field.py
def sqrt_cov(self, x: jnp.ndarray) -> jnp.ndarray:
    """Evaluate "square root" matrix U(x) such that Σ(x) = U(x) @ U(x)^T + λI."""
    ...

GaussianNoise

GaussianNoise(sigma: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Gaussian (mean=0, var=1).

Attributes:

Name Type Description
sigma float

sigma

sigma: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(key: Array, shape: tuple[int, ...]) -> Array

Sample from standard Gaussian (mean=0, var=1).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Gaussian (mean=0, var=1)."""
    return jr.normal(key, shape)

Model

Bases: ABC

Abstract base class for psychophysical models.

Subclasses must implement: - init_params(key) --> sample initial parameters (Prior) - log_likelihood_from_data(params, data) --> compute likelihood

Methods:

Name Description
init_params

Sample initial parameters from prior.

log_likelihood_from_data

Compute log p(data | params).

init_params

init_params(key: Any) -> dict

Sample initial parameters from prior.

Parameters:

Name Type Description Default
key KeyArray

PRNG key

required

Returns:

Type Description
dict

Parameter PyTree

Source code in src/psyphy/model/base.py
@abstractmethod
def init_params(self, key: Any) -> dict:  # jax.random.KeyArray
    """
    Sample initial parameters from prior.

    Parameters
    ----------
    key : jax.random.KeyArray
        PRNG key

    Returns
    -------
    dict
        Parameter PyTree
    """
    ...

log_likelihood_from_data

log_likelihood_from_data(params: dict, data: ResponseData) -> ndarray

Compute log p(data | params).

Parameters:

Name Type Description Default
params dict

Model parameters

required
data ResponseData

Observed trials

required

Returns:

Type Description
ndarray

Log-likelihood (scalar)

Source code in src/psyphy/model/base.py
@abstractmethod
def log_likelihood_from_data(self, params: dict, data: ResponseData) -> jnp.ndarray:
    """
    Compute log p(data | params).

    Parameters
    ----------
    params : dict
        Model parameters
    data : ResponseData
        Observed trials

    Returns
    -------
    jnp.ndarray
        Log-likelihood (scalar)
    """
    ...

OddityTask

OddityTask(config: OddityTaskConfig | None = None)

Bases: TaskLikelihood

Three-alternative forced-choice oddity task (MC-based only).

Implements the full 3-stimulus oddity task using Monte Carlo simulation: - Samples three internal representations per trial (z0, z1, z2) - Uses proper oddity decision rule with three pairwise distances - Suitable for complex covariance structures

Notes

MC simulation in loglik() (full 3-stimulus oddity): 1. Sample three internal representations: z_ref, z_refprime ~ N(ref, Σ_ref), z_comparison ~ N(comparison, Σ_comparison) 2. Compute average covariance: Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison 3. Compute three pairwise Mahalanobis distances: - d^2(z_ref, z_refprime) = distance between two reference samples - d^2(z_ref, z_comparison) = distance from ref to comparison - d^2(z_refprime, z_comparison) = distance from reference_prime to comparison 4. Apply oddity decision rule: delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime) 5. Logistic smoothing: P(correct) pprox logistic.cdf(delta / bandwidth) 6. Average over samples

Examples:

>>> from psyphy.model.likelihood import OddityTask
>>> from psyphy.model.likelihood import OddityTaskConfig
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.noise import GaussianNoise
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>>
>>> # Create task and model (task-owned MC controls)
>>> likelihood = OddityTask(
...     config=OddityTaskConfig(num_samples=1000, bandwidth=1e-2)
... )
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2),
...     likelihood=task,
...     noise=GaussianNoise(),
... )
>>> params = model.init_params(jr.PRNGKey(0))
1
2
3
4
5
6
>>> # MC simulation
>>> from psyphy.data.dataset import ResponseData
>>> data = ResponseData()
>>> data.add_trial(ref, comparison, resp=1)
>>> ll_mc = likelihood.loglik(params, data, model, key=jr.PRNGKey(42))
>>> print(f"Log-likelihood (MC): {ll_mc:.4f}")

Methods:

Name Description
loglik

Compute Bernoulli log-likelihood over a batch of trials.

predict

Return p(correct) for a single (ref, comparison) trial via MC simulation.

simulate

Simulate observed binary responses for a batch of trials.

Attributes:

Name Type Description
config
Source code in src/psyphy/model/likelihood.py
def __init__(self, config: OddityTaskConfig | None = None) -> None:
    # No analytical parameters in MC-only mode.
    self.config = config or OddityTaskConfig()

config

config = config or OddityTaskConfig()

loglik

loglik(params: Any, data: Any, model: Any, *, key: Any = None) -> ndarray

Compute Bernoulli log-likelihood over a batch of trials.

This is a concrete base-class method: it vmaps predict over trials then applies the Bernoulli log-likelihood formula. Subclasses only need to implement predict.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
data Any

Object with .refs, .comparisons, .responses array attributes.

required
model Any

Model instance.

required
key KeyArray

PRNG key. Passed as independent per-trial subkeys to predict. When None, falls back to key=jr.PRNGKey(0) (deterministic).

None

Returns:

Type Description
ndarray

Scalar sum of Bernoulli log-likelihoods over all trials.

Source code in src/psyphy/model/likelihood.py
def loglik(
    self,
    params: Any,
    data: Any,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Compute Bernoulli log-likelihood over a batch of trials.

    This is a concrete base-class method: it vmaps ``predict`` over trials
    then applies the Bernoulli log-likelihood formula. Subclasses only need
    to implement ``predict``.

    Parameters
    ----------
    params : Any
        Model parameters.
    data : Any
        Object with ``.refs``, ``.comparisons``, ``.responses`` array attributes.
    model : Any
        Model instance.
    key : jax.random.KeyArray, optional
        PRNG key. Passed as independent per-trial subkeys to ``predict``.
        When None, falls back to ``key=jr.PRNGKey(0)`` (deterministic).

    Returns
    -------
    jnp.ndarray
        Scalar sum of Bernoulli log-likelihoods over all trials.
    """
    refs = jnp.asarray(data.refs)
    comparisons = jnp.asarray(data.comparisons)
    responses = jnp.asarray(data.responses)
    n_trials = int(refs.shape[0])

    base_key = key if key is not None else jr.PRNGKey(0)
    trial_keys = jr.split(base_key, n_trials)

    probs = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),
        jnp.log(1.0 - probs),
    )
    return jnp.sum(log_likelihoods)

predict

predict(params: Any, ref: ndarray, comparison: ndarray, model: Any, *, key: Any = None) -> ndarray

Return p(correct) for a single (ref, comparison) trial via MC simulation.

MC controls (num_samples, bandwidth) are read from :class:OddityTaskConfig. Pass key to control randomness; when None, config.default_key_seed is used.

Source code in src/psyphy/model/likelihood.py
def predict(
    self,
    params: Any,
    ref: jnp.ndarray,
    comparison: jnp.ndarray,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Return p(correct) for a single (ref, comparison) trial via MC simulation.

    MC controls (``num_samples``, ``bandwidth``) are read from
    :class:`OddityTaskConfig`. Pass ``key`` to control randomness; when
    None, ``config.default_key_seed`` is used.
    """
    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)
    if key is None:
        key = jr.PRNGKey(int(self.config.default_key_seed))

    return self._simulate_trial_mc(
        params=params,
        ref=ref,
        comparison=comparison,
        model=model,
        num_samples=num_samples,
        bandwidth=bandwidth,
        key=key,
    )

simulate

simulate(params: Any, refs: ndarray, comparisons: ndarray, model: Any, *, key: Any) -> tuple[ndarray, ndarray]

Simulate observed binary responses for a batch of trials.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
refs (ndarray, shape(n_trials, input_dim))

Reference stimuli.

required
comparisons (ndarray, shape(n_trials, input_dim))

Comparison stimuli.

required
model Any

Model instance.

required
key KeyArray

PRNG key (required; split internally for prediction and sampling).

required

Returns:

Name Type Description
responses jnp.ndarray, shape (n_trials,), dtype int32

Simulated binary responses (1 = correct, 0 = incorrect).

p_correct (ndarray, shape(n_trials))

Estimated P(correct) per trial used to draw the responses.

Source code in src/psyphy/model/likelihood.py
def simulate(
    self,
    params: Any,
    refs: jnp.ndarray,
    comparisons: jnp.ndarray,
    model: Any,
    *,
    key: Any,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Simulate observed binary responses for a batch of trials.

    Parameters
    ----------
    params : Any
        Model parameters.
    refs : jnp.ndarray, shape (n_trials, input_dim)
        Reference stimuli.
    comparisons : jnp.ndarray, shape (n_trials, input_dim)
        Comparison stimuli.
    model : Any
        Model instance.
    key : jax.random.KeyArray
        PRNG key (required; split internally for prediction and sampling).

    Returns
    -------
    responses : jnp.ndarray, shape (n_trials,), dtype int32
        Simulated binary responses (1 = correct, 0 = incorrect).
    p_correct : jnp.ndarray, shape (n_trials,)
        Estimated P(correct) per trial used to draw the responses.
    """
    refs = jnp.asarray(refs)
    comparisons = jnp.asarray(comparisons)
    n_trials = int(refs.shape[0])

    k_pred, k_bernoulli = jr.split(key)
    trial_keys = jr.split(k_pred, n_trials)

    p_correct = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    responses = jr.bernoulli(k_bernoulli, p_correct).astype(jnp.int32)
    return responses, p_correct

OddityTaskConfig

OddityTaskConfig(num_samples: int = 1000, bandwidth: float = 0.01, default_key_seed: int = 0)

Configuration for :class:OddityTask.

This is the single source of truth for MC likelihood controls.

Attributes:

Name Type Description
num_samples int

Number of Monte Carlo samples per trial.

bandwidth float

Logistic CDF smoothing bandwidth.

default_key_seed int

Seed used when no key is provided (keeps behavior deterministic by default while allowing reproducibility control upstream).

bandwidth

bandwidth: float = 0.01

default_key_seed

default_key_seed: int = 0

num_samples

num_samples: int = 1000

Prior

Prior(input_dim: int = 2, basis_degree: int = 4, variance_scale: float = 0.004, decay_rate: float = 0.4, extra_embedding_dims: int = 1)

Prior distribution over WPPM parameters

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the model space (same as WPPM.input_dim)

2
basis_degree int | None

Degree of Chebyshev basis for Wishart process. If set, uses Wishart mode with W coefficients.

None
variance_scale float

Prior variance for degree-0 (constant) coefficient in Wishart mode. Controls overall scale of covariances.

1.0
decay_rate float

Geometric decay rate for prior variance over higher-degree coefficients. Prior variance for degree-d coefficient = variance_scale * (decay_rate^d). Smaller decay_rate -> stronger smoothness prior.

0.5
extra_embedding_dims int

Additional latent dimensions in U matrices beyond input dimensions. Allows richer ellipsoid shapes in Wishart mode.

0

Methods:

Name Description
log_prob

Compute log prior density (up to a constant)

sample_params

Sample initial parameters from the prior.

Attributes:

Name Type Description
basis_degree int
decay_rate float
extra_embedding_dims int
input_dim int
variance_scale float

basis_degree

basis_degree: int = 4

decay_rate

decay_rate: float = 0.4

extra_embedding_dims

extra_embedding_dims: int = 1

input_dim

input_dim: int = 2

variance_scale

variance_scale: float = 0.004

log_prob

log_prob(params: Params) -> ndarray

Compute log prior density (up to a constant)

Gaussian prior on W with smoothness via decay_rate log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

Parameters:

Name Type Description Default
params dict

Parameter dictionary

required

Returns:

Name Type Description
log_prob float

Log prior probability (up to normalizing constant)

Source code in src/psyphy/model/prior.py
def log_prob(self, params: Params) -> jnp.ndarray:
    """
    Compute log prior density (up to a constant)

    Gaussian prior on W with smoothness via decay_rate
        log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

    Parameters
    ----------
    params : dict
        Parameter dictionary

    Returns
    -------
    log_prob : float
        Log prior probability (up to normalizing constant)
    """

    if "W" in params:
        # Wishart mode
        W = params["W"]
        variances = self._compute_W_prior_variances()

        # Gaussian log probability for each entry
        # log N(x | 0, σ^2) = -0.5 * (x^2/σ^2 + log(2πσ^2))
        # Up to constant: -0.5 * x^2/σ^2

        if self.input_dim == 2:
            # Each W[i,j,:,:] ~ Normal(0, variance[i,j] * I)
            return -0.5 * jnp.sum((W**2) / (variances[:, :, None, None] + 1e-10))
        elif self.input_dim == 3:
            return -0.5 * jnp.sum((W**2) / (variances[:, :, :, None, None] + 1e-10))

    raise ValueError("params must contain weights 'W'")

sample_params

sample_params(key: Any) -> Params

Sample initial parameters from the prior.

Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)} for 2D, where embedding_dim = input_dim + extra_embedding_dims

Note: The 3rd dimension is input_dim (output space dimension). This matches the einsum in _compute_sqrt: U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

Parameters:

Name Type Description Default
key JAX random key
required

Returns:

Name Type Description
params dict

Parameter dictionary

Source code in src/psyphy/model/prior.py
def sample_params(self, key: Any) -> Params:
    """
    Sample initial parameters from the prior.


    Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)}
    for 2D, where embedding_dim = input_dim + extra_embedding_dims

    Note: The 3rd dimension is input_dim (output space dimension).
    This matches the einsum in _compute_sqrt:
    U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

    Parameters
    ----------
    key : JAX random key

    Returns
    -------
    params : dict
        Parameter dictionary
    """
    if self.basis_degree is None:
        raise ValueError(
            "'basis_degree' is None; please set "
            "`Prior.basis_degree` to an integer >0."
        )

    # Basis function coefficients W
    variances = self._compute_W_prior_variances()
    embedding_dim = self.input_dim + self.extra_embedding_dims

    if self.input_dim == 2:
        # Sample W ~ Normal(0, variances) for each matrix entry
        # Shape: (degree+1, degree+1, input_dim, embedding_dim)
        # Note: degree+1 to match number of basis functions [T_0, ..., T_degree]
        W = jnp.sqrt(variances)[:, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    elif self.input_dim == 3:
        # Shape: (degree+1, degree+1, degree+1, input_dim, embedding_dim)
        W = jnp.sqrt(variances)[:, :, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    else:
        raise NotImplementedError(
            f"Wishart process only supports 2D and 3D. Got input_dim={self.input_dim}"
        )

    return {"W": W}

StudentTNoise

StudentTNoise(df: float = 3.0, scale: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Student-t (df=self.df).

Attributes:

Name Type Description
df float
scale float

df

df: float = 3.0

scale

scale: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(key: Array, shape: tuple[int, ...]) -> Array

Sample from standard Student-t (df=self.df).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Student-t (df=self.df)."""
    return jr.t(key, self.df, shape)

TaskLikelihood

Bases: ABC

Abstract base class for task likelihoods.

Subclasses must implement: - predict(params, ref, comparison, model, *, key) → p(correct) for one trial

The base class provides concrete implementations of: - loglik(params, data, model, *, key) → Bernoulli log-likelihood over a batch - simulate(params, refs, comparisons, model, *, key) → simulated responses

The Bernoulli log-likelihood step is identical for all binary-response tasks, so it lives here rather than being re-implemented in every subclass.

Methods:

Name Description
loglik

Compute Bernoulli log-likelihood over a batch of trials.

predict

Return p(correct) for a single (ref, comparison) trial.

simulate

Simulate observed binary responses for a batch of trials.

loglik

loglik(params: Any, data: Any, model: Any, *, key: Any = None) -> ndarray

Compute Bernoulli log-likelihood over a batch of trials.

This is a concrete base-class method: it vmaps predict over trials then applies the Bernoulli log-likelihood formula. Subclasses only need to implement predict.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
data Any

Object with .refs, .comparisons, .responses array attributes.

required
model Any

Model instance.

required
key KeyArray

PRNG key. Passed as independent per-trial subkeys to predict. When None, falls back to key=jr.PRNGKey(0) (deterministic).

None

Returns:

Type Description
ndarray

Scalar sum of Bernoulli log-likelihoods over all trials.

Source code in src/psyphy/model/likelihood.py
def loglik(
    self,
    params: Any,
    data: Any,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Compute Bernoulli log-likelihood over a batch of trials.

    This is a concrete base-class method: it vmaps ``predict`` over trials
    then applies the Bernoulli log-likelihood formula. Subclasses only need
    to implement ``predict``.

    Parameters
    ----------
    params : Any
        Model parameters.
    data : Any
        Object with ``.refs``, ``.comparisons``, ``.responses`` array attributes.
    model : Any
        Model instance.
    key : jax.random.KeyArray, optional
        PRNG key. Passed as independent per-trial subkeys to ``predict``.
        When None, falls back to ``key=jr.PRNGKey(0)`` (deterministic).

    Returns
    -------
    jnp.ndarray
        Scalar sum of Bernoulli log-likelihoods over all trials.
    """
    refs = jnp.asarray(data.refs)
    comparisons = jnp.asarray(data.comparisons)
    responses = jnp.asarray(data.responses)
    n_trials = int(refs.shape[0])

    base_key = key if key is not None else jr.PRNGKey(0)
    trial_keys = jr.split(base_key, n_trials)

    probs = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),
        jnp.log(1.0 - probs),
    )
    return jnp.sum(log_likelihoods)

predict

predict(params: Any, ref: ndarray, comparison: ndarray, model: Any, *, key: Any = None) -> ndarray

Return p(correct) for a single (ref, comparison) trial.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
ref (ndarray, shape(input_dim))

Reference stimulus.

required
comparison (ndarray, shape(input_dim))

Comparison stimulus.

required
model Any

Model instance (provides covariance structure and model.noise).

required
key KeyArray

PRNG key for stochastic tasks. When None, the task falls back to its config.default_key_seed.

None

Returns:

Type Description
ndarray

Scalar p(correct) in (0, 1).

Source code in src/psyphy/model/likelihood.py
@abstractmethod
def predict(
    self,
    params: Any,
    ref: jnp.ndarray,
    comparison: jnp.ndarray,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Return p(correct) for a single (ref, comparison) trial.

    Parameters
    ----------
    params : Any
        Model parameters.
    ref : jnp.ndarray, shape (input_dim,)
        Reference stimulus.
    comparison : jnp.ndarray, shape (input_dim,)
        Comparison stimulus.
    model : Any
        Model instance (provides covariance structure and ``model.noise``).
    key : jax.random.KeyArray, optional
        PRNG key for stochastic tasks. When None, the task falls back to
        its ``config.default_key_seed``.

    Returns
    -------
    jnp.ndarray
        Scalar p(correct) in (0, 1).
    """
    ...

simulate

simulate(params: Any, refs: ndarray, comparisons: ndarray, model: Any, *, key: Any) -> tuple[ndarray, ndarray]

Simulate observed binary responses for a batch of trials.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
refs (ndarray, shape(n_trials, input_dim))

Reference stimuli.

required
comparisons (ndarray, shape(n_trials, input_dim))

Comparison stimuli.

required
model Any

Model instance.

required
key KeyArray

PRNG key (required; split internally for prediction and sampling).

required

Returns:

Name Type Description
responses jnp.ndarray, shape (n_trials,), dtype int32

Simulated binary responses (1 = correct, 0 = incorrect).

p_correct (ndarray, shape(n_trials))

Estimated P(correct) per trial used to draw the responses.

Source code in src/psyphy/model/likelihood.py
def simulate(
    self,
    params: Any,
    refs: jnp.ndarray,
    comparisons: jnp.ndarray,
    model: Any,
    *,
    key: Any,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Simulate observed binary responses for a batch of trials.

    Parameters
    ----------
    params : Any
        Model parameters.
    refs : jnp.ndarray, shape (n_trials, input_dim)
        Reference stimuli.
    comparisons : jnp.ndarray, shape (n_trials, input_dim)
        Comparison stimuli.
    model : Any
        Model instance.
    key : jax.random.KeyArray
        PRNG key (required; split internally for prediction and sampling).

    Returns
    -------
    responses : jnp.ndarray, shape (n_trials,), dtype int32
        Simulated binary responses (1 = correct, 0 = incorrect).
    p_correct : jnp.ndarray, shape (n_trials,)
        Estimated P(correct) per trial used to draw the responses.
    """
    refs = jnp.asarray(refs)
    comparisons = jnp.asarray(comparisons)
    n_trials = int(refs.shape[0])

    k_pred, k_bernoulli = jr.split(key)
    trial_keys = jr.split(k_pred, n_trials)

    p_correct = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    responses = jr.bernoulli(k_bernoulli, p_correct).astype(jnp.int32)
    return responses, p_correct

WPPM

WPPM(prior: Prior, likelihood: TaskLikelihood, noise: Any | None = None, *, input_dim: int = 2, extra_dims: int = 1, variance_scale: float = 0.004, decay_rate: float = 0.4, diag_term: float = 1e-06, **model_kwargs: Any)

Bases: Model

Wishart Process Psychophysical Model (WPPM).

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input stimulus space (e.g., 2 for isoluminant plane, 3 for RGB). Both reference and comparison live in R^{input_dim}.

2
prior Prior

Prior distribution over model parameters. Controls basis_degree in WPPM (basis expansion). The WPPM delegates to prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

required
likelihood TaskLikelihood

Psychophysical task mapping that defines how discriminability translates to p(correct) and how log-likelihood of responses is computed. (e.g., OddityTask)

required
noise Any

Noise model describing internal representation noise (e.g., GaussianNoise).

None
hyperparameters

extra_dims : int, default=0 Additional embedding dimensions for basis expansions (beyond input_dim). embedding_dim = input_dim + extra_dims. variance_scale : float, default=1.0 Global scaling factor for covariance magnitude decay_rate : float, default=1.0 Smoothness/length-scale for spatial covariance variation diag_term : float, default=1e-6 Small positive value added to the covariance diagonal for numerical stability.

model_kwargs : Any Reserved for future keyword arguments accepted by the base Model.__init__. Do not pass WPPM math knobs or task/likelihood knobs here.

Methods:

Name Description
init_params

Sample initial parameters from the prior.

local_covariance

Return local covariance Σ(x) at stimulus location x.

log_likelihood_from_data

Compute log-likelihood directly from a batched data object.

log_posterior_from_data

Compute log posterior from data.

predict_prob

Predict probability of a correct response for a single stimulus.

Attributes:

Name Type Description
basis_degree int | None

Chebyshev polynomial degree for Wishart process basis expansion.

decay_rate
diag_term
embedding_dim int

Dimension of the embedding space.

extra_dims
input_dim
likelihood
noise
prior
variance_scale
Source code in src/psyphy/model/wppm.py
def __init__(
    self,
    prior: Prior,
    likelihood: TaskLikelihood,
    noise: Any | None = None,
    *,  # everything after here is keyword-only
    input_dim: int = 2,
    extra_dims: int = 1,
    variance_scale: float = 4e-3,
    decay_rate: float = 0.4,
    diag_term: float = 1e-6,
    **model_kwargs: Any,
) -> None:
    # Base-model configuration.
    #
    # `model_kwargs` is reserved for *future* base `Model.__init__` kwargs.
    # It should NOT be used for WPPM-specific math (e.g. alternative covariance
    # parameterizations) or for task-specific likelihood knobs.
    if model_kwargs:
        known_misuses = {"num_samples", "bandwidth", "online_config"}
        bad = sorted(known_misuses.intersection(model_kwargs.keys()))
        if bad:
            raise TypeError(
                "Do not pass task-specific kwargs via WPPM(..., **model_kwargs). "
                f"Move {bad} into the task config (e.g. OddityTaskConfig)."
            )

    super().__init__(**model_kwargs)

    # --- core components ---
    self.input_dim = int(input_dim)  # stimulus-space dimensionality
    self.prior = prior  # prior over parameter PyTree

    if self.prior.input_dim != self.input_dim:
        raise ValueError(
            f"Dimension mismatch: Model initialized with input_dim={self.input_dim}, "
            f"but Prior expects input_dim={self.prior.input_dim}."
        )

    self.likelihood = likelihood  # task mapping and likelihood
    self.noise = noise  # noise model

    self.extra_dims = int(extra_dims)
    self.variance_scale = float(variance_scale)
    self.decay_rate = float(decay_rate)
    self.diag_term = float(diag_term)

basis_degree

basis_degree: int | None

Chebyshev polynomial degree for Wishart process basis expansion.

This property delegates to self.prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

Returns:

Type Description
int | None

Degree of Chebyshev polynomial basis (0 = constant, 1 = linear, etc.)

Notes

WPPM gets its basis_degree parameter from Prior.basis_degree.

decay_rate

decay_rate = float(decay_rate)

diag_term

diag_term = float(diag_term)

embedding_dim

embedding_dim: int

Dimension of the embedding space.

embedding_dim = input_dim + extra_dims. this represents the full perceptual space where: - First input_dim dimensions correspond to observable stimulus features - Remaining extra_dims are latent dimensions

Returns:

Type Description
int

input_dim + extra_dims

Notes

This is a computed property, not a constructor parameter.

extra_dims

extra_dims = int(extra_dims)

input_dim

input_dim = int(input_dim)

likelihood

likelihood = likelihood

noise

noise = noise

prior

prior = prior

variance_scale

variance_scale = float(variance_scale)

init_params

init_params(key: Array) -> Params

Sample initial parameters from the prior.

Returns:

Name Type Description
params dict[str, ndarray]
Source code in src/psyphy/model/wppm.py
def init_params(self, key: jax.Array) -> Params:
    """Sample initial parameters from the prior.

    Returns
    -------
    params : dict[str, jnp.ndarray]
    """
    return self.prior.sample_params(key)

local_covariance

local_covariance(params: Params, x: ndarray) -> ndarray

Return local covariance Σ(x) at stimulus location x.

Wishart mode (basis_degree set): Σ(x) = U(x) @ U(x)^T + diag_term * I where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0. - Varies smoothly with x - Guaranteed positive-definite - Returns stimulus covariance directly (input_dim, input_dim)

Parameters:

Name Type Description Default
params dict

Model parameters: - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}

required
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
Σ : jnp.ndarray, shape (input_dim, input_dim)

Covariance matrix in stimulus space.

Source code in src/psyphy/model/wppm.py
def local_covariance(self, params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    Return local covariance Σ(x) at stimulus location x.


    Wishart mode (basis_degree set):
        Σ(x) = U(x) @ U(x)^T + diag_term * I
        where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0.
        - Varies smoothly with x
        - Guaranteed positive-definite
        - Returns stimulus covariance directly (input_dim, input_dim)

    Parameters
    ----------
    params : dict
        Model parameters:
        - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    Σ : jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix in stimulus space.
    """
    # WPPM: spatially-varying covariance
    if "W" in params:
        U = self._compute_sqrt(params, x)  # (input_dim, embedding_dim)
        # Σ(x) = U(x) @ U(x)^T + diag_term * I
        # Result is (input_dim, input_dim)
        Sigma = U @ U.T + self.diag_term * jnp.eye(self.input_dim)
        return Sigma

    raise ValueError("params must contain 'W' (weights of WPPM)")

log_likelihood_from_data

log_likelihood_from_data(params: Params, data: Any, *, key: Array | None = None) -> ndarray

Compute log-likelihood directly from a batched data object.

Why delegate to the likelihood? - The likelihood knows the decision rule (oddity, 2AFC, ...). - The likelihood can use the model (this WPPM) to fetch discriminabilities. - The likelihood can use the noise model if it needs MC simulation.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data TrialData (or any object with refs/comparisons/responses arrays)

Collected trial data.

required
key Array | None

JAX random key for MC likelihood evaluation. When provided, a fresh noise realization is drawn every call — required for correct stochastic gradient estimates during optimization. When None, the task falls back to OddityTaskConfig.default_key_seed (useful for fixed evaluation and testing, but should not be used during gradient-based optimization).

None

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed).

Source code in src/psyphy/model/wppm.py
def log_likelihood_from_data(
    self, params: Params, data: Any, *, key: jax.Array | None = None
) -> jnp.ndarray:
    """Compute log-likelihood directly from a batched data object.

    Why delegate to the likelihood?
        - The likelihood knows the decision rule (oddity, 2AFC, ...).
        - The likelihood can use the model (this WPPM) to fetch discriminabilities.
        - The likelihood can use the noise model if it needs MC simulation.

    Parameters
    ----------
    params : dict
        Model parameters.
    data : TrialData (or any object with refs/comparisons/responses arrays)
        Collected trial data.
    key : jax.Array | None, optional
        JAX random key for MC likelihood evaluation. When provided, a fresh
        noise realization is drawn every call — required for correct stochastic
        gradient estimates during optimization. When None, the task falls back
        to ``OddityTaskConfig.default_key_seed`` (useful for fixed evaluation
        and testing, but should not be used during gradient-based optimization).

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed).
    """
    return self.likelihood.loglik(params, data, self, key=key)

log_posterior_from_data

log_posterior_from_data(params: Params, data: Any, *, key: Array | None = None) -> ndarray

Compute log posterior from data.

This simply adds the prior log-probability to the task log-likelihood. Inference engines (e.g., MAP optimizer) typically optimize this quantity.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data TrialData

Collected trial data.

required
key Array | None

JAX random key for the MC likelihood. Must be provided during optimization so each gradient step uses a fresh noise realization. When None, falls back to OddityTaskConfig.default_key_seed.

None

Returns:

Type Description
ndarray

Scalar log posterior = loglik(params | data) + log_prior(params).

Source code in src/psyphy/model/wppm.py
def log_posterior_from_data(
    self, params: Params, data: Any, *, key: jax.Array | None = None
) -> jnp.ndarray:
    """Compute log posterior from data.

    This simply adds the prior log-probability to the task log-likelihood.
    Inference engines (e.g., MAP optimizer) typically optimize this quantity.

    Parameters
    ----------
    params : dict
        Model parameters.
    data : TrialData
        Collected trial data.
    key : jax.Array | None, optional
        JAX random key for the MC likelihood. Must be provided during
        optimization so each gradient step uses a fresh noise realization.
        When None, falls back to ``OddityTaskConfig.default_key_seed``.

    Returns
    -------
    jnp.ndarray
        Scalar log posterior = loglik(params | data) + log_prior(params).
    """
    return self.log_likelihood_from_data(
        params, data, key=key
    ) + self.prior.log_prob(params)

predict_prob

predict_prob(params: Params, stimulus: Stimulus, **likelihood_kwargs: Any) -> ndarray

Predict probability of a correct response for a single stimulus.

Design choice: WPPM computes discriminability & covariance; the LIKELIHOOD defines how that translates to performance. We therefore delegate to: likelihood.predict(params, stimulus, model=self, noise=self.noise)

Parameters:

Name Type Description Default
params dict
required
stimulus tuple[ndarray, ndarray]

(reference, comparison) pair in model space.

required

Returns:

Name Type Description
p_correct ndarray
Source code in src/psyphy/model/wppm.py
def predict_prob(
    self, params: Params, stimulus: Stimulus, **likelihood_kwargs: Any
) -> jnp.ndarray:
    """
    Predict probability of a correct response for a single stimulus.

    Design choice:
        WPPM computes discriminability & covariance; the LIKELIHOOD defines how
        that translates to performance. We therefore delegate to:
            likelihood.predict(params, stimulus, model=self, noise=self.noise)

    Parameters
    ----------
    params : dict
    stimulus : tuple[jnp.ndarray, jnp.ndarray]
         (reference, comparison) pair in model space.

    Returns
    -------
    p_correct : jnp.ndarray
    """
    # Strict task-owned configuration:
    # - MC control knobs (e.g. num_samples/bandwidth) live in the task config.
    # - WPPM.predict_prob therefore does not accept task-specific kwargs.
    if likelihood_kwargs:
        raise TypeError(
            f"WPPM.predict_prob got unexpected kwargs: {likelihood_kwargs}. "
            "Configure likelihood behavior via the TaskLikelihood object itself."
        )

    ref, comparison = stimulus
    return self.likelihood.predict(params, ref, comparison, self)

WPPMCovarianceField

WPPMCovarianceField(model, params: dict)

Covariance field for WPPM with Wishart process Encapsulates model + parameters to provide clean evaluation interface for Σ(x) and U(x).

Parameters:

Name Type Description Default
model WPPM

Model providing evaluation logic (local_covariance, _compute_sqrt)

required
params dict

Model parameters: - MVP: {"log_diag": (input_dim,)} - Wishart: {"W": (degree+1, degree+1, input_dim, embedding_dim)} where embedding_dim = input_dim + extra_embedding_dims

Note: The 3rd dimension is input_dim (output/stimulus space), not embedding_dim. This matches the einsum in _compute_sqrt where U(x) has shape (input_dim, embedding_dim).

required

Attributes:

Name Type Description
model WPPM

Associated model instance

params dict

Parameter dictionary

Examples:

>>> # From prior
>>> model = WPPM(input_dim=2, prior=Prior(basis_degree=5), ...)
>>> field = WPPMCovarianceField.from_prior(model, key)
>>> Sigma = field.cov(jnp.array([0.5, 0.3]))
>>>
>>> # From posterior
>>> posterior = model.fit(data, optimizer=MAPOptimizer())
>>> field = posterior.get_covariance_field()
>>> Sigmas = field.cov_batch(X_grid)
>>>
>>> # Access square root (Wishart mode only)
>>> U = field.sqrt_cov(jnp.array([0.5, 0.3]))
Notes

Implements the CovarianceField protocol for polymorphic use.

Construct covariance field from model and parameters.

Parameters:

Name Type Description Default
model WPPM

Model providing evaluation logic

required
params dict

Parameter dictionary

required

Methods:

Name Description
cov

Evaluate covariance matrix Σ(x) at stimulus location x.

cov_batch

Evaluate covariance at multiple locations (vectorized).

from_params

Create field from arbitrary parameters.

from_posterior

Create covariance field from fitted posterior.

from_prior

Sample a covariance field from the prior.

sqrt_cov

Evaluate U(x) such that Σ(x) = U(x) @ U(x)^T + diag_term*I.

sqrt_cov_batch

Vectorized evaluation of U(x) at multiple locations.

Source code in src/psyphy/model/covariance_field.py
def __init__(self, model, params: dict):
    """
    Construct covariance field from model and parameters.

    Parameters
    ----------
    model : WPPM
        Model providing evaluation logic
    params : dict
        Parameter dictionary
    """
    self.model = model
    self.params = params
    # Pre-compile JIT batch evaluation path for performance
    self._eval_batch_jitted = jax.jit(self._eval_batch_impl)

model

model = model

params

params = params

cov

cov(x: ndarray) -> ndarray

Evaluate covariance matrix Σ(x) at stimulus location x.

.. deprecated:: Use field(x) instead for unified single/batch API.

Parameters:

Name Type Description Default
x (ndarray, shape(input_dim))

Stimulus location in [0, 1]^d

required

Returns:

Type Description
(ndarray, shape(input_dim, input_dim))

Covariance matrix Σ(x) in stimulus space

Notes

With the rectangular U design, this always returns stimulus-space covariance (input_dim, input_dim), regardless of extra_dims.

Source code in src/psyphy/model/covariance_field.py
def cov(self, x: jnp.ndarray) -> jnp.ndarray:
    """
    Evaluate covariance matrix Σ(x) at stimulus location x.

    .. deprecated::
        Use `field(x)` instead for unified single/batch API.

    Parameters
    ----------
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location in [0, 1]^d

    Returns
    -------
    jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix Σ(x) in stimulus space

    Notes
    -----
    With the rectangular U design, this always returns stimulus-space
    covariance (input_dim, input_dim), regardless of extra_dims.
    """
    warnings.warn(
        "cov() is deprecated. Use field(x) instead for unified single/batch API.",
        DeprecationWarning,
        stacklevel=2,
    )
    if x.ndim != 1:
        raise ValueError(
            f"cov() only accepts single points with shape (input_dim,), got {x.shape}. "
            f"Use field(X) for batches."
        )
    return self._eval_single(x)

cov_batch

cov_batch(X: ndarray) -> ndarray

Evaluate covariance at multiple locations (vectorized).

.. deprecated:: Use field(X) instead for unified single/batch API.

Parameters:

Name Type Description Default
X (ndarray, shape(n_points, input_dim))

Multiple stimulus locations

required

Returns:

Type Description
(ndarray, shape(n_points, dim, dim))

Covariance matrices at each location

Source code in src/psyphy/model/covariance_field.py
def cov_batch(self, X: jnp.ndarray) -> jnp.ndarray:
    """
    Evaluate covariance at multiple locations (vectorized).

    .. deprecated::
        Use `field(X)` instead for unified single/batch API.

    Parameters
    ----------
    X : jnp.ndarray, shape (n_points, input_dim)
        Multiple stimulus locations

    Returns
    -------
    jnp.ndarray, shape (n_points, dim, dim)
        Covariance matrices at each location
    """
    warnings.warn(
        "cov_batch() is deprecated. Use field(X) instead for unified single/batch API.",
        DeprecationWarning,
        stacklevel=2,
    )
    if X.ndim < 2:
        raise ValueError(
            f"cov_batch() expects batch with shape (n_points, input_dim), got {X.shape}. "
            f"Use field(x) for single points."
        )
    return self._eval_batch_jitted(X)

from_params

from_params(model, params: dict) -> WPPMCovarianceField

Create field from arbitrary parameters.

Useful for: - Custom initialization - Posterior samples - Intermediate optimization checkpoints - Testing

Parameters:

Name Type Description Default
model WPPM

Model providing evaluation logic

required
params dict

Parameter dictionary

required

Returns:

Type Description
WPPMCovarianceField

Examples:

>>> params = {"log_diag": jnp.array([0.1, 0.2])}
>>> field = WPPMCovarianceField.from_params(model, params)
Source code in src/psyphy/model/covariance_field.py
@classmethod
def from_params(cls, model, params: dict) -> WPPMCovarianceField:
    """
    Create field from arbitrary parameters.

    Useful for:
    - Custom initialization
    - Posterior samples
    - Intermediate optimization checkpoints
    - Testing

    Parameters
    ----------
    model : WPPM
        Model providing evaluation logic
    params : dict
        Parameter dictionary

    Returns
    -------
    WPPMCovarianceField

    Examples
    --------
    >>> params = {"log_diag": jnp.array([0.1, 0.2])}
    >>> field = WPPMCovarianceField.from_params(model, params)
    """
    return cls(model, params)

from_posterior

from_posterior(posterior) -> WPPMCovarianceField

Create covariance field from fitted posterior.

Parameters:

Name Type Description Default
posterior ParameterPosterior

Fitted posterior (e.g., from model.fit())

required

Returns:

Type Description
WPPMCovarianceField

Field representing posterior estimate of Σ(x)

Notes

For MAP posteriors, uses θ_MAP. For variational posteriors, could use posterior mean or sample.

Examples:

1
2
3
>>> posterior = model.fit(data, optimizer=MAPOptimizer())
>>> field = WPPMCovarianceField.from_posterior(posterior)
>>> Sigma = field.cov(x)
Source code in src/psyphy/model/covariance_field.py
@classmethod
def from_posterior(cls, posterior) -> WPPMCovarianceField:
    """
    Create covariance field from fitted posterior.

    Parameters
    ----------
    posterior : ParameterPosterior
        Fitted posterior (e.g., from model.fit())

    Returns
    -------
    WPPMCovarianceField
        Field representing posterior estimate of Σ(x)

    Notes
    -----
    For MAP posteriors, uses θ_MAP.
    For variational posteriors, could use posterior mean or sample.

    Examples
    --------
    >>> posterior = model.fit(data, optimizer=MAPOptimizer())
    >>> field = WPPMCovarianceField.from_posterior(posterior)
    >>> Sigma = field.cov(x)
    """
    return cls(posterior._model, posterior.params)

from_prior

from_prior(model, key: KeyArray) -> WPPMCovarianceField

Sample a covariance field from the prior.

Parameters:

Name Type Description Default
model WPPM

Model defining prior distribution

required
key KeyArray

PRNG key for sampling

required

Returns:

Type Description
WPPMCovarianceField

Field sampled from p(Σ(x))

Examples:

1
2
3
>>> model = WPPM(input_dim=2, prior=Prior(basis_degree=5), ...)
>>> field = WPPMCovarianceField.from_prior(model, jr.PRNGKey(42))
>>> Sigma = field.cov(jnp.array([0.5, 0.5]))
Source code in src/psyphy/model/covariance_field.py
@classmethod
def from_prior(cls, model, key: jr.KeyArray) -> WPPMCovarianceField:
    """
    Sample a covariance field from the prior.

    Parameters
    ----------
    model : WPPM
        Model defining prior distribution
    key : jax.random.KeyArray
        PRNG key for sampling

    Returns
    -------
    WPPMCovarianceField
        Field sampled from p(Σ(x))

    Examples
    --------
    >>> model = WPPM(input_dim=2, prior=Prior(basis_degree=5), ...)
    >>> field = WPPMCovarianceField.from_prior(model, jr.PRNGKey(42))
    >>> Sigma = field.cov(jnp.array([0.5, 0.5]))
    """
    params = model.init_params(key)
    return cls(model, params)

sqrt_cov

sqrt_cov(x: ndarray) -> ndarray

Evaluate U(x) such that Σ(x) = U(x) @ U(x)^T + diag_term*I.

Parameters:

Name Type Description Default
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
(ndarray, shape(input_dim, embedding_dim))

Rectangular square root matrix U(x). embedding_dim = input_dim + extra_dims

Notes

Only available in Wishart mode. MVP mode uses diagonal parameterization without explicit U matrices.

In the rectangular design (Hong et al.), U is (input_dim, embedding_dim).

Examples:

>>> # Wishart mode
>>> model = WPPM(input_dim=2, basis_degree=5, extra_dims=1, ...)
>>> field = WPPMCovarianceField.from_prior(model, key)
>>> U = field.sqrt_cov(jnp.array([0.5, 0.3]))
>>> print(U.shape)  # (2, 3) for input_dim=2, extra_dims=1
>>>
>>> # Verify: Σ = U @ U^T + λI
>>> Sigma_from_U = U @ U.T + model.diag_term * jnp.eye(2)
>>> Sigma_direct = field.cov(x)
>>> assert jnp.allclose(Sigma_from_U, Sigma_direct)
Source code in src/psyphy/model/covariance_field.py
def sqrt_cov(self, x: jnp.ndarray) -> jnp.ndarray:
    """
    Evaluate U(x) such that Σ(x) = U(x) @ U(x)^T + diag_term*I.

    Parameters
    ----------
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    jnp.ndarray, shape (input_dim, embedding_dim)
        Rectangular square root matrix U(x).
        embedding_dim = input_dim + extra_dims


    Notes
    -----
    Only available in Wishart mode. MVP mode uses diagonal parameterization
    without explicit U matrices.

    In the rectangular design (Hong et al.), U is (input_dim, embedding_dim).

    Examples
    --------
    >>> # Wishart mode
    >>> model = WPPM(input_dim=2, basis_degree=5, extra_dims=1, ...)
    >>> field = WPPMCovarianceField.from_prior(model, key)
    >>> U = field.sqrt_cov(jnp.array([0.5, 0.3]))
    >>> print(U.shape)  # (2, 3) for input_dim=2, extra_dims=1
    >>>
    >>> # Verify: Σ = U @ U^T + λI
    >>> Sigma_from_U = U @ U.T + model.diag_term * jnp.eye(2)
    >>> Sigma_direct = field.cov(x)
    >>> assert jnp.allclose(Sigma_from_U, Sigma_direct)
    """
    if "W" not in self.params:
        raise ValueError(
            "sqrt_cov only available in Wishart mode. "
            "Set basis_degree when creating WPPM to use Wishart process."
        )
    return self.model._compute_sqrt(self.params, x)

sqrt_cov_batch

sqrt_cov_batch(X: ndarray) -> ndarray

Vectorized evaluation of U(x) at multiple locations.

Parameters:

Name Type Description Default
X (ndarray, shape(n_points, input_dim))

Multiple stimulus locations

required

Returns:

Type Description
(ndarray, shape(n_points, input_dim, embedding_dim))

Rectangular square root matrices at each location. embedding_dim = input_dim + extra_dims

Raises:

Type Description
ValueError

If in MVP mode.

Notes

In the rectangular design (Hong et al.), U is (input_dim, embedding_dim).

Examples:

1
2
3
>>> X_grid = jnp.array([[0.1, 0.2], [0.5, 0.5], [0.9, 0.8]])
>>> U_batch = field.sqrt_cov_batch(X_grid)
>>> print(U_batch.shape)  # (3, 2, 3) for input_dim=2, extra_dims=1
Source code in src/psyphy/model/covariance_field.py
def sqrt_cov_batch(self, X: jnp.ndarray) -> jnp.ndarray:
    """
    Vectorized evaluation of U(x) at multiple locations.

    Parameters
    ----------
    X : jnp.ndarray, shape (n_points, input_dim)
        Multiple stimulus locations

    Returns
    -------
    jnp.ndarray, shape (n_points, input_dim, embedding_dim)
        Rectangular square root matrices at each location.
        embedding_dim = input_dim + extra_dims

    Raises
    ------
    ValueError
        If in MVP mode.

    Notes
    -----
    In the rectangular design (Hong et al.), U is (input_dim, embedding_dim).

    Examples
    --------
    >>> X_grid = jnp.array([[0.1, 0.2], [0.5, 0.5], [0.9, 0.8]])
    >>> U_batch = field.sqrt_cov_batch(X_grid)
    >>> print(U_batch.shape)  # (3, 2, 3) for input_dim=2, extra_dims=1
    """
    if "W" not in self.params:
        raise ValueError("sqrt_cov_batch only available in Wishart mode")
    return jax.vmap(self.sqrt_cov)(X)

Wishart Psyochophysical Process Model (WPPM)


wppm

wppm.py

Wishart Process Psychophysical Model (WPPM)

Goals

Wishart Process Psychophysical Model (WPPM): - Expose hyperparameters needed to for example use Model config used in Hong et al.: * extra_dims: embedding size for basis expansions * variance_scale: global covariance scale * decay_rate: smoothness/length-scale for covariance field * diag_term: numerical stabilizer added to covariance diagonals

All numerics use JAX (jax.numpy as jnp) to support autodiff and optax optimizers

Classes:

Name Description
WPPM

Wishart Process Psychophysical Model (WPPM).

Attributes:

Name Type Description
Params
Stimulus

Params

Params = dict[str, ndarray]

Stimulus

Stimulus = tuple[ndarray, ndarray]

WPPM

WPPM(prior: Prior, likelihood: TaskLikelihood, noise: Any | None = None, *, input_dim: int = 2, extra_dims: int = 1, variance_scale: float = 0.004, decay_rate: float = 0.4, diag_term: float = 1e-06, **model_kwargs: Any)

Bases: Model

Wishart Process Psychophysical Model (WPPM).

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input stimulus space (e.g., 2 for isoluminant plane, 3 for RGB). Both reference and comparison live in R^{input_dim}.

2
prior Prior

Prior distribution over model parameters. Controls basis_degree in WPPM (basis expansion). The WPPM delegates to prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

required
likelihood TaskLikelihood

Psychophysical task mapping that defines how discriminability translates to p(correct) and how log-likelihood of responses is computed. (e.g., OddityTask)

required
noise Any

Noise model describing internal representation noise (e.g., GaussianNoise).

None
hyperparameters

extra_dims : int, default=0 Additional embedding dimensions for basis expansions (beyond input_dim). embedding_dim = input_dim + extra_dims. variance_scale : float, default=1.0 Global scaling factor for covariance magnitude decay_rate : float, default=1.0 Smoothness/length-scale for spatial covariance variation diag_term : float, default=1e-6 Small positive value added to the covariance diagonal for numerical stability.

model_kwargs : Any Reserved for future keyword arguments accepted by the base Model.__init__. Do not pass WPPM math knobs or task/likelihood knobs here.

Methods:

Name Description
init_params

Sample initial parameters from the prior.

local_covariance

Return local covariance Σ(x) at stimulus location x.

log_likelihood_from_data

Compute log-likelihood directly from a batched data object.

log_posterior_from_data

Compute log posterior from data.

predict_prob

Predict probability of a correct response for a single stimulus.

Attributes:

Name Type Description
basis_degree int | None

Chebyshev polynomial degree for Wishart process basis expansion.

decay_rate
diag_term
embedding_dim int

Dimension of the embedding space.

extra_dims
input_dim
likelihood
noise
prior
variance_scale
Source code in src/psyphy/model/wppm.py
def __init__(
    self,
    prior: Prior,
    likelihood: TaskLikelihood,
    noise: Any | None = None,
    *,  # everything after here is keyword-only
    input_dim: int = 2,
    extra_dims: int = 1,
    variance_scale: float = 4e-3,
    decay_rate: float = 0.4,
    diag_term: float = 1e-6,
    **model_kwargs: Any,
) -> None:
    # Base-model configuration.
    #
    # `model_kwargs` is reserved for *future* base `Model.__init__` kwargs.
    # It should NOT be used for WPPM-specific math (e.g. alternative covariance
    # parameterizations) or for task-specific likelihood knobs.
    if model_kwargs:
        known_misuses = {"num_samples", "bandwidth", "online_config"}
        bad = sorted(known_misuses.intersection(model_kwargs.keys()))
        if bad:
            raise TypeError(
                "Do not pass task-specific kwargs via WPPM(..., **model_kwargs). "
                f"Move {bad} into the task config (e.g. OddityTaskConfig)."
            )

    super().__init__(**model_kwargs)

    # --- core components ---
    self.input_dim = int(input_dim)  # stimulus-space dimensionality
    self.prior = prior  # prior over parameter PyTree

    if self.prior.input_dim != self.input_dim:
        raise ValueError(
            f"Dimension mismatch: Model initialized with input_dim={self.input_dim}, "
            f"but Prior expects input_dim={self.prior.input_dim}."
        )

    self.likelihood = likelihood  # task mapping and likelihood
    self.noise = noise  # noise model

    self.extra_dims = int(extra_dims)
    self.variance_scale = float(variance_scale)
    self.decay_rate = float(decay_rate)
    self.diag_term = float(diag_term)

basis_degree

basis_degree: int | None

Chebyshev polynomial degree for Wishart process basis expansion.

This property delegates to self.prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

Returns:

Type Description
int | None

Degree of Chebyshev polynomial basis (0 = constant, 1 = linear, etc.)

Notes

WPPM gets its basis_degree parameter from Prior.basis_degree.

decay_rate

decay_rate = float(decay_rate)

diag_term

diag_term = float(diag_term)

embedding_dim

embedding_dim: int

Dimension of the embedding space.

embedding_dim = input_dim + extra_dims. this represents the full perceptual space where: - First input_dim dimensions correspond to observable stimulus features - Remaining extra_dims are latent dimensions

Returns:

Type Description
int

input_dim + extra_dims

Notes

This is a computed property, not a constructor parameter.

extra_dims

extra_dims = int(extra_dims)

input_dim

input_dim = int(input_dim)

likelihood

likelihood = likelihood

noise

noise = noise

prior

prior = prior

variance_scale

variance_scale = float(variance_scale)

init_params

init_params(key: Array) -> Params

Sample initial parameters from the prior.

Returns:

Name Type Description
params dict[str, ndarray]
Source code in src/psyphy/model/wppm.py
def init_params(self, key: jax.Array) -> Params:
    """Sample initial parameters from the prior.

    Returns
    -------
    params : dict[str, jnp.ndarray]
    """
    return self.prior.sample_params(key)

local_covariance

local_covariance(params: Params, x: ndarray) -> ndarray

Return local covariance Σ(x) at stimulus location x.

Wishart mode (basis_degree set): Σ(x) = U(x) @ U(x)^T + diag_term * I where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0. - Varies smoothly with x - Guaranteed positive-definite - Returns stimulus covariance directly (input_dim, input_dim)

Parameters:

Name Type Description Default
params dict

Model parameters: - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}

required
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
Σ : jnp.ndarray, shape (input_dim, input_dim)

Covariance matrix in stimulus space.

Source code in src/psyphy/model/wppm.py
def local_covariance(self, params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    Return local covariance Σ(x) at stimulus location x.


    Wishart mode (basis_degree set):
        Σ(x) = U(x) @ U(x)^T + diag_term * I
        where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0.
        - Varies smoothly with x
        - Guaranteed positive-definite
        - Returns stimulus covariance directly (input_dim, input_dim)

    Parameters
    ----------
    params : dict
        Model parameters:
        - WPPM: {"W": (degree+1, ..., input_dim, embedding_dim)}
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    Σ : jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix in stimulus space.
    """
    # WPPM: spatially-varying covariance
    if "W" in params:
        U = self._compute_sqrt(params, x)  # (input_dim, embedding_dim)
        # Σ(x) = U(x) @ U(x)^T + diag_term * I
        # Result is (input_dim, input_dim)
        Sigma = U @ U.T + self.diag_term * jnp.eye(self.input_dim)
        return Sigma

    raise ValueError("params must contain 'W' (weights of WPPM)")

log_likelihood_from_data

log_likelihood_from_data(params: Params, data: Any, *, key: Array | None = None) -> ndarray

Compute log-likelihood directly from a batched data object.

Why delegate to the likelihood? - The likelihood knows the decision rule (oddity, 2AFC, ...). - The likelihood can use the model (this WPPM) to fetch discriminabilities. - The likelihood can use the noise model if it needs MC simulation.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data TrialData (or any object with refs/comparisons/responses arrays)

Collected trial data.

required
key Array | None

JAX random key for MC likelihood evaluation. When provided, a fresh noise realization is drawn every call — required for correct stochastic gradient estimates during optimization. When None, the task falls back to OddityTaskConfig.default_key_seed (useful for fixed evaluation and testing, but should not be used during gradient-based optimization).

None

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed).

Source code in src/psyphy/model/wppm.py
def log_likelihood_from_data(
    self, params: Params, data: Any, *, key: jax.Array | None = None
) -> jnp.ndarray:
    """Compute log-likelihood directly from a batched data object.

    Why delegate to the likelihood?
        - The likelihood knows the decision rule (oddity, 2AFC, ...).
        - The likelihood can use the model (this WPPM) to fetch discriminabilities.
        - The likelihood can use the noise model if it needs MC simulation.

    Parameters
    ----------
    params : dict
        Model parameters.
    data : TrialData (or any object with refs/comparisons/responses arrays)
        Collected trial data.
    key : jax.Array | None, optional
        JAX random key for MC likelihood evaluation. When provided, a fresh
        noise realization is drawn every call — required for correct stochastic
        gradient estimates during optimization. When None, the task falls back
        to ``OddityTaskConfig.default_key_seed`` (useful for fixed evaluation
        and testing, but should not be used during gradient-based optimization).

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed).
    """
    return self.likelihood.loglik(params, data, self, key=key)

log_posterior_from_data

log_posterior_from_data(params: Params, data: Any, *, key: Array | None = None) -> ndarray

Compute log posterior from data.

This simply adds the prior log-probability to the task log-likelihood. Inference engines (e.g., MAP optimizer) typically optimize this quantity.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data TrialData

Collected trial data.

required
key Array | None

JAX random key for the MC likelihood. Must be provided during optimization so each gradient step uses a fresh noise realization. When None, falls back to OddityTaskConfig.default_key_seed.

None

Returns:

Type Description
ndarray

Scalar log posterior = loglik(params | data) + log_prior(params).

Source code in src/psyphy/model/wppm.py
def log_posterior_from_data(
    self, params: Params, data: Any, *, key: jax.Array | None = None
) -> jnp.ndarray:
    """Compute log posterior from data.

    This simply adds the prior log-probability to the task log-likelihood.
    Inference engines (e.g., MAP optimizer) typically optimize this quantity.

    Parameters
    ----------
    params : dict
        Model parameters.
    data : TrialData
        Collected trial data.
    key : jax.Array | None, optional
        JAX random key for the MC likelihood. Must be provided during
        optimization so each gradient step uses a fresh noise realization.
        When None, falls back to ``OddityTaskConfig.default_key_seed``.

    Returns
    -------
    jnp.ndarray
        Scalar log posterior = loglik(params | data) + log_prior(params).
    """
    return self.log_likelihood_from_data(
        params, data, key=key
    ) + self.prior.log_prob(params)

predict_prob

predict_prob(params: Params, stimulus: Stimulus, **likelihood_kwargs: Any) -> ndarray

Predict probability of a correct response for a single stimulus.

Design choice: WPPM computes discriminability & covariance; the LIKELIHOOD defines how that translates to performance. We therefore delegate to: likelihood.predict(params, stimulus, model=self, noise=self.noise)

Parameters:

Name Type Description Default
params dict
required
stimulus tuple[ndarray, ndarray]

(reference, comparison) pair in model space.

required

Returns:

Name Type Description
p_correct ndarray
Source code in src/psyphy/model/wppm.py
def predict_prob(
    self, params: Params, stimulus: Stimulus, **likelihood_kwargs: Any
) -> jnp.ndarray:
    """
    Predict probability of a correct response for a single stimulus.

    Design choice:
        WPPM computes discriminability & covariance; the LIKELIHOOD defines how
        that translates to performance. We therefore delegate to:
            likelihood.predict(params, stimulus, model=self, noise=self.noise)

    Parameters
    ----------
    params : dict
    stimulus : tuple[jnp.ndarray, jnp.ndarray]
         (reference, comparison) pair in model space.

    Returns
    -------
    p_correct : jnp.ndarray
    """
    # Strict task-owned configuration:
    # - MC control knobs (e.g. num_samples/bandwidth) live in the task config.
    # - WPPM.predict_prob therefore does not accept task-specific kwargs.
    if likelihood_kwargs:
        raise TypeError(
            f"WPPM.predict_prob got unexpected kwargs: {likelihood_kwargs}. "
            "Configure likelihood behavior via the TaskLikelihood object itself."
        )

    ref, comparison = stimulus
    return self.likelihood.predict(params, ref, comparison, self)

Priors


prior

prior.py

Prior distributions for WPPM parameters

Hyperparameters: * variance_scale : global scaling factor for covariance magnitude * decay_rate : smoothness controlling spatial variation * extra_embedding_dims : embedding dimension for basis expansions

Connections
  • WPPM calls Prior.sample_params() to initialize model parameters
  • WPPM adds Prior.log_prob(params) to task log-likelihoods to form the log posterior
  • Prior will generate structured parameters for basis expansions and decay_rate-controlled smooth covariance fields

Classes:

Name Description
Prior

Prior distribution over WPPM parameters

Attributes:

Name Type Description
Params

Params

Params = dict[str, ndarray]

Prior

Prior(input_dim: int = 2, basis_degree: int = 4, variance_scale: float = 0.004, decay_rate: float = 0.4, extra_embedding_dims: int = 1)

Prior distribution over WPPM parameters

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the model space (same as WPPM.input_dim)

2
basis_degree int | None

Degree of Chebyshev basis for Wishart process. If set, uses Wishart mode with W coefficients.

None
variance_scale float

Prior variance for degree-0 (constant) coefficient in Wishart mode. Controls overall scale of covariances.

1.0
decay_rate float

Geometric decay rate for prior variance over higher-degree coefficients. Prior variance for degree-d coefficient = variance_scale * (decay_rate^d). Smaller decay_rate -> stronger smoothness prior.

0.5
extra_embedding_dims int

Additional latent dimensions in U matrices beyond input dimensions. Allows richer ellipsoid shapes in Wishart mode.

0

Methods:

Name Description
log_prob

Compute log prior density (up to a constant)

sample_params

Sample initial parameters from the prior.

Attributes:

Name Type Description
basis_degree int
decay_rate float
extra_embedding_dims int
input_dim int
variance_scale float

basis_degree

basis_degree: int = 4

decay_rate

decay_rate: float = 0.4

extra_embedding_dims

extra_embedding_dims: int = 1

input_dim

input_dim: int = 2

variance_scale

variance_scale: float = 0.004

log_prob

log_prob(params: Params) -> ndarray

Compute log prior density (up to a constant)

Gaussian prior on W with smoothness via decay_rate log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

Parameters:

Name Type Description Default
params dict

Parameter dictionary

required

Returns:

Name Type Description
log_prob float

Log prior probability (up to normalizing constant)

Source code in src/psyphy/model/prior.py
def log_prob(self, params: Params) -> jnp.ndarray:
    """
    Compute log prior density (up to a constant)

    Gaussian prior on W with smoothness via decay_rate
        log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

    Parameters
    ----------
    params : dict
        Parameter dictionary

    Returns
    -------
    log_prob : float
        Log prior probability (up to normalizing constant)
    """

    if "W" in params:
        # Wishart mode
        W = params["W"]
        variances = self._compute_W_prior_variances()

        # Gaussian log probability for each entry
        # log N(x | 0, σ^2) = -0.5 * (x^2/σ^2 + log(2πσ^2))
        # Up to constant: -0.5 * x^2/σ^2

        if self.input_dim == 2:
            # Each W[i,j,:,:] ~ Normal(0, variance[i,j] * I)
            return -0.5 * jnp.sum((W**2) / (variances[:, :, None, None] + 1e-10))
        elif self.input_dim == 3:
            return -0.5 * jnp.sum((W**2) / (variances[:, :, :, None, None] + 1e-10))

    raise ValueError("params must contain weights 'W'")

sample_params

sample_params(key: Any) -> Params

Sample initial parameters from the prior.

Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)} for 2D, where embedding_dim = input_dim + extra_embedding_dims

Note: The 3rd dimension is input_dim (output space dimension). This matches the einsum in _compute_sqrt: U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

Parameters:

Name Type Description Default
key JAX random key
required

Returns:

Name Type Description
params dict

Parameter dictionary

Source code in src/psyphy/model/prior.py
def sample_params(self, key: Any) -> Params:
    """
    Sample initial parameters from the prior.


    Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)}
    for 2D, where embedding_dim = input_dim + extra_embedding_dims

    Note: The 3rd dimension is input_dim (output space dimension).
    This matches the einsum in _compute_sqrt:
    U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

    Parameters
    ----------
    key : JAX random key

    Returns
    -------
    params : dict
        Parameter dictionary
    """
    if self.basis_degree is None:
        raise ValueError(
            "'basis_degree' is None; please set "
            "`Prior.basis_degree` to an integer >0."
        )

    # Basis function coefficients W
    variances = self._compute_W_prior_variances()
    embedding_dim = self.input_dim + self.extra_embedding_dims

    if self.input_dim == 2:
        # Sample W ~ Normal(0, variances) for each matrix entry
        # Shape: (degree+1, degree+1, input_dim, embedding_dim)
        # Note: degree+1 to match number of basis functions [T_0, ..., T_degree]
        W = jnp.sqrt(variances)[:, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    elif self.input_dim == 3:
        # Shape: (degree+1, degree+1, degree+1, input_dim, embedding_dim)
        W = jnp.sqrt(variances)[:, :, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    else:
        raise NotImplementedError(
            f"Wishart process only supports 2D and 3D. Got input_dim={self.input_dim}"
        )

    return {"W": W}

Noise


noise

Classes:

Name Description
GaussianNoise
StudentTNoise

GaussianNoise

GaussianNoise(sigma: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Gaussian (mean=0, var=1).

Attributes:

Name Type Description
sigma float

sigma

sigma: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(key: Array, shape: tuple[int, ...]) -> Array

Sample from standard Gaussian (mean=0, var=1).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Gaussian (mean=0, var=1)."""
    return jr.normal(key, shape)

StudentTNoise

StudentTNoise(df: float = 3.0, scale: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Student-t (df=self.df).

Attributes:

Name Type Description
df float
scale float

df

df: float = 3.0

scale

scale: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(key: Array, shape: tuple[int, ...]) -> Array

Sample from standard Student-t (df=self.df).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Student-t (df=self.df)."""
    return jr.t(key, self.df, shape)

Likelihood (defined by Tasks)


likelihood

psyphy.model.likelihood

Task likelihoods for psychophysical experiments.

This module defines task-specific mappings from a model (e.g., WPPM) and stimuli to response likelihoods.

Current direction

OddityTask: the log-likelihood is computed via Monte Carlo observer simulation of the full 3-stimulus oddity decision rule (two identical references, one comparison).

The public API is:

  • TaskLikelihood.predict(params, stimuli, model, noise) Optional fast predictor for p(correct). For MC-only tasks this may be unimplemented.

  • TaskLikelihood.loglik(params, data, model, noise, **kwargs) Compute log-likelihood of observed responses under this task.

Connections
  • WPPM delegates to the task to compute likelihood.
  • Noise models are passed through so likelihoods can simulate observer responses.

Classes:

Name Description
OddityTask

Three-alternative forced-choice oddity task (MC-based only).

OddityTaskConfig

Configuration for :class:OddityTask.

TaskLikelihood

Abstract base class for task likelihoods.

OddityTask

OddityTask(config: OddityTaskConfig | None = None)

Bases: TaskLikelihood

Three-alternative forced-choice oddity task (MC-based only).

Implements the full 3-stimulus oddity task using Monte Carlo simulation: - Samples three internal representations per trial (z0, z1, z2) - Uses proper oddity decision rule with three pairwise distances - Suitable for complex covariance structures

Notes

MC simulation in loglik() (full 3-stimulus oddity): 1. Sample three internal representations: z_ref, z_refprime ~ N(ref, Σ_ref), z_comparison ~ N(comparison, Σ_comparison) 2. Compute average covariance: Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison 3. Compute three pairwise Mahalanobis distances: - d^2(z_ref, z_refprime) = distance between two reference samples - d^2(z_ref, z_comparison) = distance from ref to comparison - d^2(z_refprime, z_comparison) = distance from reference_prime to comparison 4. Apply oddity decision rule: delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime) 5. Logistic smoothing: P(correct) pprox logistic.cdf(delta / bandwidth) 6. Average over samples

Examples:

>>> from psyphy.model.likelihood import OddityTask
>>> from psyphy.model.likelihood import OddityTaskConfig
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.noise import GaussianNoise
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>>
>>> # Create task and model (task-owned MC controls)
>>> likelihood = OddityTask(
...     config=OddityTaskConfig(num_samples=1000, bandwidth=1e-2)
... )
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2),
...     likelihood=task,
...     noise=GaussianNoise(),
... )
>>> params = model.init_params(jr.PRNGKey(0))
1
2
3
4
5
6
>>> # MC simulation
>>> from psyphy.data.dataset import ResponseData
>>> data = ResponseData()
>>> data.add_trial(ref, comparison, resp=1)
>>> ll_mc = likelihood.loglik(params, data, model, key=jr.PRNGKey(42))
>>> print(f"Log-likelihood (MC): {ll_mc:.4f}")

Methods:

Name Description
loglik

Compute Bernoulli log-likelihood over a batch of trials.

predict

Return p(correct) for a single (ref, comparison) trial via MC simulation.

simulate

Simulate observed binary responses for a batch of trials.

Attributes:

Name Type Description
config
Source code in src/psyphy/model/likelihood.py
def __init__(self, config: OddityTaskConfig | None = None) -> None:
    # No analytical parameters in MC-only mode.
    self.config = config or OddityTaskConfig()

config

config = config or OddityTaskConfig()

loglik

loglik(params: Any, data: Any, model: Any, *, key: Any = None) -> ndarray

Compute Bernoulli log-likelihood over a batch of trials.

This is a concrete base-class method: it vmaps predict over trials then applies the Bernoulli log-likelihood formula. Subclasses only need to implement predict.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
data Any

Object with .refs, .comparisons, .responses array attributes.

required
model Any

Model instance.

required
key KeyArray

PRNG key. Passed as independent per-trial subkeys to predict. When None, falls back to key=jr.PRNGKey(0) (deterministic).

None

Returns:

Type Description
ndarray

Scalar sum of Bernoulli log-likelihoods over all trials.

Source code in src/psyphy/model/likelihood.py
def loglik(
    self,
    params: Any,
    data: Any,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Compute Bernoulli log-likelihood over a batch of trials.

    This is a concrete base-class method: it vmaps ``predict`` over trials
    then applies the Bernoulli log-likelihood formula. Subclasses only need
    to implement ``predict``.

    Parameters
    ----------
    params : Any
        Model parameters.
    data : Any
        Object with ``.refs``, ``.comparisons``, ``.responses`` array attributes.
    model : Any
        Model instance.
    key : jax.random.KeyArray, optional
        PRNG key. Passed as independent per-trial subkeys to ``predict``.
        When None, falls back to ``key=jr.PRNGKey(0)`` (deterministic).

    Returns
    -------
    jnp.ndarray
        Scalar sum of Bernoulli log-likelihoods over all trials.
    """
    refs = jnp.asarray(data.refs)
    comparisons = jnp.asarray(data.comparisons)
    responses = jnp.asarray(data.responses)
    n_trials = int(refs.shape[0])

    base_key = key if key is not None else jr.PRNGKey(0)
    trial_keys = jr.split(base_key, n_trials)

    probs = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),
        jnp.log(1.0 - probs),
    )
    return jnp.sum(log_likelihoods)

predict

predict(params: Any, ref: ndarray, comparison: ndarray, model: Any, *, key: Any = None) -> ndarray

Return p(correct) for a single (ref, comparison) trial via MC simulation.

MC controls (num_samples, bandwidth) are read from :class:OddityTaskConfig. Pass key to control randomness; when None, config.default_key_seed is used.

Source code in src/psyphy/model/likelihood.py
def predict(
    self,
    params: Any,
    ref: jnp.ndarray,
    comparison: jnp.ndarray,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Return p(correct) for a single (ref, comparison) trial via MC simulation.

    MC controls (``num_samples``, ``bandwidth``) are read from
    :class:`OddityTaskConfig`. Pass ``key`` to control randomness; when
    None, ``config.default_key_seed`` is used.
    """
    num_samples = int(self.config.num_samples)
    bandwidth = float(self.config.bandwidth)
    if key is None:
        key = jr.PRNGKey(int(self.config.default_key_seed))

    return self._simulate_trial_mc(
        params=params,
        ref=ref,
        comparison=comparison,
        model=model,
        num_samples=num_samples,
        bandwidth=bandwidth,
        key=key,
    )

simulate

simulate(params: Any, refs: ndarray, comparisons: ndarray, model: Any, *, key: Any) -> tuple[ndarray, ndarray]

Simulate observed binary responses for a batch of trials.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
refs (ndarray, shape(n_trials, input_dim))

Reference stimuli.

required
comparisons (ndarray, shape(n_trials, input_dim))

Comparison stimuli.

required
model Any

Model instance.

required
key KeyArray

PRNG key (required; split internally for prediction and sampling).

required

Returns:

Name Type Description
responses jnp.ndarray, shape (n_trials,), dtype int32

Simulated binary responses (1 = correct, 0 = incorrect).

p_correct (ndarray, shape(n_trials))

Estimated P(correct) per trial used to draw the responses.

Source code in src/psyphy/model/likelihood.py
def simulate(
    self,
    params: Any,
    refs: jnp.ndarray,
    comparisons: jnp.ndarray,
    model: Any,
    *,
    key: Any,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Simulate observed binary responses for a batch of trials.

    Parameters
    ----------
    params : Any
        Model parameters.
    refs : jnp.ndarray, shape (n_trials, input_dim)
        Reference stimuli.
    comparisons : jnp.ndarray, shape (n_trials, input_dim)
        Comparison stimuli.
    model : Any
        Model instance.
    key : jax.random.KeyArray
        PRNG key (required; split internally for prediction and sampling).

    Returns
    -------
    responses : jnp.ndarray, shape (n_trials,), dtype int32
        Simulated binary responses (1 = correct, 0 = incorrect).
    p_correct : jnp.ndarray, shape (n_trials,)
        Estimated P(correct) per trial used to draw the responses.
    """
    refs = jnp.asarray(refs)
    comparisons = jnp.asarray(comparisons)
    n_trials = int(refs.shape[0])

    k_pred, k_bernoulli = jr.split(key)
    trial_keys = jr.split(k_pred, n_trials)

    p_correct = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    responses = jr.bernoulli(k_bernoulli, p_correct).astype(jnp.int32)
    return responses, p_correct

OddityTaskConfig

OddityTaskConfig(num_samples: int = 1000, bandwidth: float = 0.01, default_key_seed: int = 0)

Configuration for :class:OddityTask.

This is the single source of truth for MC likelihood controls.

Attributes:

Name Type Description
num_samples int

Number of Monte Carlo samples per trial.

bandwidth float

Logistic CDF smoothing bandwidth.

default_key_seed int

Seed used when no key is provided (keeps behavior deterministic by default while allowing reproducibility control upstream).

bandwidth

bandwidth: float = 0.01

default_key_seed

default_key_seed: int = 0

num_samples

num_samples: int = 1000

TaskLikelihood

Bases: ABC

Abstract base class for task likelihoods.

Subclasses must implement: - predict(params, ref, comparison, model, *, key) → p(correct) for one trial

The base class provides concrete implementations of: - loglik(params, data, model, *, key) → Bernoulli log-likelihood over a batch - simulate(params, refs, comparisons, model, *, key) → simulated responses

The Bernoulli log-likelihood step is identical for all binary-response tasks, so it lives here rather than being re-implemented in every subclass.

Methods:

Name Description
loglik

Compute Bernoulli log-likelihood over a batch of trials.

predict

Return p(correct) for a single (ref, comparison) trial.

simulate

Simulate observed binary responses for a batch of trials.

loglik

loglik(params: Any, data: Any, model: Any, *, key: Any = None) -> ndarray

Compute Bernoulli log-likelihood over a batch of trials.

This is a concrete base-class method: it vmaps predict over trials then applies the Bernoulli log-likelihood formula. Subclasses only need to implement predict.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
data Any

Object with .refs, .comparisons, .responses array attributes.

required
model Any

Model instance.

required
key KeyArray

PRNG key. Passed as independent per-trial subkeys to predict. When None, falls back to key=jr.PRNGKey(0) (deterministic).

None

Returns:

Type Description
ndarray

Scalar sum of Bernoulli log-likelihoods over all trials.

Source code in src/psyphy/model/likelihood.py
def loglik(
    self,
    params: Any,
    data: Any,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Compute Bernoulli log-likelihood over a batch of trials.

    This is a concrete base-class method: it vmaps ``predict`` over trials
    then applies the Bernoulli log-likelihood formula. Subclasses only need
    to implement ``predict``.

    Parameters
    ----------
    params : Any
        Model parameters.
    data : Any
        Object with ``.refs``, ``.comparisons``, ``.responses`` array attributes.
    model : Any
        Model instance.
    key : jax.random.KeyArray, optional
        PRNG key. Passed as independent per-trial subkeys to ``predict``.
        When None, falls back to ``key=jr.PRNGKey(0)`` (deterministic).

    Returns
    -------
    jnp.ndarray
        Scalar sum of Bernoulli log-likelihoods over all trials.
    """
    refs = jnp.asarray(data.refs)
    comparisons = jnp.asarray(data.comparisons)
    responses = jnp.asarray(data.responses)
    n_trials = int(refs.shape[0])

    base_key = key if key is not None else jr.PRNGKey(0)
    trial_keys = jr.split(base_key, n_trials)

    probs = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),
        jnp.log(1.0 - probs),
    )
    return jnp.sum(log_likelihoods)

predict

predict(params: Any, ref: ndarray, comparison: ndarray, model: Any, *, key: Any = None) -> ndarray

Return p(correct) for a single (ref, comparison) trial.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
ref (ndarray, shape(input_dim))

Reference stimulus.

required
comparison (ndarray, shape(input_dim))

Comparison stimulus.

required
model Any

Model instance (provides covariance structure and model.noise).

required
key KeyArray

PRNG key for stochastic tasks. When None, the task falls back to its config.default_key_seed.

None

Returns:

Type Description
ndarray

Scalar p(correct) in (0, 1).

Source code in src/psyphy/model/likelihood.py
@abstractmethod
def predict(
    self,
    params: Any,
    ref: jnp.ndarray,
    comparison: jnp.ndarray,
    model: Any,
    *,
    key: Any = None,
) -> jnp.ndarray:
    """Return p(correct) for a single (ref, comparison) trial.

    Parameters
    ----------
    params : Any
        Model parameters.
    ref : jnp.ndarray, shape (input_dim,)
        Reference stimulus.
    comparison : jnp.ndarray, shape (input_dim,)
        Comparison stimulus.
    model : Any
        Model instance (provides covariance structure and ``model.noise``).
    key : jax.random.KeyArray, optional
        PRNG key for stochastic tasks. When None, the task falls back to
        its ``config.default_key_seed``.

    Returns
    -------
    jnp.ndarray
        Scalar p(correct) in (0, 1).
    """
    ...

simulate

simulate(params: Any, refs: ndarray, comparisons: ndarray, model: Any, *, key: Any) -> tuple[ndarray, ndarray]

Simulate observed binary responses for a batch of trials.

Parameters:

Name Type Description Default
params Any

Model parameters.

required
refs (ndarray, shape(n_trials, input_dim))

Reference stimuli.

required
comparisons (ndarray, shape(n_trials, input_dim))

Comparison stimuli.

required
model Any

Model instance.

required
key KeyArray

PRNG key (required; split internally for prediction and sampling).

required

Returns:

Name Type Description
responses jnp.ndarray, shape (n_trials,), dtype int32

Simulated binary responses (1 = correct, 0 = incorrect).

p_correct (ndarray, shape(n_trials))

Estimated P(correct) per trial used to draw the responses.

Source code in src/psyphy/model/likelihood.py
def simulate(
    self,
    params: Any,
    refs: jnp.ndarray,
    comparisons: jnp.ndarray,
    model: Any,
    *,
    key: Any,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Simulate observed binary responses for a batch of trials.

    Parameters
    ----------
    params : Any
        Model parameters.
    refs : jnp.ndarray, shape (n_trials, input_dim)
        Reference stimuli.
    comparisons : jnp.ndarray, shape (n_trials, input_dim)
        Comparison stimuli.
    model : Any
        Model instance.
    key : jax.random.KeyArray
        PRNG key (required; split internally for prediction and sampling).

    Returns
    -------
    responses : jnp.ndarray, shape (n_trials,), dtype int32
        Simulated binary responses (1 = correct, 0 = incorrect).
    p_correct : jnp.ndarray, shape (n_trials,)
        Estimated P(correct) per trial used to draw the responses.
    """
    refs = jnp.asarray(refs)
    comparisons = jnp.asarray(comparisons)
    n_trials = int(refs.shape[0])

    k_pred, k_bernoulli = jr.split(key)
    trial_keys = jr.split(k_pred, n_trials)

    p_correct = jax.vmap(
        lambda ref, comparison, k: self.predict(
            params, ref, comparison, model, key=k
        )
    )(refs, comparisons, trial_keys)

    responses = jr.bernoulli(k_bernoulli, p_correct).astype(jnp.int32)
    return responses, p_correct