Evaluation Protocols

Interfaces for loss functions and fitting algorithms

jaxcmr uses protocol-based design for evaluation components. This allows different loss functions and fitting algorithms to be swapped while maintaining compatibility with the rest of the library.

Why Protocols?

Protocols define interfaces that components must satisfy, enabling:

  • Modularity: Swap loss functions without changing fitting code
  • Extensibility: Implement custom losses or optimizers
  • Type safety: IDE support and static analysis
  • Documentation: Clear contracts for each component

LossFnGenerator Protocol

The LossFnGenerator protocol defines how loss functions are created for model fitting.

Definition

Code
from jaxcmr.typing import LossFnGenerator

@runtime_checkable
class LossFnGenerator(Protocol):
    """Generates loss function for model fitting."""

    def __init__(
        self,
        model_create_fn: MemorySearchCreateFn,
        dataset: RecallDataset,
        features: Optional[Float[Array, " word_pool_items features_count"]],
    ) -> None:
        """Initialize the factory with the specified trials and trial data."""

    def __call__(
        self,
        trial_indices: Integer[Array, " trials"],
        base_params: Mapping[str, Float_],
        free_param_names: Iterable[str],
    ) -> Callable[[np.ndarray], Float[Array, ""]]:
        """Return the loss value for the specified model parameters."""

Initialization

The generator is initialized with:

Parameter Type Description
model_create_fn MemorySearchCreateFn Factory function to create model instances
dataset RecallDataset Trial data (presentations, recalls, subjects)
features Optional[Array] Semantic embeddings or None

Generating Loss Functions

When called, the generator returns a specialized loss function:

Input parameters: - trial_indices: Which trials to include in the loss - base_params: Fixed parameter values - free_param_names: Names of parameters being optimized

Returns: A callable loss_fn(params: np.ndarray) -> scalar

The returned function: 1. Takes a 1D array of parameter values (in order of free_param_names) 2. Returns a scalar loss value (lower is better) 3. Supports vectorized evaluation for population-based optimizers

Mathematical Role

The loss function maps parameters \(\theta\) to a scalar objective:

\[L: \theta \mapsto \mathbb{R}\]

For likelihood-based losses: \[L(\theta) = -\sum_t \log P(\text{data}_t \mid \theta)\]

For MSE-based losses: \[L(\theta) = \frac{1}{n}\sum_i (y_i^{obs} - y_i^{sim}(\theta))^2\]

FitResult TypedDict

Fitting algorithms return a FitResult containing the optimization results.

Definition

Code
from jaxcmr.typing import FitResult

class FitResult(TypedDict):
    fixed: dict[str, float]
    """Dictionary of fixed parameters and their values."""

    free: dict[str, list[float]]
    """Dictionary of free parameters and their [lower_bound, upper_bound]."""

    fitness: list[float]
    """List of fitness values (one per subject or single global fit)."""

    fits: dict[str, list[float]]
    """Dictionary of parameter names -> optimized values."""

    hyperparameters: dict[str, Any]
    """Dictionary of hyperparameters used during fitting."""

    fit_time: float
    """Total time (in seconds) taken to perform the fitting."""

Structure

Field Type Description
fixed dict[str, float] Parameters held constant
free dict[str, list[float]] Parameter bounds [lo, hi]
fitness list[float] Loss values at optimum
fits dict[str, list[float]] Fitted parameter values
hyperparameters dict Optimizer settings used
fit_time float Wall-clock time in seconds

Per-Subject vs Pooled

When using fit_subjects(): - fitness has one value per subject - fits["param"] has one value per subject - fits["subject"] lists subject IDs

When using fit() (pooled or single-subject): - fitness has a single value - fits["param"] has a single value

FittingAlgorithm Protocol

The FittingAlgorithm protocol defines how parameter optimization is performed.

Definition

Code
from jaxcmr.typing import FittingAlgorithm

@runtime_checkable
class FittingAlgorithm(Protocol):
    """Protocol describing a fitting algorithm for memory search models."""

    def __init__(
        self,
        dataset: RecallDataset,
        features: Optional[Float[Array, " word_pool_items features_count"]],
        base_params: Mapping[str, Float_],
        model_factory: Type[MemorySearchModelFactory],
        loss_fn_generator: Type[LossFnGenerator],
        hyperparams: Optional[dict[str, Any]] = None,
    ):
        """Configure the fitting algorithm."""

    def fit(
        self,
        trial_mask: Bool[Array, " trials"],
        subject_id: int = -1,
    ) -> FitResult:
        """Fit one parameter set to the trials selected by the mask."""

    def fit_subjects(
        self,
        trial_mask: Bool[Array, " trials"],
    ) -> FitResult:
        """Fit each subject independently and accumulate results."""

Initialization

Parameter Type Description
dataset RecallDataset Trial data including subject IDs
features Optional[Array] Semantic embeddings or None
base_params Mapping[str, float] Fixed parameter values
model_factory Type[MemorySearchModelFactory] Model factory class
loss_fn_generator Type[LossFnGenerator] Loss generator class
hyperparams Optional[dict] Optimizer-specific settings

The fit() Method

Code
def fit(
    self,
    trial_mask: Bool[Array, " trials"],
    subject_id: int = -1,
) -> FitResult:

def fit_subjects(
    self,
    trial_mask: Bool[Array, " trials"],
) -> FitResult:

fit()

Parameter Type Description
trial_mask Bool[Array] Which trials to include
subject_id int Label stored in result (default -1)

