Skip to content

Posterior

Package Overview


posterior

posterior

Posterior representations.

This subpackage provides: - ParameterPosterior: protocol for posteriors over model parameters p(θ | data) - MAPPosterior: delta distribution at θ_MAP (point estimate)

Two-tier design
  • ParameterPosterior: represents p(θ | data)
  • PredictivePosterior: represents p(f(X*) | data)
Future extensions
  • LaplacePosterior: Gaussian approximation N(θ_MAP, Σ)
  • NumpyroPosterior/BlackjaxPosterior: MCMC samples

Classes:

Name Description
MAPPosterior

MAP (Maximum A Posteriori) posterior - delta distribution at θ_MAP.

ParameterPosterior

Protocol for posterior distributions over model parameters p(θ | data).

PredictivePosterior

Protocol for predictive distributions p(f(X*) | data) at test stimuli.

WPPMPredictivePosterior

Predictive posterior for WPPM models.

MAPPosterior

MAPPosterior(params, model)

MAP (Maximum A Posteriori) posterior - delta distribution at θ_MAP.

Represents a point estimate with no uncertainty.

Parameters:

Name Type Description Default
params dict

MAP parameter dictionary (θ_MAP)

required
model WPPM

Model instance used for predictions

required
Notes

This implements the ParameterPosterior protocol.

Methods:

Name Description
sample

Sample from delta distribution (returns repeated θ_MAP).

Attributes:

Name Type Description
model

Return the associated model.

params

Return the MAP parameters (θ_MAP).

Source code in src/psyphy/posterior/posterior.py
def __init__(self, params, model):
    self._params = params
    self._model = model

model

model

Return the associated model.

params

params

Return the MAP parameters (θ_MAP).

sample

sample(n: int = 1, *, key=None)

Sample from delta distribution (returns repeated θ_MAP).

Parameters:

Name Type Description Default
n int

Number of samples

1
key KeyArray

PRNG key (unused for delta distribution)

None

Returns:

Type Description
dict

Parameter PyTree with leading dimension n. Each array has shape (n, ...) with identical values.

Notes

Delta distribution has no randomness - returns repeated MAP estimate.

Source code in src/psyphy/posterior/posterior.py
def sample(self, n: int = 1, *, key=None):
    """
    Sample from delta distribution (returns repeated θ_MAP).

    Parameters
    ----------
    n : int, default=1
        Number of samples
    key : jax.random.KeyArray, optional
        PRNG key (unused for delta distribution)

    Returns
    -------
    dict
        Parameter PyTree with leading dimension n.
        Each array has shape (n, ...) with identical values.

    Notes
    -----
    Delta distribution has no randomness - returns repeated MAP estimate.
    """
    return jax.tree.map(
        lambda x: jnp.tile(x[None, ...], (n,) + (1,) * x.ndim), self._params
    )

ParameterPosterior

Bases: Protocol

Protocol for posterior distributions over model parameters p(θ | data).

Returned by InferenceEngine.fit(model, data). Used for research workflows: diagnostics, parameter sampling, uncertainty.

Methods:

Name Description
sample

Sample parameter vectors from p(θ | data).

Attributes:

Name Type Description
model

Associated generative model.

params dict

Point estimate or posterior mean parameters.

model

model

Associated generative model.

Returns:

Type Description
Model

The WPPM or other model instance used for predictions.

params

params: dict

Point estimate or posterior mean parameters.

Returns:

Type Description
dict

Parameter PyTree (e.g., {"log_diag": jnp.ndarray, ...})

Notes
  • MAP: θ_MAP
  • MCMC: posterior mean of samples

sample

sample(n: int, *, key: KeyArray) -> dict

Sample parameter vectors from p(θ | data).

Parameters:

Name Type Description Default
n int

Number of samples

required
key KeyArray

PRNG key for randomness

required

Returns:

Type Description
dict

Parameter PyTree with leading dimension n. Example: {"log_diag": jnp.ndarray with shape (n, input_dim)}

Notes
  • MAP: returns repeated θ_MAP
  • Numpyro/Blackjax: returns stored samples (may subsample if n differs)
