Differential evolution fitting via SciPy

ScipyDE uses Differential Evolution (DE), a population-based global optimization algorithm, to find model parameters that minimize a loss function.

What is Differential Evolution?

Differential Evolution is an evolutionary algorithm that: - Maintains a population of candidate solutions - Evolves the population through mutation, crossover, and selection - Converges to global optima without requiring gradients

This makes it well-suited for: - Non-convex optimization landscapes - Loss functions with local minima - Black-box objective functions

Algorithm Overview

DE iteratively improves a population of parameter vectors:

Initialization

Create a population of \(N\) random parameter vectors: \[\mathbf{x}_i \sim \text{Uniform}(\mathbf{lb}, \mathbf{ub})\]

where \(\mathbf{lb}\) and \(\mathbf{ub}\) are the parameter bounds.

Mutation

For each vector \(\mathbf{x}_i\), create a mutant: \[\mathbf{v}_i = \mathbf{x}_{r_1} + F \cdot (\mathbf{x}_{r_2} - \mathbf{x}_{r_3})\]

where: - \(r_1, r_2, r_3\) are distinct random indices - \(F\) is the mutation scale factor (diff_w)

Crossover

Create a trial vector by mixing \(\mathbf{x}_i\) and \(\mathbf{v}_i\): \[u_{i,j} = \begin{cases} v_{i,j} & \text{if } \text{rand}() < CR \text{ or } j = j_{\text{rand}} \\ x_{i,j} & \text{otherwise} \end{cases}\]

where \(CR\) is the crossover rate (cross_over_rate).

Selection

