Sequence Likelihood

Maximum likelihood estimation for recall sequences

Sequence likelihood loss treats each recall event as a probabilistic choice, computing the negative log-likelihood of observed recall sequences under the model.

The Core Idea

Given a model with parameters \(\theta\), we compute the probability of each observed recall event:

\[P(r_j \mid r_{1:j-1}, \theta)\]

This is the probability of recalling item \(r_j\) given the previous recalls \(r_{1:j-1}\) and model state.

Mathematical Specification

Trial-Level Likelihood

For a single trial \(t\) with recall sequence \(R_t = (r_1, r_2, \ldots, r_k)\):

\[\mathcal{L}_t(\theta) = \prod_{j=1}^{k} P(r_{t,j} \mid r_{t,1:j-1}, \theta)\]

This includes both item recalls and the final stop decision.

Negative Log-Likelihood

The loss function is the sum of negative log-likelihoods across trials:

\[L(\theta) = -\sum_t \log \mathcal{L}_t(\theta) = -\sum_t \sum_{j=1}^{|R_t|} \log P(r_{t,j} \mid r_{t,1:j-1}, \theta)\]

Probability Computation

At each retrieval step, the model provides:

Code
probs = model.outcome_probabilities()
# probs[0] = P(stop)
# probs[i] = P(recall item i) for i > 0

The likelihood of observing recall \(r_j\) is probs[r_j].

Variants

Standard Sequence Likelihood

Module: jaxcmr.loss.sequence_likelihood

The default implementation creates per-trial models and computes sequential probabilities:

Code
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator

fitter = ScipyDE(
    ...,
    loss_fn_generator=MemorySearchLikelihoodFnGenerator,
)

Implementation details: - Creates fresh model for each trial - Uses lax.scan for sequential probability computation - Supports vectorized parameter evaluation

Transform Sequence Likelihood

Module: jaxcmr.loss.transform_sequence_likelihood

Applies user-specified masks to per-trial likelihood arrays:

Code
from jaxcmr.loss.transform_sequence_likelihood import (
    MemorySearchLikelihoodFnGenerator,
    ExcludeFirstRecallLikelihoodFnGenerator,
    ExcludeTerminationLikelihoodFnGenerator,
)
Variant Behavior Use Case
MemorySearchLikelihoodFnGenerator Custom mask function Flexible masking
ExcludeFirstRecallLikelihoodFnGenerator Masks \(P(r_1)\) Focus on transitions
ExcludeTerminationLikelihoodFnGenerator Masks \(P(stop)\) events Ignore stopping

Masking Mechanism

Masks replace selected probability entries with 1.0, neutralizing their contribution to the log-likelihood:

Code
def _apply_mask(probabilities, mask):
    """Replace masked entries with 1.0 so log(1.0) = 0."""
    return jnp.where(mask, probabilities, 1.0)

Built-in mask functions:

Code
def mask_first_recall(pres, recalls):
    """Returns mask that drops the first recall event."""
    return jnp.arange(len(recalls)) != 0

def mask_trailing_terminations(pres, recalls):
    """Returns mask retaining only nonzero (item) recall events."""
    return recalls != 0

Base Sequence Likelihood (Legacy)

Module: jaxcmr.loss.base_sequence_likelihood

Reuses a single study context across all trials:

Code
from jaxcmr.loss.base_sequence_likelihood import (
    MemorySearchLikelihoodFnGenerator,
    ExcludeFirstRecallLikelihoodFnGenerator,
    ExcludeTerminationLikelihoodFnGenerator,
)

When to use: - All trials have identical presentation sequences - Computational efficiency matters - Legacy compatibility

Limitation: Does not support trial-specific presentations.

Comparison Table

Variant Per-Trial Models Masking Use Case
Standard Yes No Default MLE
Transform Yes Yes (pluggable) Selective fitting
Exclude First Yes First recall Transitions only
Exclude Termination Yes Stop events Item selection only
Base (Legacy) No (shared) Yes Identical lists

Usage Examples

Standard Likelihood

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.sequence_likelihood import MemorySearchLikelihoodFnGenerator

fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params={"allow_repeated_recalls": False, ...},
    model_factory=model_factory,
    loss_fn_generator=MemorySearchLikelihoodFnGenerator,
    hyperparams={
        "bounds": {
            "encoding_drift_rate": [0.0, 1.0],
            "recall_drift_rate": [0.0, 1.0],
            ...
        },
        "num_steps": 1000,
    },
)

results = fitter.fit(trial_mask)

Exclude First Recall

Focus on transitions rather than initiation:

Code
from jaxcmr.loss.transform_sequence_likelihood import ExcludeFirstRecallLikelihoodFnGenerator

fitter = ScipyDE(
    ...,
    loss_fn_generator=ExcludeFirstRecallLikelihoodFnGenerator,
)

Why exclude first recall? - First recall is heavily influenced by recency - Transitions test contiguity more directly - Removes confound of start-of-retrieval dynamics

Exclude Termination

Focus on item selection, ignore stopping decisions:

Code
from jaxcmr.loss.transform_sequence_likelihood import ExcludeTerminationLikelihoodFnGenerator

fitter = ScipyDE(
    ...,
    loss_fn_generator=ExcludeTerminationLikelihoodFnGenerator,
)

Why exclude termination? - Stopping behavior may be noisy or poorly modeled - Focus on which items are retrieved - Useful when stop parameters are unreliable

Custom Masking

Create your own mask function:

Code
from jaxcmr.loss.transform_sequence_likelihood import MemorySearchLikelihoodFnGenerator

def my_mask_fn(pres, recalls):
    """Example: mask recalls after position 5."""
    return jnp.arange(len(recalls)) < 5

# Use with the base generator that accepts mask functions
# (Implementation-specific)

Computational Notes

Sequential Computation

Likelihood requires sequential model updates:

Code
for each recall r_j:
    probs = model.outcome_probabilities()
    likelihood *= probs[r_j]
    model = model.retrieve(r_j)

This is implemented efficiently with lax.scan to maintain JAX traceability.

Numerical Stability

Log-probabilities are summed to avoid underflow:

\[\log \mathcal{L} = \sum_j \log P(r_j \mid \ldots)\]

A small constant lb (lower bound) prevents log(0):

Code
log_prob = jnp.log(prob + lb)

Vectorized Parameters

For population-based optimizers, the loss function accepts batched parameters:

Code
# Shape: (n_params, pop_size)
params_batch = ...
losses = loss_fn(params_batch)  # Shape: (pop_size,)

What Likelihood Captures

Temporal Dynamics

  • Contiguity: Transitions to nearby items
  • Forward asymmetry: Preference for forward transitions
  • Recency: First recalls from end of list

Serial Position Effects

  • Primacy: Enhanced recall of early items
  • Recency: Enhanced recall of late items

Stopping Behavior

  • When subjects terminate recall
  • How stop probability changes with recalls

Limitations

Order Dependence

Likelihood assumes order is meaningful. For unordered recall, use Set Permutation Likelihood.

Computational Cost

Sequential computation limits parallelism. Each trial requires \(O(k)\) model updates where \(k\) is recall count.

Sensitivity to Zeros

If the model assigns probability 0 to an observed event, likelihood becomes \(-\infty\). The lb constant mitigates this but doesn’t solve fundamental model misspecification.