Source code in src/psyphy/posterior/parameter_posterior.py
def sample(self, n: int, *, key: jr.KeyArray) -> dict:
    """
    Sample parameter vectors from p(θ | data).

    Parameters
    ----------
    n : int
        Number of samples
    key : jax.random.KeyArray
        PRNG key for randomness

    Returns
    -------
    dict
        Parameter PyTree with leading dimension n.
        Example: {"log_diag": jnp.ndarray with shape (n, input_dim)}

    Notes
    -----
    - MAP: returns repeated θ_MAP
    - Numpyro/Blackjax: returns stored samples (may subsample if n differs)
    """
    ...

PredictivePosterior

Bases: Protocol

Protocol for predictive distributions p(f(X*) | data) at test stimuli.

Returned by Model.posterior(X) for use in acquisition functions.

Methods:

Name Description
cov_field

Posterior over perceptual covariance field Σ(X).

rsample

Reparameterized samples from p(f(X*) | data).

Attributes:

Name Type Description
mean ndarray

Posterior predictive mean E[f(X*) | data].

variance ndarray

Posterior predictive marginal variances Var[f(X*) | data].

mean

mean: ndarray

Posterior predictive mean E[f(X*) | data].

Returns:

Type Description
ndarray

Shape (n_test,) for scalar outputs Shape (n_test, output_dim) for vector outputs (future)

Notes

Computed via Monte Carlo integration over parameter posterior.

variance

variance: ndarray

Posterior predictive marginal variances Var[f(X*) | data].

Returns:

Type Description
ndarray

Shape (n_test,) for scalar outputs Shape (n_test, output_dim) for vector outputs (future)

Notes

Captures both aleatoric (model) and epistemic (parameter) uncertainty.

cov_field

cov_field(X: ndarray) -> ndarray

Posterior over perceptual covariance field Σ(X).

Parameters:

Name Type Description Default
X ndarray

Test stimuli, shape (n_test, input_dim)

required

Returns:

Type Description
ndarray

Posterior mean covariance E[Σ(X) | data], shape (n_test, input_dim, input_dim)

Notes

WPPM-specific method for visualizing perceptual noise structure. This is NOT the predictive covariance - it's the model's internal representation of perceptual uncertainty.

Source code in src/psyphy/posterior/predictive_posterior.py
def cov_field(self, X: jnp.ndarray) -> jnp.ndarray:
    """
    Posterior over perceptual covariance field Σ(X).

    Parameters
    ----------
    X : jnp.ndarray
        Test stimuli, shape (n_test, input_dim)

    Returns
    -------
    jnp.ndarray
        Posterior mean covariance E[Σ(X) | data],
        shape (n_test, input_dim, input_dim)

    Notes
    -----
    WPPM-specific method for visualizing perceptual noise structure.
    This is NOT the predictive covariance - it's the model's
    internal representation of perceptual uncertainty.
    """
    ...

rsample

rsample(
    sample_shape: tuple = (), *, key: KeyArray
) -> ndarray

Reparameterized samples from p(f(X*) | data).

Parameters:

Name Type Description Default
sample_shape tuple

Shape of sample batch

()
key KeyArray

PRNG key

required

Returns:

Type Description
ndarray

Shape (sample_shape, n_test) for scalar outputs Shape (sample_shape, n_test, output_dim) for vector outputs

Notes

Enables gradient-based acquisition optimization via reparameterization trick.

Source code in src/psyphy/posterior/predictive_posterior.py
def rsample(self, sample_shape: tuple = (), *, key: jr.KeyArray) -> jnp.ndarray:
    """
    Reparameterized samples from p(f(X*) | data).

    Parameters
    ----------
    sample_shape : tuple, default=()
        Shape of sample batch
    key : jax.random.KeyArray
        PRNG key

    Returns
    -------
    jnp.ndarray
        Shape (*sample_shape, n_test) for scalar outputs
        Shape (*sample_shape, n_test, output_dim) for vector outputs

    Notes
    -----
    Enables gradient-based acquisition optimization via reparameterization trick.
    """
    ...

WPPMPredictivePosterior

WPPMPredictivePosterior(
    param_posterior: ParameterPosterior,
    X: ndarray,
    probes: ndarray | None = None,
    n_samples: int = 100,
)

Predictive posterior for WPPM models.

Computes p(f(X*) | data) via Monte Carlo integration over parameter posterior p(θ | data).

Parameters:

Name Type Description Default
param_posterior ParameterPosterior

Posterior over model parameters

required
X (ndarray, shape(n_test, input_dim))

Test reference stimuli

required
probes (ndarray, shape(n_test, input_dim))

Test probe stimuli. If None, predictions are over thresholds.

None
n_samples int

