# Create fitting algorithmfitter = ScipyDE( dataset=dataset, features=None, # No semantic features for this example base_params=base_params, model_create_fn=make_factory(CMR), loss_fn_generator=MemorySearchLikelihoodFnGenerator, hyperparams={"bounds": param_bounds,"num_steps": 50, # Reduced for demo"pop_size": 10,"progress_bar": True, },)print("Fitter configured")
Code
# Fit to first subject only (for speed)trial_mask = dataset["subject"].flatten() == 1print(f"Fitting to {trial_mask.sum()} trials...")result = fitter.fit(trial_mask)print(f"\nFitting complete in {result['fit_time']:.1f}s")print(f"Final loss: {result['fitness'][0]:.3f}")print(f"Best parameters:")for name, values in result['fits'].items(): print(f" {name}: {values[0]:.3f}")
Protocol Reference
LossFnGenerator
Code
from nbdev.showdoc import show_docfrom jaxcmr.typing import LossFnGeneratorshow_doc(LossFnGenerator)
FittingAlgorithm
Code
from jaxcmr.typing import FittingAlgorithmshow_doc(FittingAlgorithm)
FitResult
Code
from jaxcmr.typing import FitResultshow_doc(FitResult)