Interfaces for loss functions and fitting algorithms
jaxcmr uses protocol-based design for evaluation components. This allows different loss functions and fitting algorithms to be swapped while maintaining compatibility with the rest of the library.
Why Protocols?
Protocols define interfaces that components must satisfy, enabling:
Modularity: Swap loss functions without changing fitting code
Extensibility: Implement custom losses or optimizers
Type safety: IDE support and static analysis
Documentation: Clear contracts for each component
LossFnGenerator Protocol
The LossFnGenerator protocol defines how loss functions are created for model fitting.
Definition
Code
from jaxcmr.typing import LossFnGenerator@runtime_checkableclass LossFnGenerator(Protocol):"""Generates loss function for model fitting."""def__init__(self, model_create_fn: MemorySearchCreateFn, dataset: RecallDataset, features: Optional[Float[Array, " word_pool_items features_count"]], ) ->None:"""Initialize the factory with the specified trials and trial data."""def__call__(self, trial_indices: Integer[Array, " trials"], base_params: Mapping[str, Float_], free_param_names: Iterable[str], ) -> Callable[[np.ndarray], Float[Array, ""]]:"""Return the loss value for the specified model parameters."""
Initialization
The generator is initialized with:
Parameter
Type
Description
model_create_fn
MemorySearchCreateFn
Factory function to create model instances
dataset
RecallDataset
Trial data (presentations, recalls, subjects)
features
Optional[Array]
Semantic embeddings or None
Generating Loss Functions
When called, the generator returns a specialized loss function:
Input parameters: - trial_indices: Which trials to include in the loss - base_params: Fixed parameter values - free_param_names: Names of parameters being optimized
Returns: A callable loss_fn(params: np.ndarray) -> scalar
The returned function: 1. Takes a 1D array of parameter values (in order of free_param_names) 2. Returns a scalar loss value (lower is better) 3. Supports vectorized evaluation for population-based optimizers
Mathematical Role
The loss function maps parameters \(\theta\) to a scalar objective:
\[L: \theta \mapsto \mathbb{R}\]
For likelihood-based losses: \[L(\theta) = -\sum_t \log P(\text{data}_t \mid \theta)\]
For MSE-based losses: \[L(\theta) = \frac{1}{n}\sum_i (y_i^{obs} - y_i^{sim}(\theta))^2\]
FitResult TypedDict
Fitting algorithms return a FitResult containing the optimization results.
Definition
Code
from jaxcmr.typing import FitResultclass FitResult(TypedDict): fixed: dict[str, float]"""Dictionary of fixed parameters and their values.""" free: dict[str, list[float]]"""Dictionary of free parameters and their [lower_bound, upper_bound].""" fitness: list[float]"""List of fitness values (one per subject or single global fit).""" fits: dict[str, list[float]]"""Dictionary of parameter names -> optimized values.""" hyperparameters: dict[str, Any]"""Dictionary of hyperparameters used during fitting.""" fit_time: float"""Total time (in seconds) taken to perform the fitting."""
Structure
Field
Type
Description
fixed
dict[str, float]
Parameters held constant
free
dict[str, list[float]]
Parameter bounds [lo, hi]
fitness
list[float]
Loss values at optimum
fits
dict[str, list[float]]
Fitted parameter values
hyperparameters
dict
Optimizer settings used
fit_time
float
Wall-clock time in seconds
Per-Subject vs Pooled
When using fit_subjects(): - fitness has one value per subject - fits["param"] has one value per subject - fits["subject"] lists subject IDs
When using fit() (pooled or single-subject): - fitness has a single value - fits["param"] has a single value
FittingAlgorithm Protocol
The FittingAlgorithm protocol defines how parameter optimization is performed.
Definition
Code
from jaxcmr.typing import FittingAlgorithm@runtime_checkableclass FittingAlgorithm(Protocol):"""Protocol describing a fitting algorithm for memory search models."""def__init__(self, dataset: RecallDataset, features: Optional[Float[Array, " word_pool_items features_count"]], base_params: Mapping[str, Float_], model_factory: Type[MemorySearchModelFactory], loss_fn_generator: Type[LossFnGenerator], hyperparams: Optional[dict[str, Any]] =None, ):"""Configure the fitting algorithm."""def fit(self, trial_mask: Bool[Array, " trials"], subject_id: int=-1, ) -> FitResult:"""Fit one parameter set to the trials selected by the mask."""def fit_subjects(self, trial_mask: Bool[Array, " trials"], ) -> FitResult:"""Fit each subject independently and accumulate results."""
Returns:FitResult with fitted parameters and diagnostics.
fit_subjects()
Parameter
Type
Description
trial_mask
Bool[Array]
Which trials to include
Returns:FitResult with per-subject fitted parameters.
MemorySearchModelFactory Protocol
Model factories create model instances for specific trials.
Definition
Code
@runtime_checkableclass MemorySearchModelFactory(Protocol):def__init__(self, dataset: RecallDataset, features: Optional[Float[Array, " word_pool_items features_count"]], ) ->None:"""Initialize the factory with trial data."""def create_model(self, parameters: Mapping[str, Float_], ) -> MemorySearch:"""Create a generic model instance."""def create_trial_model(self, trial_index: Integer[Array, ""], parameters: Mapping[str, Float_], ) -> MemorySearch:"""Create a model configured for a specific trial."""
The create_trial_model method allows trial-specific configuration (e.g., different list lengths, trial-specific features like EEG data).
Implementing Custom Components
Custom Loss Function
Code
from jaxcmr.typing import LossFnGenerator, MemorySearchCreateFn, RecallDatasetfrom typing import Callable, Mapping, Iterable, Optionalimport numpy as npfrom jaxtyping import Float, Integer, Arrayclass CustomLossFnGenerator:"""Example custom loss function generator."""def__init__(self, model_create_fn: MemorySearchCreateFn, dataset: RecallDataset, features: Optional[Float[Array, " items features"]], ):self.model_create_fn = model_create_fnself.dataset = datasetself.features = featuresdef__call__(self, trial_indices: Integer[Array, " trials"], base_params: Mapping[str, float], free_param_names: Iterable[str], ) -> Callable[[np.ndarray], float]: free_names =list(free_param_names)def loss_fn(param_values: np.ndarray) ->float:# Build full parameter dict params =dict(base_params)for name, value inzip(free_names, param_values): params[name] = value# Compute loss over trials total_loss =0.0for idx in trial_indices: model =self.model_create_fn(idx, params)# ... compute trial-specific loss ... total_loss += trial_lossreturn total_lossreturn loss_fn