Number of posterior samples for MC integration

100

Attributes:

Name Type Description
param_posterior ParameterPosterior

Wrapped parameter posterior

X ndarray

Test stimuli

probes ndarray | None

Test probes

n_samples int

MC sample count

Notes

Uses lazy evaluation: moments computed on first access.

Methods:

Name Description
cov_field

Posterior mean covariance field E[Σ(X) | data].

rsample

Sample predictions from p(f(X*) | data).

Source code in src/psyphy/posterior/predictive_posterior.py
def __init__(
    self,
    param_posterior: ParameterPosterior,
    X: jnp.ndarray,
    probes: jnp.ndarray | None = None,
    n_samples: int = 100,
):
    self.param_posterior = param_posterior
    self.X = X
    self.probes = probes
    self.n_samples = n_samples

    # Lazy evaluation cache
    self._mean = None
    self._variance = None
    self._computed = False

X

X = X

mean

mean: ndarray

E[f(X*) | data], shape (n_test,).

n_samples

n_samples = n_samples

param_posterior

param_posterior = param_posterior

probes

probes = probes

variance

variance: ndarray

Var[f(X*) | data], shape (n_test,).

cov_field

cov_field(X: ndarray) -> ndarray

Posterior mean covariance field E[Σ(X) | data].

Parameters:

Name Type Description Default
X ndarray

Test stimuli, shape (n_test, input_dim)

required

Returns:

Type Description
ndarray

Covariance matrices, shape (n_test, input_dim, input_dim)

Notes

Averages local_covariance(x) over parameter posterior samples.

Source code in src/psyphy/posterior/predictive_posterior.py
def cov_field(self, X: jnp.ndarray) -> jnp.ndarray:
    """
    Posterior mean covariance field E[Σ(X) | data].

    Parameters
    ----------
    X : jnp.ndarray
        Test stimuli, shape (n_test, input_dim)

    Returns
    -------
    jnp.ndarray
        Covariance matrices, shape (n_test, input_dim, input_dim)

    Notes
    -----
    Averages local_covariance(x) over parameter posterior samples.
    """
    key = jr.PRNGKey(0)
    param_samples = self.param_posterior.sample(self.n_samples, key=key)

    model = self.param_posterior.model

    def cov_at_x(params, x):
        """Evaluate Σ(x) with given parameters."""
        return model.local_covariance(params, x)

    # Vectorized evaluation: (n_samples, n_test, input_dim, input_dim)
    cov_samples = jax.vmap(
        lambda params: jax.vmap(lambda x: cov_at_x(params, x))(X)
    )(param_samples)

    # Return posterior mean
    return jnp.mean(cov_samples, axis=0)

rsample

rsample(
    sample_shape: tuple = (), *, key: KeyArray
) -> ndarray

Sample predictions from p(f(X*) | data).

Parameters:

Name Type Description Default
sample_shape tuple

Batch shape

()
key KeyArray

PRNG key

required

Returns:

Type Description
ndarray

Shape (*sample_shape, n_test)

Source code in src/psyphy/posterior/predictive_posterior.py
def rsample(self, sample_shape: tuple = (), *, key: jr.KeyArray) -> jnp.ndarray:
    """
    Sample predictions from p(f(X*) | data).

    Parameters
    ----------
    sample_shape : tuple
        Batch shape
    key : jax.random.KeyArray
        PRNG key

    Returns
    -------
    jnp.ndarray
        Shape (*sample_shape, n_test)
    """
    n = int(jnp.prod(jnp.array(sample_shape))) if sample_shape else 1
    param_samples = self.param_posterior.sample(n, key=key)

    model = self.param_posterior.model

    if self.probes is None:
        raise NotImplementedError("Threshold sampling not yet implemented")

    def predict_one(params):
        """Predict for all test points with given params."""
        return jax.vmap(lambda r, p: model.predict_prob(params, (r, p)))(
            self.X, self.probes
        )

    samples = jax.vmap(predict_one)(param_samples)

    if sample_shape:
        return samples.reshape(*sample_shape, -1)
    return samples

Parameter Posterior Protocol


parameter_posterior

parameter_posterior.py

Protocol and implementations for posterior distributions over model parameters.

This module defines the ParameterPosterior interface representing p(θ | data), used for research workflows: diagnostics, parameter uncertainty, sampling.

Design

Different inference engines produce different posterior representations: - MAP: delta distribution at θ_MAP - MCMC: collection of samples