Returns: FitResult with fitted parameters and diagnostics.

fit_subjects()

Parameter Type Description
trial_mask Bool[Array] Which trials to include

Returns: FitResult with per-subject fitted parameters.

MemorySearchModelFactory Protocol

Model factories create model instances for specific trials.

Definition

Code
@runtime_checkable
class MemorySearchModelFactory(Protocol):
    def __init__(
        self,
        dataset: RecallDataset,
        features: Optional[Float[Array, " word_pool_items features_count"]],
    ) -> None:
        """Initialize the factory with trial data."""

    def create_model(
        self,
        parameters: Mapping[str, Float_],
    ) -> MemorySearch:
        """Create a generic model instance."""

    def create_trial_model(
        self,
        trial_index: Integer[Array, ""],
        parameters: Mapping[str, Float_],
    ) -> MemorySearch:
        """Create a model configured for a specific trial."""

The create_trial_model method allows trial-specific configuration (e.g., different list lengths, trial-specific features like EEG data).

Implementing Custom Components

Custom Loss Function

Code
from jaxcmr.typing import LossFnGenerator, MemorySearchCreateFn, RecallDataset
from typing import Callable, Mapping, Iterable, Optional
import numpy as np
from jaxtyping import Float, Integer, Array

class CustomLossFnGenerator:
    """Example custom loss function generator."""

    def __init__(
        self,
        model_create_fn: MemorySearchCreateFn,
        dataset: RecallDataset,
        features: Optional[Float[Array, " items features"]],
    ):
        self.model_create_fn = model_create_fn
        self.dataset = dataset
        self.features = features

    def __call__(
        self,
        trial_indices: Integer[Array, " trials"],
        base_params: Mapping[str, float],
        free_param_names: Iterable[str],
    ) -> Callable[[np.ndarray], float]:
        free_names = list(free_param_names)

        def loss_fn(param_values: np.ndarray) -> float:
            # Build full parameter dict
            params = dict(base_params)
            for name, value in zip(free_names, param_values):
                params[name] = value

            # Compute loss over trials
            total_loss = 0.0
            for idx in trial_indices:
                model = self.model_create_fn(idx, params)
                # ... compute trial-specific loss ...
                total_loss += trial_loss

            return total_loss

        return loss_fn

Custom Fitting Algorithm

Code
from jaxcmr.typing import FittingAlgorithm, FitResult, RecallDataset
from typing import Type, Mapping, Optional, Any
import time

class CustomFitter:
    """Example custom fitting algorithm."""

    def __init__(
        self,
        dataset: RecallDataset,
        features: Optional[...],
        base_params: Mapping[str, float],
        model_factory: Type[...],
        loss_fn_generator: Type[...],
        hyperparams: Optional[dict[str, Any]] = None,
    ):
        self.dataset = dataset
        self.features = features
        self.base_params = base_params
        self.model_factory = model_factory
        self.loss_fn_generator = loss_fn_generator
        self.hyperparams = hyperparams or {}

    def fit(
        self,
        trial_mask,
        subject_id: int = -1,
    ) -> FitResult:
        start_time = time.time()

        # Initialize loss generator
        factory = self.model_factory(self.dataset, self.features)
        loss_gen = self.loss_fn_generator(
            factory.create_trial_model,
            self.dataset,
            self.features,
        )

        # Run optimization
        # ... custom optimization logic ...

        return FitResult(
            fixed=dict(self.base_params),
            free=self.hyperparams.get("bounds", {}),
            fitness=[final_loss],
            fits={"param_name": [optimal_value], "subject": [subject_id]},
            hyperparameters=self.hyperparams,
            fit_time=time.time() - start_time,
        )

    def fit_subjects(
        self,
        trial_mask,
    ) -> FitResult:
        # Loop over subjects, call self.fit() per subject, accumulate
        ...

Usage Example

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.models.cmr import make_factory
from jaxcmr.components.linear_memory import init_mfc, init_mcf
from jaxcmr.components.context import init as init_context
from jaxcmr.components.termination import PositionalTermination

# Create model factory
model_factory = make_factory(
    init_mfc, init_mcf, init_context, PositionalTermination
)

# Configure fitter with loss generator
fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params={"allow_repeated_recalls": False},
    model_factory=model_factory,
    loss_fn_generator=MemorySearchLikelihoodFnGenerator,  # Protocol implementation
    hyperparams={
        "bounds": {"encoding_drift_rate": [0.0, 1.0], ...},
        "num_steps": 1000,
    },
)

# Run fitting
results: FitResult = fitter.fit_subjects(trial_mask)

Available Implementations

Loss Functions

Implementation Module Description
MemorySearchLikelihoodFnGenerator loss.sequence_likelihood Standard sequence likelihood
MemorySearchLikelihoodFnGenerator loss.set_permutation_likelihood Monte Carlo for unordered recall
ExcludeFirstRecallLikelihoodFnGenerator loss.transform_sequence_likelihood Masks first recall event
ExcludeTerminationLikelihoodFnGenerator loss.transform_sequence_likelihood Masks stop decisions
MemorySearchSpcMseFnGenerator loss.spc_mse Serial position curve MSE
MemorySearchCatSpcMseFnGenerator loss.cat_spc_mse Category-filtered SPC MSE

Fitting Algorithms

Implementation Module Description
ScipyDE jaxcmr.fitting Differential evolution via SciPy