Skip to content

Utils


Package


utils

utils

Shared utility functions and helpers for psyphy.

This subpackage provides: - math : mathematical utilities (currently: basis functions, which may get their own module).

Functions:

Name Description
chebyshev_basis

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

chebyshev_basis

chebyshev_basis(x: ndarray, degree: int) -> ndarray

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

Parameters:

Name Type Description Default
x ndarray

Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].

required
degree int

Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

required

Returns:

Type Description
ndarray

Array of shape (N, degree + 1) where column j contains T_j(x).

Raises:

Type Description
ValueError

If degree is negative or x is not 1-D.

Notes

Uses the three-term recurrence: T_0(x) = 1 T_1(x) = x T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x) The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

Examples:

1
2
3
>>> import jax.numpy as jnp
>>> x = jnp.linspace(-1, 1, 5)
>>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
Source code in src/psyphy/utils/math.py
def chebyshev_basis(x: jnp.ndarray, degree: int) -> jnp.ndarray:
    """
    Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

    Parameters
    ----------
    x : jnp.ndarray
        Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].
    degree : int
        Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

    Returns
    -------
    jnp.ndarray
        Array of shape (N, degree + 1) where column j contains T_j(x).

    Raises
    ------
    ValueError
        If `degree` is negative or `x` is not 1-D.

    Notes
    -----
    Uses the three-term recurrence:
        T_0(x) = 1
        T_1(x) = x
        T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x)
    The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> x = jnp.linspace(-1, 1, 5)
    >>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
    """
    if degree < 0:
        raise ValueError("degree must be >= 0")
    if x.ndim != 1:
        raise ValueError("x must be 1-D (shape (N,))")

    # Ensure a floating dtype (Chebyshev recurrences are polynomial in x)
    x = x.astype(jnp.result_type(x, 0.0))

    N = x.shape[0]

    # Handle small degrees explicitly.
    if degree == 0:
        return jnp.ones((N, 1), dtype=x.dtype)
    if degree == 1:
        return jnp.stack([jnp.ones_like(x), x], axis=1)

    # Initialize T0 and T1 columns.
    T0 = jnp.ones_like(x)
    T1 = x

    # Scan to generate T2..T_degree in a JIT-friendly way (avoids Python-side loops).
    def step(carry, _):
        # compute next Chebyshev polynomial
        Tm1, Tm = carry
        Tnext = 2.0 * x * Tm - Tm1
        return (Tm, Tnext), Tnext  # new carry, plus an output to collect

    # Jax friendly loop
    (final_Tm1_ignored, final_Tm_ignored), Ts = lax.scan(
        step, (T0, T1), xs=None, length=degree - 1
    )
    # Ts has shape (degree-1, N) and holds [T2, T3, ..., T_degree]
    B = jnp.concatenate([T0[:, None], T1[:, None], jnp.swapaxes(Ts, 0, 1)], axis=1)
    return B

Math


math

math.py

Math utilities for psyphy.

Includes: - chebyshev_basis : compute Chebyshev polynomial basis. - mahalanobis_distance : discriminability metric used in WPPM MVP. - rbf_kernel : kernel function, useful in Full WPPM mode covariance priors.

All functions use JAX (jax.numpy) for compatibility with autodiff.

Notes
  • math.chebyshev_basis is relevant when implementing Full WPPM mode, where covariance fields are expressed in a basis expansion.
  • math.mahalanobis_distance is directly used in WPPM MVP discriminability.
  • math.rbf_kernel is a placeholder for Gaussian-process-style covariance priors.

Examples:

1
2
3
4
5
>>> import jax.numpy as jnp
>>> from psyphy.utils import math
>>> x = jnp.linspace(-1, 1, 5)
>>> math.chebyshev_basis(x, degree=3).shape
(5, 4)

Functions:

Name Description
chebyshev_basis

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

chebyshev_basis

chebyshev_basis(x: ndarray, degree: int) -> ndarray

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

Parameters:

Name Type Description Default
x ndarray

Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].

required
degree int

Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

required

Returns:

Type Description
ndarray

Array of shape (N, degree + 1) where column j contains T_j(x).

Raises:

Type Description
ValueError

If degree is negative or x is not 1-D.

Notes

Uses the three-term recurrence: T_0(x) = 1 T_1(x) = x T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x) The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

Examples:

1
2
3
>>> import jax.numpy as jnp
>>> x = jnp.linspace(-1, 1, 5)
>>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
Source code in src/psyphy/utils/math.py
def chebyshev_basis(x: jnp.ndarray, degree: int) -> jnp.ndarray:
    """
    Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

    Parameters
    ----------
    x : jnp.ndarray
        Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].
    degree : int
        Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

    Returns
    -------
    jnp.ndarray
        Array of shape (N, degree + 1) where column j contains T_j(x).

    Raises
    ------
    ValueError
        If `degree` is negative or `x` is not 1-D.

    Notes
    -----
    Uses the three-term recurrence:
        T_0(x) = 1
        T_1(x) = x
        T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x)
    The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> x = jnp.linspace(-1, 1, 5)
    >>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
    """
    if degree < 0:
        raise ValueError("degree must be >= 0")
    if x.ndim != 1:
        raise ValueError("x must be 1-D (shape (N,))")

    # Ensure a floating dtype (Chebyshev recurrences are polynomial in x)
    x = x.astype(jnp.result_type(x, 0.0))

    N = x.shape[0]

    # Handle small degrees explicitly.
    if degree == 0:
        return jnp.ones((N, 1), dtype=x.dtype)
    if degree == 1:
        return jnp.stack([jnp.ones_like(x), x], axis=1)

    # Initialize T0 and T1 columns.
    T0 = jnp.ones_like(x)
    T1 = x

    # Scan to generate T2..T_degree in a JIT-friendly way (avoids Python-side loops).
    def step(carry, _):
        # compute next Chebyshev polynomial
        Tm1, Tm = carry
        Tnext = 2.0 * x * Tm - Tm1
        return (Tm, Tnext), Tnext  # new carry, plus an output to collect

    # Jax friendly loop
    (final_Tm1_ignored, final_Tm_ignored), Ts = lax.scan(
        step, (T0, T1), xs=None, length=degree - 1
    )
    # Ts has shape (degree-1, N) and holds [T2, T3, ..., T_degree]
    B = jnp.concatenate([T0[:, None], T1[:, None], jnp.swapaxes(Ts, 0, 1)], axis=1)
    return B