All implement a common protocol for polymorphic use.

Classes:

Name Description
ParameterPosterior

Protocol for posterior distributions over model parameters p(θ | data).

ParameterPosterior

Bases: Protocol

Protocol for posterior distributions over model parameters p(θ | data).

Returned by InferenceEngine.fit(model, data). Used for research workflows: diagnostics, parameter sampling, uncertainty.

Methods:

Name Description
sample

Sample parameter vectors from p(θ | data).

Attributes:

Name Type Description
model

Associated generative model.

params dict

Point estimate or posterior mean parameters.

model

model

Associated generative model.

Returns:

Type Description
Model

The WPPM or other model instance used for predictions.

params

params: dict

Point estimate or posterior mean parameters.

Returns:

Type Description
dict

Parameter PyTree (e.g., {"log_diag": jnp.ndarray, ...})

Notes
  • MAP: θ_MAP
  • MCMC: posterior mean of samples

sample

sample(n: int, *, key: KeyArray) -> dict

Sample parameter vectors from p(θ | data).

Parameters:

Name Type Description Default
n int

Number of samples

required
key KeyArray

PRNG key for randomness

required

Returns:

Type Description
dict

Parameter PyTree with leading dimension n. Example: {"log_diag": jnp.ndarray with shape (n, input_dim)}

Notes
  • MAP: returns repeated θ_MAP
  • Numpyro/Blackjax: returns stored samples (may subsample if n differs)
Source code in src/psyphy/posterior/parameter_posterior.py
def sample(self, n: int, *, key: jr.KeyArray) -> dict:
    """
    Sample parameter vectors from p(θ | data).

    Parameters
    ----------
    n : int
        Number of samples
    key : jax.random.KeyArray
        PRNG key for randomness

    Returns
    -------
    dict
        Parameter PyTree with leading dimension n.
        Example: {"log_diag": jnp.ndarray with shape (n, input_dim)}

    Notes
    -----
    - MAP: returns repeated θ_MAP
    - Numpyro/Blackjax: returns stored samples (may subsample if n differs)
    """
    ...

Parameter Posterior Implementations


posterior

posterior.py

Concrete ParameterPosterior implementations.

This module provides: - MAPPosterior: delta distribution at θ_MAP (point estimate)

Classes:

Name Description
MAPPosterior

MAP (Maximum A Posteriori) posterior - delta distribution at θ_MAP.

MAPPosterior

MAPPosterior(params, model)

MAP (Maximum A Posteriori) posterior - delta distribution at θ_MAP.

Represents a point estimate with no uncertainty.

Parameters:

Name Type Description Default
params dict

MAP parameter dictionary (θ_MAP)

required
model WPPM

Model instance used for predictions

required
Notes

This implements the ParameterPosterior protocol.

Methods:

Name Description
sample

Sample from delta distribution (returns repeated θ_MAP).

Attributes:

Name Type Description
model

Return the associated model.

params

Return the MAP parameters (θ_MAP).

Source code in src/psyphy/posterior/posterior.py
def __init__(self, params, model):
    self._params = params
    self._model = model

model

model

Return the associated model.

params

params

Return the MAP parameters (θ_MAP).

sample

sample(n: int = 1, *, key=None)

Sample from delta distribution (returns repeated θ_MAP).

Parameters:

Name Type Description Default
n int

Number of samples

1
key KeyArray

PRNG key (unused for delta distribution)

None

Returns:

Type Description
dict

Parameter PyTree with leading dimension n. Each array has shape (n, ...) with identical values.

Notes

Delta distribution has no randomness - returns repeated MAP estimate.

Source code in src/psyphy/posterior/posterior.py
def sample(self, n: int = 1, *, key=None):
    """
    Sample from delta distribution (returns repeated θ_MAP).

    Parameters
    ----------
    n : int, default=1
        Number of samples
    key : jax.random.KeyArray, optional
        PRNG key (unused for delta distribution)

    Returns
    -------
    dict
        Parameter PyTree with leading dimension n.
        Each array has shape (n, ...) with identical values.

    Notes
    -----
    Delta distribution has no randomness - returns repeated MAP estimate.
    """
    return jax.tree.map(
        lambda x: jnp.tile(x[None, ...], (n,) + (1,) * x.ndim), self._params
    )

Predictive Posterior


predictive_posterior

predictive_posterior.py

