ScipyDE uses Differential Evolution (DE), a population-based global optimization algorithm, to find model parameters that minimize a loss function.
What is Differential Evolution?
Differential Evolution is an evolutionary algorithm that: - Maintains a population of candidate solutions - Evolves the population through mutation, crossover, and selection - Converges to global optima without requiring gradients
This makes it well-suited for: - Non-convex optimization landscapes - Loss functions with local minima - Black-box objective functions
Algorithm Overview
DE iteratively improves a population of parameter vectors:
Initialization
Create a population of \(N\) random parameter vectors: \[\mathbf{x}_i \sim \text{Uniform}(\mathbf{lb}, \mathbf{ub})\]
where \(\mathbf{lb}\) and \(\mathbf{ub}\) are the parameter bounds.
Mutation
For each vector \(\mathbf{x}_i\), create a mutant: \[\mathbf{v}_i = \mathbf{x}_{r_1} + F \cdot (\mathbf{x}_{r_2} - \mathbf{x}_{r_3})\]
where: - \(r_1, r_2, r_3\) are distinct random indices - \(F\) is the mutation scale factor (diff_w)
Crossover
Create a trial vector by mixing \(\mathbf{x}_i\) and \(\mathbf{v}_i\): \[u_{i,j} = \begin{cases} v_{i,j} & \text{if } \text{rand}() < CR \text{ or } j = j_{\text{rand}} \\ x_{i,j} & \text{otherwise} \end{cases}\]
where \(CR\) is the crossover rate (cross_over_rate).
Repeat until: - Maximum iterations reached (num_steps) - Population converges (relative tolerance relative_tolerance)
Hyperparameters
Parameter
Symbol
Description
Default
bounds
—
Parameter bounds {"param": [lo, hi]}
Required
num_steps
—
Maximum iterations
1000
pop_size
\(N\)
Population size multiplier
15
best_of
—
Independent restarts
1
relative_tolerance
\(tol\)
Convergence threshold
0.001
cross_over_rate
\(CR\)
Crossover probability
0.9
diff_w
\(F\)
Mutation scale factor
0.85
progress_bar
—
Show progress
True
display_iterations
—
Show per-iteration info
False
Population Size
The actual population is pop_size × n_params. For 10 parameters with pop_size=15: - Population = 150 candidate solutions - Each iteration evaluates 150 loss values
Relative Tolerance
Optimization stops early if the population standard deviation falls below: \[\text{std}(\text{fitness}) < tol \cdot |\text{mean}(\text{fitness})|\]
Best Of
With best_of=3: 1. Run optimization 3 independent times 2. Keep the result with lowest fitness 3. Guards against unlucky initialization
# Single fit across all subjects (pooled)results = fitter.fit(trial_mask)# Results have single values instead of listsprint(results["fitness"]) # [single_loss]print(results["fits"]["encoding_drift_rate"]) # [single_value]
With Different Loss Functions
Code
from jaxcmr.loss.spc_mse import MemorySearchSpcMseFnGenerator# MSE-based fitting (may need more iterations)fitter = ScipyDE( ..., loss_fn_generator=MemorySearchSpcMseFnGenerator, hyperparams={ ...,"num_steps": 1500, # More iterations for MSE },)
Helper Functions
make_subject_trial_masks
Extract per-subject trial masks from a global mask:
Code
from jaxcmr.fitting import make_subject_trial_masks# Get masks for each subjectsubject_masks, unique_subjects = make_subject_trial_masks(trial_mask, data["subject"].flatten())# Manual per-subject fittingfor s, mask in enumerate(subject_masks): results = fitter.fit(mask, subject_id=int(unique_subjects[s])) # Process individual subject...
Computational Notes
Vectorized Evaluation
ScipyDE uses vectorized=True in SciPy’s differential_evolution:
Code
differential_evolution( loss_fn, bounds, vectorized=True, # Evaluate entire population at once ...)
The loss function receives parameters as: - Shape (n_params, pop_size) for vectorized evaluation - Returns shape (pop_size,) losses
This enables efficient parallel evaluation in JAX.
Memory Usage
Memory scales with: - Population size: pop_size × n_params - Number of trials in the mask - Model size
For large datasets, consider: - Reducing pop_size - Fitting subjects in batches - Using global fitting with subsampled trials
Time Estimates
Fitting time depends on: - Number of parameters (affects population size) - Number of trials (affects loss evaluation) - num_steps (affects iterations) - best_of (multiplies total time) - Loss function complexity