Code
from jaxcmr.loss.set_permutation_likelihood import MemorySearchLikelihoodFnGeneratorMonte Carlo likelihood for unordered recall data
Set Permutation Likelihood computes the probability of observing a recall set (ignoring order) by averaging over all possible orderings. This is essential when recall order is not recorded or not meaningful.
Sometimes only the set of recalled items matters: - Final free recall (FFR) where order isn’t recorded - Tasks focusing on recall probability, not temporal organization - Data where order information is unreliable
Standard sequence likelihood requires a specific order. For unordered data, we need to marginalize over all possible orderings.
For a recall set \(R_t = \{r_1, \ldots, r_k\}\), the true probability is:
\[P(R_t \mid \theta) = \sum_{\pi \in S_k} P(\pi(R_t) \mid \theta)\]
where \(S_k\) is the set of all \(k!\) permutations and \(\pi(R_t)\) is the ordered sequence under permutation \(\pi\).
Computing all \(k!\) permutations is infeasible for typical recall lengths: - \(k = 5\): \(5! = 120\) permutations - \(k = 10\): \(10! = 3,628,800\) permutations - \(k = 15\): \(15! \approx 1.3 \times 10^{12}\) permutations
Instead, we sample \(N\) random permutations and average:
\[\hat{P}(R_t \mid \theta) \approx \frac{1}{N} \sum_{i=1}^{N} \prod_{j=1}^{k} P(\pi_i(r_j) \mid \pi_i(r_{1:j-1}), \theta)\]
where \(\pi_i \sim \text{Uniform}(S_k)\) are independent random permutations.
The negative log-likelihood:
\[L(\theta) = -\sum_t \log \hat{P}(R_t \mid \theta)\]
Module: jaxcmr.loss.set_permutation_likelihood
| Parameter | Default | Description |
|---|---|---|
simulation_count |
50 | Number of permutation samples per trial |
| Aspect | Sequence Likelihood | Set Permutation |
|---|---|---|
| Order information | Full | Ignored |
| Computation | \(O(k)\) per trial | \(O(N \cdot k)\) per trial |
| Captures CRP | Yes | No |
| Captures SPC | Yes | Yes |
| Contiguity | Yes | No |
| Recall probability | Yes | Yes |
By ignoring order, set permutation likelihood cannot capture: - Lag-CRP (conditional response probability by lag) - Forward asymmetry - Temporal clustering - Initiation effects (which item is recalled first)
Set permutation likelihood does capture: - Serial position curve (which positions are recalled) - Overall recall probability - Item-level effects (primacy, recency)
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.set_permutation_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.models.cmr import make_factory
# Create model factory
model_factory = make_factory(
init_mfc, init_mcf, init_context, PositionalTermination
)
# Fit using set permutation likelihood
fitter = ScipyDE(
dataset=data,
features=None,
base_params={"allow_repeated_recalls": False, ...},
model_factory=model_factory,
loss_fn_generator=MemorySearchLikelihoodFnGenerator, # Same name, different module!
hyperparams={
"bounds": {
"encoding_drift_rate": [0.0, 1.0],
"start_drift_rate": [0.0, 1.0],
...
},
"num_steps": 1000,
},
)
results = fitter.fit(trial_mask)Important: Note that the class name is the same (MemorySearchLikelihoodFnGenerator) but the module differs (loss.set_permutation_likelihood vs loss.sequence_likelihood).
Increasing simulation_count reduces variance but increases computation:
Permutations are pre-sampled for reproducibility. This means: - Same permutations used across optimization iterations - Gradient estimates are consistent - Different random seeds give different results
Monte Carlo averaging produces an unbiased estimate:
\[\mathbb{E}[\hat{P}(R_t)] = P(R_t)\]
But the log of the average is not the average of the logs:
\[\mathbb{E}[\log \hat{P}(R_t)] \neq \log P(R_t)\]
This introduces a small bias, decreasing with sample count.
If recall order were truly random (exchangeable), all permutations would have equal probability. The permutation-averaged likelihood tests whether the model’s predictions are consistent with the observed set, regardless of which order actually occurred.
An alternative is to compute the sum over item identities rather than permutations (sum-product algorithm). This is more efficient but requires model modifications. The permutation approach works with any model unchanged.
| Recall length \(k\) | Sequence likelihood | Set permutation (\(N=50\)) |
|---|---|---|
| 5 | 5 steps | 250 steps |
| 10 | 10 steps | 500 steps |
| 15 | 15 steps | 750 steps |
The overhead is a factor of \(N\) (sample count).
Permutation samples are independent and can be parallelized:
Pre-sampling permutations requires storage: - \(T\) trials × \(N\) samples × \(k\) items - Typically a few MB for standard datasets