Predictive posterior distributions p(f(X*) | data) at test stimuli.

This module defines posteriors over predictions (not parameters), used by acquisition functions for Bayesian optimization.

Design

PredictivePosterior wraps a ParameterPosterior and computes predictions via: E[f(X) | data] ≈ (1/N) Σᵢ f(X; θᵢ) where θᵢ ~ p(θ | data)

This separates concerns: - ParameterPosterior: represents uncertainty over θ (research) - PredictivePosterior: represents uncertainty over f(X*) (decision-making)

Classes:

Name Description
PredictivePosterior

Protocol for predictive distributions p(f(X*) | data) at test stimuli.

WPPMPredictivePosterior

Predictive posterior for WPPM models.

PredictivePosterior

Bases: Protocol

Protocol for predictive distributions p(f(X*) | data) at test stimuli.

Returned by Model.posterior(X) for use in acquisition functions.

Methods:

Name Description
cov_field

Posterior over perceptual covariance field Σ(X).

rsample

Reparameterized samples from p(f(X*) | data).

Attributes:

Name Type Description
mean ndarray

Posterior predictive mean E[f(X*) | data].

variance ndarray

Posterior predictive marginal variances Var[f(X*) | data].

mean

mean: ndarray

Posterior predictive mean E[f(X*) | data].

Returns:

Type Description
ndarray

Shape (n_test,) for scalar outputs Shape (n_test, output_dim) for vector outputs (future)

Notes

Computed via Monte Carlo integration over parameter posterior.

variance

variance: ndarray

Posterior predictive marginal variances Var[f(X*) | data].

Returns:

Type Description
ndarray

Shape (n_test,) for scalar outputs Shape (n_test, output_dim) for vector outputs (future)

Notes

Captures both aleatoric (model) and epistemic (parameter) uncertainty.

cov_field

cov_field(X: ndarray) -> ndarray

Posterior over perceptual covariance field Σ(X).

Parameters:

Name Type Description Default
X ndarray

Test stimuli, shape (n_test, input_dim)

required

Returns:

Type Description
ndarray

Posterior mean covariance E[Σ(X) | data], shape (n_test, input_dim, input_dim)

Notes

WPPM-specific method for visualizing perceptual noise structure. This is NOT the predictive covariance - it's the model's internal representation of perceptual uncertainty.

Source code in src/psyphy/posterior/predictive_posterior.py
def cov_field(self, X: jnp.ndarray) -> jnp.ndarray:
    """
    Posterior over perceptual covariance field Σ(X).

    Parameters
    ----------
    X : jnp.ndarray
        Test stimuli, shape (n_test, input_dim)

    Returns
    -------
    jnp.ndarray
        Posterior mean covariance E[Σ(X) | data],
        shape (n_test, input_dim, input_dim)

    Notes
    -----
    WPPM-specific method for visualizing perceptual noise structure.
    This is NOT the predictive covariance - it's the model's
    internal representation of perceptual uncertainty.
    """
    ...

rsample

rsample(
    sample_shape: tuple = (), *, key: KeyArray
) -> ndarray

Reparameterized samples from p(f(X*) | data).

Parameters:

Name Type Description Default
sample_shape tuple

Shape of sample batch

()
key KeyArray

PRNG key

required

Returns:

Type Description
ndarray

Shape (sample_shape, n_test) for scalar outputs Shape (sample_shape, n_test, output_dim) for vector outputs

Notes

Enables gradient-based acquisition optimization via reparameterization trick.

Source code in src/psyphy/posterior/predictive_posterior.py
def rsample(self, sample_shape: tuple = (), *, key: jr.KeyArray) -> jnp.ndarray:
    """
    Reparameterized samples from p(f(X*) | data).

    Parameters
    ----------
    sample_shape : tuple, default=()
        Shape of sample batch
    key : jax.random.KeyArray
        PRNG key

    Returns
    -------
    jnp.ndarray
        Shape (*sample_shape, n_test) for scalar outputs
        Shape (*sample_shape, n_test, output_dim) for vector outputs

    Notes
    -----
    Enables gradient-based acquisition optimization via reparameterization trick.
    """
    ...

WPPMPredictivePosterior

WPPMPredictivePosterior(
    param_posterior: ParameterPosterior,
    X: ndarray,
    probes: ndarray | None = None,
    n_samples: int = 100,
)

Predictive posterior for WPPM models.

Computes p(f(X*) | data) via Monte Carlo integration over parameter posterior p(θ | data).

