Code
probs = model.outcome_probabilities()
# probs[0] = P(stop)
# probs[i] = P(recall item i) for i > 0Maximum likelihood estimation for recall sequences
Sequence likelihood loss treats each recall event as a probabilistic choice, computing the negative log-likelihood of observed recall sequences under the model.
Given a model with parameters \(\theta\), we compute the probability of each observed recall event:
\[P(r_j \mid r_{1:j-1}, \theta)\]
This is the probability of recalling item \(r_j\) given the previous recalls \(r_{1:j-1}\) and model state.
For a single trial \(t\) with recall sequence \(R_t = (r_1, r_2, \ldots, r_k)\):
\[\mathcal{L}_t(\theta) = \prod_{j=1}^{k} P(r_{t,j} \mid r_{t,1:j-1}, \theta)\]
This includes both item recalls and the final stop decision.
The loss function is the sum of negative log-likelihoods across trials:
\[L(\theta) = -\sum_t \log \mathcal{L}_t(\theta) = -\sum_t \sum_{j=1}^{|R_t|} \log P(r_{t,j} \mid r_{t,1:j-1}, \theta)\]
At each retrieval step, the model provides:
The likelihood of observing recall \(r_j\) is probs[r_j].
Module: jaxcmr.loss.sequence_likelihood
The default implementation creates per-trial models and computes sequential probabilities:
Implementation details: - Creates fresh model for each trial - Uses lax.scan for sequential probability computation - Supports vectorized parameter evaluation
Module: jaxcmr.loss.transform_sequence_likelihood
Applies user-specified masks to per-trial likelihood arrays:
| Variant | Behavior | Use Case |
|---|---|---|
MemorySearchLikelihoodFnGenerator |
Custom mask function | Flexible masking |
ExcludeFirstRecallLikelihoodFnGenerator |
Masks \(P(r_1)\) | Focus on transitions |
ExcludeTerminationLikelihoodFnGenerator |
Masks \(P(stop)\) events | Ignore stopping |
Masks replace selected probability entries with 1.0, neutralizing their contribution to the log-likelihood:
Built-in mask functions:
Module: jaxcmr.loss.base_sequence_likelihood
Reuses a single study context across all trials:
When to use: - All trials have identical presentation sequences - Computational efficiency matters - Legacy compatibility
Limitation: Does not support trial-specific presentations.
| Variant | Per-Trial Models | Masking | Use Case |
|---|---|---|---|
| Standard | Yes | No | Default MLE |
| Transform | Yes | Yes (pluggable) | Selective fitting |
| Exclude First | Yes | First recall | Transitions only |
| Exclude Termination | Yes | Stop events | Item selection only |
| Base (Legacy) | No (shared) | Yes | Identical lists |
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator
fitter = ScipyDE(
dataset=data,
features=None,
base_params={"allow_repeated_recalls": False, ...},
model_factory=model_factory,
loss_fn_generator=MemorySearchLikelihoodFnGenerator,
hyperparams={
"bounds": {
"encoding_drift_rate": [0.0, 1.0],
"recall_drift_rate": [0.0, 1.0],
...
},
"num_steps": 1000,
},
)
results = fitter.fit(trial_mask)Focus on transitions rather than initiation:
Why exclude first recall? - First recall is heavily influenced by recency - Transitions test contiguity more directly - Removes confound of start-of-retrieval dynamics
Focus on item selection, ignore stopping decisions:
Why exclude termination? - Stopping behavior may be noisy or poorly modeled - Focus on which items are retrieved - Useful when stop parameters are unreliable
Create your own mask function:
Likelihood requires sequential model updates:
This is implemented efficiently with lax.scan to maintain JAX traceability.
Log-probabilities are summed to avoid underflow:
\[\log \mathcal{L} = \sum_j \log P(r_j \mid \ldots)\]
A small constant lb (lower bound) prevents log(0):
For population-based optimizers, the loss function accepts batched parameters:
Likelihood assumes order is meaningful. For unordered recall, use Set Permutation Likelihood.
Sequential computation limits parallelism. Each trial requires \(O(k)\) model updates where \(k\) is recall count.
If the model assigns probability 0 to an observed event, likelihood becomes \(-\infty\). The lb constant mitigates this but doesn’t solve fundamental model misspecification.