Loss Functions

Objective functions for model fitting

Loss functions measure how well model parameters explain observed data. jaxcmr provides several loss function families, each suited to different fitting goals.

What is a Loss Function?

A loss function maps model parameters \(\theta\) to a scalar measuring the discrepancy between model predictions and observed data:

\[L: \theta \mapsto \mathbb{R}\]

Lower values indicate better fit. Optimization algorithms search for parameters that minimize the loss.

The LossFnGenerator Pattern

In jaxcmr, loss functions are created via generators that implement the LossFnGenerator protocol:

Code
# Initialize with dataset and model factory
generator = MemorySearchLikelihoodFnGenerator(
    model_create_fn, dataset, features
)

# Generate specialized loss function for specific trials
loss_fn = generator(trial_indices, base_params, free_param_names)

# Evaluate loss
loss_value = loss_fn(parameter_array)

This pattern allows: - Trial-specific loss computation - Efficient JIT compilation - Vectorized evaluation for population-based optimizers

Loss Function Families

Family Mathematical Form When to Use
Likelihood \(-\sum_t \log P(r_t \mid \theta)\) Full sequence information, MLE
Permutation \(-\sum_t \log \mathbb{E}_{\pi}[P(\pi(r_t) \mid \theta)]\) Order-agnostic recall
MSE \(\sum_i (y_i^{obs} - y_i^{sim})^2\) Summary statistics, shape matching

Likelihood-Based

Sequence Likelihood treats each recall event as a probabilistic choice:

\[L(\theta) = -\sum_t \sum_{j=1}^{|R_t|} \log P(r_{t,j} \mid r_{t,1:j-1}, \theta)\]

  • Uses full temporal information (order, transitions)
  • Maximum likelihood estimation (MLE)
  • Captures contiguity, primacy, recency

Permutation-Based

Set Permutation Likelihood marginalizes over recall orderings:

\[L(\theta) = -\sum_t \log \frac{1}{N} \sum_{i=1}^{N} \prod_{j} P(\pi_i(r_j) \mid \pi_i(r_{1:j-1}), \theta)\]

  • Treats recall as an unordered set
  • Monte Carlo estimation via permutation sampling
  • For final free recall (FFR) or when order is not recorded

MSE-Based

Summary Statistic MSE compares observed vs simulated analyses:

\[L(\theta) = \frac{1}{n} \sum_i \left( y_i^{obs} - y_i^{sim}(\theta) \right)^2\]

  • Fits to aggregate patterns (SPC, CRP)
  • Requires stochastic simulation
  • Useful for shape matching without full likelihood

All Generators

Generator Class Module Description
MemorySearchLikelihoodFnGenerator loss.sequence_likelihood Standard sequence likelihood
MemorySearchLikelihoodFnGenerator loss.set_permutation_likelihood Monte Carlo permutation likelihood
MemorySearchLikelihoodFnGenerator loss.transform_sequence_likelihood Likelihood with custom masking
ExcludeFirstRecallLikelihoodFnGenerator loss.transform_sequence_likelihood Masks first recall event
ExcludeTerminationLikelihoodFnGenerator loss.transform_sequence_likelihood Masks stop decisions
MemorySearchLikelihoodFnGenerator loss.base_sequence_likelihood Legacy: reuses single context
MemorySearchSpcMseFnGenerator loss.spc_mse Serial position curve MSE
MemorySearchCatSpcMseFnGenerator loss.cat_spc_mse Category-filtered SPC MSE
MemorySearchMseFnGenerator experimental.mse_loss General MSE with pluggable analysis

Choosing a Loss Function

For Standard Free Recall

Use Sequence Likelihood (loss.sequence_likelihood): - Captures full temporal dynamics - Fits to contiguity, primacy, recency - Most informative when order is recorded

For Unordered Recall Data

Use Set Permutation Likelihood (loss.set_permutation_likelihood): - When only the recall set is recorded (not order) - Final free recall (FFR) paradigms - Monte Carlo estimation handles factorial complexity

For Summary Statistics

Use SPC MSE (loss.spc_mse) or Category SPC MSE (loss.cat_spc_mse): - When fitting to aggregate curves - When likelihood is computationally prohibitive - When shape matching is more important than exact probability

For Selective Fitting

Use Transform variants with masking: - ExcludeFirstRecallLikelihoodFnGenerator: Focus on transitions (not initiation) - ExcludeTerminationLikelihoodFnGenerator: Ignore stopping behavior

Computational Considerations

JAX Compilation

Loss functions are designed for JAX’s JIT compilation: - First evaluation triggers compilation (slow) - Subsequent evaluations are fast - Vectorized evaluation supports population-based optimizers

Vectorized Evaluation

For differential evolution, loss functions accept batched parameters:

Code
# Single evaluation
loss = loss_fn(params)  # params: (n_params,) -> scalar

# Batched evaluation
losses = loss_fn(params_batch)  # params: (n_params, pop_size) -> (pop_size,)

Simulation Count

MSE-based losses require multiple simulations per evaluation:

Loss simulation_count Notes
Set Permutation 50 Permutation samples
SPC MSE 20 Recall chain samples
Category SPC MSE 10 Category-filtered samples

Higher counts reduce variance but increase computation time.

Usage Example

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.models.cmr import make_factory

# Standard likelihood fitting
fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params=fixed_params,
    model_factory=make_factory(...),
    loss_fn_generator=MemorySearchLikelihoodFnGenerator,
    hyperparams={"bounds": free_params, "num_steps": 1000},
)

results = fitter.fit(trial_mask)
Code
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator

# SPC MSE fitting
fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params=fixed_params,
    model_factory=make_factory(...),
    loss_fn_generator=MemorySearchSpcMseFnGenerator,
    hyperparams={"bounds": free_params, "num_steps": 1000},
)

results = fitter.fit(trial_mask)