Parameters:

Name Type Description Default
param_posterior ParameterPosterior

Posterior over model parameters

required
X (ndarray, shape(n_test, input_dim))

Test reference stimuli

required
probes (ndarray, shape(n_test, input_dim))

Test probe stimuli. If None, predictions are over thresholds.

None
n_samples int

Number of posterior samples for MC integration

100

Attributes:

Name Type Description
param_posterior ParameterPosterior

Wrapped parameter posterior

X ndarray

Test stimuli

probes ndarray | None

Test probes

n_samples int

MC sample count

Notes

Uses lazy evaluation: moments computed on first access.

Methods:

Name Description
cov_field

Posterior mean covariance field E[Σ(X) | data].

rsample

Sample predictions from p(f(X*) | data).

Source code in src/psyphy/posterior/predictive_posterior.py
def __init__(
    self,
    param_posterior: ParameterPosterior,
    X: jnp.ndarray,
    probes: jnp.ndarray | None = None,
    n_samples: int = 100,
):
    self.param_posterior = param_posterior
    self.X = X
    self.probes = probes
    self.n_samples = n_samples

    # Lazy evaluation cache
    self._mean = None
    self._variance = None
    self._computed = False

X

X = X

mean

mean: ndarray

E[f(X*) | data], shape (n_test,).

n_samples

n_samples = n_samples

param_posterior

param_posterior = param_posterior

probes

probes = probes

variance

variance: ndarray

Var[f(X*) | data], shape (n_test,).

cov_field

cov_field(X: ndarray) -> ndarray

Posterior mean covariance field E[Σ(X) | data].

Parameters:

Name Type Description Default
X ndarray

Test stimuli, shape (n_test, input_dim)

required

Returns:

Type Description
ndarray

Covariance matrices, shape (n_test, input_dim, input_dim)

Notes

Averages local_covariance(x) over parameter posterior samples.

Source code in src/psyphy/posterior/predictive_posterior.py
def cov_field(self, X: jnp.ndarray) -> jnp.ndarray:
    """
    Posterior mean covariance field E[Σ(X) | data].

    Parameters
    ----------
    X : jnp.ndarray
        Test stimuli, shape (n_test, input_dim)

    Returns
    -------
    jnp.ndarray
        Covariance matrices, shape (n_test, input_dim, input_dim)

    Notes
    -----
    Averages local_covariance(x) over parameter posterior samples.
    """
    key = jr.PRNGKey(0)
    param_samples = self.param_posterior.sample(self.n_samples, key=key)

    model = self.param_posterior.model

    def cov_at_x(params, x):
        """Evaluate Σ(x) with given parameters."""
        return model.local_covariance(params, x)

    # Vectorized evaluation: (n_samples, n_test, input_dim, input_dim)
    cov_samples = jax.vmap(
        lambda params: jax.vmap(lambda x: cov_at_x(params, x))(X)
    )(param_samples)

    # Return posterior mean
    return jnp.mean(cov_samples, axis=0)

rsample

rsample(
    sample_shape: tuple = (), *, key: KeyArray
) -> ndarray

Sample predictions from p(f(X*) | data).

Parameters:

Name Type Description Default
sample_shape tuple

Batch shape

()
key KeyArray

PRNG key

required

Returns:

Type Description
ndarray

Shape (*sample_shape, n_test)

Source code in src/psyphy/posterior/predictive_posterior.py
def rsample(self, sample_shape: tuple = (), *, key: jr.KeyArray) -> jnp.ndarray:
    """
    Sample predictions from p(f(X*) | data).

    Parameters
    ----------
    sample_shape : tuple
        Batch shape
    key : jax.random.KeyArray
        PRNG key

    Returns
    -------
    jnp.ndarray
        Shape (*sample_shape, n_test)
    """
    n = int(jnp.prod(jnp.array(sample_shape))) if sample_shape else 1
    param_samples = self.param_posterior.sample(n, key=key)

    model = self.param_posterior.model

    if self.probes is None:
        raise NotImplementedError("Threshold sampling not yet implemented")

    def predict_one(params):
        """Predict for all test points with given params."""
        return jax.vmap(lambda r, p: model.predict_prob(params, (r, p)))(
            self.X, self.probes
        )

    samples = jax.vmap(predict_one)(param_samples)

    if sample_shape:
        return samples.reshape(*sample_shape, -1)
    return samples