Loss functions measure how well model parameters explain observed data. jaxcmr provides several loss function families, each suited to different fitting goals.
What is a Loss Function?
A loss function maps model parameters \(\theta\) to a scalar measuring the discrepancy between model predictions and observed data:
\[L: \theta \mapsto \mathbb{R}\]
Lower values indicate better fit. Optimization algorithms search for parameters that minimize the loss.
The LossFnGenerator Pattern
In jaxcmr, loss functions are created via generators that implement the LossFnGenerator protocol:
Code
# Initialize with dataset and model factorygenerator = MemorySearchLikelihoodFnGenerator( model_create_fn, dataset, features)# Generate specialized loss function for specific trialsloss_fn = generator(trial_indices, base_params, free_param_names)# Evaluate lossloss_value = loss_fn(parameter_array)
This pattern allows: - Trial-specific loss computation - Efficient JIT compilation - Vectorized evaluation for population-based optimizers
Use Sequence Likelihood (loss.sequence_likelihood): - Captures full temporal dynamics - Fits to contiguity, primacy, recency - Most informative when order is recorded
For Unordered Recall Data
Use Set Permutation Likelihood (loss.set_permutation_likelihood): - When only the recall set is recorded (not order) - Final free recall (FFR) paradigms - Monte Carlo estimation handles factorial complexity
For Summary Statistics
Use SPC MSE (loss.spc_mse) or Category SPC MSE (loss.cat_spc_mse): - When fitting to aggregate curves - When likelihood is computationally prohibitive - When shape matching is more important than exact probability
For Selective Fitting
Use Transform variants with masking: - ExcludeFirstRecallLikelihoodFnGenerator: Focus on transitions (not initiation) - ExcludeTerminationLikelihoodFnGenerator: Ignore stopping behavior
Computational Considerations
JAX Compilation
Loss functions are designed for JAX’s JIT compilation: - First evaluation triggers compilation (slow) - Subsequent evaluations are fast - Vectorized evaluation supports population-based optimizers
Vectorized Evaluation
For differential evolution, loss functions accept batched parameters: