Fitting Algorithms

Parameter optimization strategies

Fitting algorithms search for model parameters that minimize a loss function. jaxcmr provides implementations that work with any LossFnGenerator to find optimal parameters.

What is Fitting?

Fitting (or parameter optimization) finds parameters \(\theta^*\) that minimize a loss function:

\[\theta^* = \arg\min_\theta L(\theta)\]

For memory models, this typically means finding parameters that make the observed recall data most probable under the model.

The FittingAlgorithm Protocol

All fitting algorithms implement the FittingAlgorithm protocol:

Code
class FittingAlgorithm(Protocol):    def __init__(        self,        dataset: RecallDataset,        features: Optional[...],        base_params: Mapping[str, float],        model_factory: Type[MemorySearchModelFactory],        loss_fn_generator: Type[LossFnGenerator],        hyperparams: Optional[dict[str, Any]] = None,    ): ...    def fit(        self,        trial_mask: Bool[Array, " trials"],        subject_id: int = -1,    ) -> FitResult: ...    def fit_subjects(        self,        trial_mask: Bool[Array, " trials"],    ) -> FitResult: ...

See Protocols for full details.

Per-Subject vs Global Fitting

Per-Subject Fitting

Code
results = fitter.fit_subjects(trial_mask)
  • Fits separate parameters for each subject- Captures individual differences- Returns one parameter set per subject- Default behaviorWhen to use:- Individual differences are expected- Enough trials per subject for stable estimates- Want to examine parameter distributions### Pooled Fitting
Code
results = fitter.fit(trial_mask)
  • Fits single parameters to all data
  • Pools information across subjects
  • Returns one parameter set total

When to use: - Limited trials per subject - Testing group-level predictions - Preliminary exploration

Available Algorithms

Algorithm Module Optimizer Description
ScipyDE jaxcmr.fitting Differential Evolution Global optimizer, robust

Currently jaxcmr provides ScipyDE (Differential Evolution via SciPy). Additional algorithms can be implemented following the FittingAlgorithm protocol.

Common Hyperparameters

Most fitting algorithms accept these hyperparameters:

Parameter Description Typical Value
bounds Parameter search bounds {"param": [lo, hi], ...}
num_steps Maximum iterations 500-2000
best_of Number of restarts 1-5
progress_bar Show progress True

Basic Usage

Code
from jaxcmr.fitting import ScipyDEfrom jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGeneratorfrom jaxcmr.models.cmr import make_factoryfrom jaxcmr.components.linear_memory import init_mfc, init_mcffrom jaxcmr.components.context import init as init_contextfrom jaxcmr.components.termination import PositionalTermination# 1. Create model factorymodel_factory = make_factory(    init_mfc, init_mcf, init_context, PositionalTermination)# 2. Define parametersfixed_params = {    "allow_repeated_recalls": False,    "learn_after_context_update": True,}free_params = {    "encoding_drift_rate": [0.0, 1.0],    "start_drift_rate": [0.0, 1.0],    "recall_drift_rate": [0.0, 1.0],    "shared_support": [0.0, 10.0],    "item_support": [0.0, 10.0],    "learning_rate": [0.0, 1.0],    "primacy_scale": [0.0, 10.0],    "primacy_decay": [0.0, 5.0],    "choice_sensitivity": [0.0, 10.0],    "stop_probability_scale": [0.0, 1.0],    "stop_probability_growth": [0.0, 1.0],}# 3. Create fitterfitter = ScipyDE(    dataset=data,    features=None,    base_params=fixed_params,    model_factory=model_factory,    loss_fn_generator=MemorySearchLikelihoodFnGenerator,    hyperparams={        "bounds": free_params,        "num_steps": 1000,        "pop_size": 15,        "best_of": 3,        "progress_bar": True,    },)# 4. Run fittingresults = fitter.fit_subjects(trial_mask)

Interpreting Results

The FitResult dictionary contains:

Code
results["fixed"]      # Fixed parameters: {"param": value}
results["free"]       # Bounds: {"param": [lo, hi]}
results["fitness"]    # Loss values: [subj1_loss, subj2_loss, ...]
results["fits"]       # Fitted values: {"param": [subj1_val, ...], "subject": [...]}
results["hyperparameters"]  # Settings used
results["fit_time"]   # Total seconds

Accessing Fitted Parameters

Code
# Per-subject values for a parameter
encoding_rates = results["fits"]["encoding_drift_rate"]

# Subject IDs
subjects = results["fits"]["subject"]

# Create DataFrame
import pandas as pd
df = pd.DataFrame({
    "subject": subjects,
    "encoding_drift_rate": encoding_rates,
    # ... other parameters
})

Choosing Parameters to Fit

Fixed vs Free

Fix parameters when: - Value is known or constrained by design - Limited data for estimation - Reducing model complexity

Free parameters when: - Value is theoretically important - Adequate data for estimation - Individual differences expected

Bounds Selection

Parameter bounds should: - Cover the theoretically plausible range - Not be unnecessarily wide (slows optimization) - Reflect prior knowledge

Code
# Too wide - inefficient
"encoding_drift_rate": [0.0, 1000.0]

# Too narrow - may miss optimum
"encoding_drift_rate": [0.4, 0.6]

# Appropriate
"encoding_drift_rate": [0.0, 1.0]

Optimization Tips

Multiple Restarts

Use best_of > 1 for global optimizers:

Code
hyperparams = {
    "best_of": 3,  # Run 3 times, keep best
    ...
}

This helps avoid local minima and verifies convergence.

Sufficient Iterations

For complex models, increase num_steps:

Code
# Simple model
"num_steps": 500

# Complex model or MSE fitting
"num_steps": 2000

Monitor Progress

Enable progress bar to track optimization:

Code
hyperparams = {
    "progress_bar": True,
    "display_iterations": True,  # Show per-iteration details
}

Trial Masks

Filter which trials to include:

Code
from jaxcmr.helpers import generate_trial_mask

# All trials
trial_mask = generate_trial_mask(data, "True")

# Filter by condition
trial_mask = generate_trial_mask(data, "data['listtype'] == 1")

# Specific subjects
trial_mask = generate_trial_mask(data, "data['subject'] < 10")

Next Steps

  • See ScipyDE for detailed algorithm documentation
  • See Loss Functions for available loss functions
  • See Simulation for generating predictions from fitted parameters
  • See Comparison for comparing fitted models