Keep the better vector: \[\mathbf{x}_i' = \begin{cases} \mathbf{u}_i & \text{if } f(\mathbf{u}_i) < f(\mathbf{x}_i) \\ \mathbf{x}_i & \text{otherwise} \end{cases}\]

Convergence

Repeat until: - Maximum iterations reached (num_steps) - Population converges (relative tolerance relative_tolerance)

Hyperparameters

Parameter Symbol Description Default
bounds Parameter bounds {"param": [lo, hi]} Required
num_steps Maximum iterations 1000
pop_size \(N\) Population size multiplier 15
best_of Independent restarts 1
relative_tolerance \(tol\) Convergence threshold 0.001
cross_over_rate \(CR\) Crossover probability 0.9
diff_w \(F\) Mutation scale factor 0.85
progress_bar Show progress True
display_iterations Show per-iteration info False

Population Size

The actual population is pop_size × n_params. For 10 parameters with pop_size=15: - Population = 150 candidate solutions - Each iteration evaluates 150 loss values

Relative Tolerance

Optimization stops early if the population standard deviation falls below: \[\text{std}(\text{fitness}) < tol \cdot |\text{mean}(\text{fitness})|\]

Best Of

With best_of=3: 1. Run optimization 3 independent times 2. Keep the result with lowest fitness 3. Guards against unlucky initialization

FitResult Structure

Code
results = fitter.fit_subjects(trial_mask)# Fixed parametersresults["fixed"]# {"allow_repeated_recalls": False, ...}# Parameter boundsresults["free"]# {"encoding_drift_rate": [0.0, 1.0], ...}# Loss values (one per subject when using fit_subjects)results["fitness"]# [loss_subj1, loss_subj2, ...]# Fitted parameter valuesresults["fits"]# {#     "encoding_drift_rate": [val_subj1, val_subj2, ...],#     "recall_drift_rate": [val_subj1, val_subj2, ...],#     ...,#     "subject": [id_subj1, id_subj2, ...]# }# Hyperparameters usedresults["hyperparameters"]# {"bounds": {...}, "num_steps": 1000, ...}# Total timeresults["fit_time"]# 123.45 (seconds)

Usage

Basic Per-Subject Fitting

Code
from jaxcmr.fitting import ScipyDEfrom jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGeneratorfrom jaxcmr.models.cmr import make_factory# Create model factorymodel_factory = make_factory(    init_mfc, init_mcf, init_context, PositionalTermination)# Configure fitterfitter = ScipyDE(    dataset=data,    features=None,    base_params={        "allow_repeated_recalls": False,        "learn_after_context_update": True,    },    model_factory=model_factory,    loss_fn_generator=MemorySearchLikelihoodFnGenerator,    hyperparams={        "bounds": {            "encoding_drift_rate": [0.0, 1.0],            "start_drift_rate": [0.0, 1.0],            "recall_drift_rate": [0.0, 1.0],            "shared_support": [0.0, 10.0],            "item_support": [0.0, 10.0],            "learning_rate": [0.0, 1.0],            "primacy_scale": [0.0, 10.0],            "primacy_decay": [0.0, 5.0],            "choice_sensitivity": [0.0, 10.0],            "stop_probability_scale": [0.0, 1.0],            "stop_probability_growth": [0.0, 1.0],        },        "num_steps": 1000,        "pop_size": 15,        "best_of": 3,        "progress_bar": True,    },)# Fit all subjectsresults = fitter.fit_subjects(trial_mask)

Pooled Fitting

Code
# Single fit across all subjects (pooled)results = fitter.fit(trial_mask)# Results have single values instead of listsprint(results["fitness"])  # [single_loss]print(results["fits"]["encoding_drift_rate"])  # [single_value]

With Different Loss Functions

Code
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator

# MSE-based fitting (may need more iterations)
fitter = ScipyDE(
    ...,
    loss_fn_generator=MemorySearchSpcMseFnGenerator,
    hyperparams={
        ...,
        "num_steps": 1500,  # More iterations for MSE
    },
)

Helper Functions

make_subject_trial_masks

Extract per-subject trial masks from a global mask:

Code
from jaxcmr.fitting import make_subject_trial_masks# Get masks for each subjectsubject_masks, unique_subjects = make_subject_trial_masks(trial_mask, data["subject"].flatten())# Manual per-subject fittingfor s, mask in enumerate(subject_masks):    results = fitter.fit(mask, subject_id=int(unique_subjects[s]))    # Process individual subject...

Computational Notes

Vectorized Evaluation

ScipyDE uses vectorized=True in SciPy’s differential_evolution:

Code
differential_evolution(
    loss_fn,
    bounds,
    vectorized=True,  # Evaluate entire population at once
    ...
)

The loss function receives parameters as: - Shape (n_params, pop_size) for vectorized evaluation - Returns shape (pop_size,) losses

This enables efficient parallel evaluation in JAX.

Memory Usage

Memory scales with: - Population size: pop_size × n_params - Number of trials in the mask - Model size

For large datasets, consider: - Reducing pop_size - Fitting subjects in batches - Using global fitting with subsampled trials

Time Estimates

Fitting time depends on: - Number of parameters (affects population size) - Number of trials (affects loss evaluation) - num_steps (affects iterations) - best_of (multiplies total time) - Loss function complexity

Tuning Recommendations

For Faster Convergence

Code
hyperparams = {
    "pop_size": 10,       # Smaller population
    "num_steps": 500,     # Fewer iterations
    "best_of": 1,         # Single run
    "relative_tolerance": 0.01,  # Looser convergence
}

For Better Optima

Code
hyperparams = {
    "pop_size": 20,       # Larger population
    "num_steps": 2000,    # More iterations
    "best_of": 5,         # Multiple restarts
    "relative_tolerance": 0.0001,  # Tighter convergence
}

For Difficult Problems

Code
hyperparams = {
    "cross_over_rate": 0.7,  # More exploration
    "diff_w": 0.5,           # Smaller mutations
    "num_steps": 3000,       # Many iterations
    "best_of": 10,           # Many restarts
}

Troubleshooting

Poor Convergence

Symptoms: High loss values, unstable parameters

Solutions: - Increase num_steps - Increase best_of - Check parameter bounds (too wide?) - Verify loss function is correct

Slow Fitting

Symptoms: Fitting takes too long

Solutions: - Reduce pop_size - Reduce num_steps - Use global fitting instead of per-subject - Reduce number of free parameters

Different Results Each Run

Symptoms: Fitted parameters vary across runs

Solutions: - Increase best_of - Tighten relative_tolerance - Check for multimodal loss landscape - Consider identifiability issues

Example: Full Workflow

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.models.cmr import make_factory
from jaxcmr.helpers import load_data, generate_trial_mask
from jaxcmr.summarize import summarize_parameters
from jaxcmr.simulation import simulate_h5_from_h5
from jax import random
import jax.numpy as jnp

# Load data
data = load_data("experiment.h5")
trial_mask = generate_trial_mask(data, "True")

# Create model
model_factory = make_factory(
    init_mfc, init_mcf, init_context, PositionalTermination
)

# Define parameters
fixed = {"allow_repeated_recalls": False}
free = {
    "encoding_drift_rate": [0.0, 1.0],
    "recall_drift_rate": [0.0, 1.0],
    "primacy_scale": [0.0, 10.0],
    "primacy_decay": [0.0, 5.0],
}

# Fit
fitter = ScipyDE(
    data, None, fixed, model_factory,
    MemorySearchLikelihoodFnGenerator,
    {"bounds": free, "num_steps": 1000, "best_of": 3},
)
results = fitter.fit(trial_mask)

# Summarize
print(summarize_parameters([results], list(free.keys())))

# Simulate
fitted_params = {k: jnp.array(v) for k, v in results["fits"].items()}
sim = simulate_h5_from_h5(
    model_factory, data, None, fitted_params,
    trial_mask, experiment_count=100,
    rng=random.PRNGKey(0),
)