Fitting algorithms search for model parameters that minimize a loss function. jaxcmr provides implementations that work with any LossFnGenerator to find optimal parameters.
What is Fitting?
Fitting (or parameter optimization) finds parameters \(\theta^*\) that minimize a loss function:
\[\theta^* = \arg\min_\theta L(\theta)\]
For memory models, this typically means finding parameters that make the observed recall data most probable under the model.
The FittingAlgorithm Protocol
All fitting algorithms implement the FittingAlgorithm protocol:
Fits separate parameters for each subject- Captures individual differences- Returns one parameter set per subject- Default behaviorWhen to use:- Individual differences are expected- Enough trials per subject for stable estimates- Want to examine parameter distributions### Pooled Fitting
Code
results = fitter.fit(trial_mask)
Fits single parameters to all data
Pools information across subjects
Returns one parameter set total
When to use: - Limited trials per subject - Testing group-level predictions - Preliminary exploration
Currently jaxcmr provides ScipyDE (Differential Evolution via SciPy). Additional algorithms can be implemented following the FittingAlgorithm protocol.
Common Hyperparameters
Most fitting algorithms accept these hyperparameters:
# Per-subject values for a parameterencoding_rates = results["fits"]["encoding_drift_rate"]# Subject IDssubjects = results["fits"]["subject"]# Create DataFrameimport pandas as pddf = pd.DataFrame({"subject": subjects,"encoding_drift_rate": encoding_rates,# ... other parameters})
Choosing Parameters to Fit
Fixed vs Free
Fix parameters when: - Value is known or constrained by design - Limited data for estimation - Reducing model complexity
Free parameters when: - Value is theoretically important - Adequate data for estimation - Individual differences expected
Bounds Selection
Parameter bounds should: - Cover the theoretically plausible range - Not be unnecessarily wide (slows optimization) - Reflect prior knowledge
Code
# Too wide - inefficient"encoding_drift_rate": [0.0, 1000.0]# Too narrow - may miss optimum"encoding_drift_rate": [0.4, 0.6]# Appropriate"encoding_drift_rate": [0.0, 1.0]
Optimization Tips
Multiple Restarts
Use best_of > 1 for global optimizers:
Code
hyperparams = {"best_of": 3, # Run 3 times, keep best ...}
This helps avoid local minima and verifies convergence.
Sufficient Iterations
For complex models, increase num_steps:
Code
# Simple model"num_steps": 500# Complex model or MSE fitting"num_steps": 2000
Monitor Progress
Enable progress bar to track optimization:
Code
hyperparams = {"progress_bar": True,"display_iterations": True, # Show per-iteration details}
Trial Masks
Filter which trials to include:
Code
from jaxcmr.helpers import generate_trial_mask# All trialstrial_mask = generate_trial_mask(data, "True")# Filter by conditiontrial_mask = generate_trial_mask(data, "data['listtype'] == 1")# Specific subjectstrial_mask = generate_trial_mask(data, "data['subject'] < 10")