Evaluation Pipeline

Fitting and simulation flow

Model evaluation connects loss functions, fitting algorithms, and simulation.

Pipeline Overview

Dataset → LossFnGenerator → Loss Function → FittingAlgorithm → FitResult
                                ↑
                          TrialSimulator

The evaluation pipeline consists of:

  1. Data: RecallDataset with study/recall sequences

  2. Loss Function: Computes model-data mismatch (likelihood, MSE, etc.)

  3. Fitting Algorithm: Optimizes parameters to minimize loss

  4. Simulation: Generates model predictions for comparison

End-to-End Example

Here’s a complete fitting workflow using real data:

Code
import h5pyimport numpy as npfrom jax import numpy as jnp# Load datawith h5py.File("../../data/HealeyKahana2014.h5", "r") as f:    dataset = {key: f[key][:] for key in f.keys()}print(f"Loaded {dataset['subject'].shape[0]} trials")print(f"Subjects: {np.unique(dataset['subject'])}")
Code
from jaxcmr.models.cmr import CMR, make_factory
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.fitting import ScipyDE

# Define base parameters (fixed during fitting)
base_params = {
    "encoding_drift_rate": 0.5,
    "start_drift_rate": 0.5,
    "recall_drift_rate": 0.5,
    "learning_rate": 0.5,
    "primacy_scale": 2.0,
    "primacy_decay": 0.8,
    "shared_support": 0.05,
    "item_support": 0.25,
    "choice_sensitivity": 0.6,
    "stop_probability_scale": 0.05,
    "stop_probability_growth": 0.2,
    "learn_after_context_update": True,
    "allow_repeated_recalls": False,
}

# Define bounds for free parameters
param_bounds = {
    "encoding_drift_rate": [0.1, 0.9],
    "recall_drift_rate": [0.1, 0.9],
}

print("Parameters configured")
print(f"Free parameters: {list(param_bounds.keys())}")
Code
# Create fitting algorithm
fitter = ScipyDE(
    dataset=dataset,
    features=None,  # No semantic features for this example
    base_params=base_params,
    model_create_fn=make_factory(CMR),
    loss_fn_generator=MemorySearchLikelihoodFnGenerator,
    hyperparams={
        "bounds": param_bounds,
        "num_steps": 50,  # Reduced for demo
        "pop_size": 10,
        "progress_bar": True,
    },
)

print("Fitter configured")
Code
# Fit to first subject only (for speed)trial_mask = dataset["subject"].flatten() == 1print(f"Fitting to {trial_mask.sum()} trials...")result = fitter.fit(trial_mask)print(f"\nFitting complete in {result['fit_time']:.1f}s")print(f"Final loss: {result['fitness'][0]:.3f}")print(f"Best parameters:")for name, values in result['fits'].items():    print(f"  {name}: {values[0]:.3f}")

Protocol Reference

LossFnGenerator

Code
from nbdev.showdoc import show_doc
from jaxcmr.typing import LossFnGenerator

show_doc(LossFnGenerator)

FittingAlgorithm

Code
from jaxcmr.typing import FittingAlgorithm

show_doc(FittingAlgorithm)

FitResult

Code
from jaxcmr.typing import FitResult

show_doc(FitResult)

Available Loss Functions

Generator Description
MemorySearchLikelihoodFnGenerator Negative log-likelihood of recall sequences
SetPermutationLikelihoodFnGenerator Likelihood ignoring recall order
SPCMSEFnGenerator Mean squared error on serial position curve

See Loss Functions for details.