Set Permutation Likelihood

Monte Carlo likelihood for unordered recall data

Set Permutation Likelihood computes the probability of observing a recall set (ignoring order) by averaging over all possible orderings. This is essential when recall order is not recorded or not meaningful.

The Problem

Sometimes only the set of recalled items matters: - Final free recall (FFR) where order isn’t recorded - Tasks focusing on recall probability, not temporal organization - Data where order information is unreliable

Standard sequence likelihood requires a specific order. For unordered data, we need to marginalize over all possible orderings.

Mathematical Specification

True Marginal Probability

For a recall set \(R_t = \{r_1, \ldots, r_k\}\), the true probability is:

\[P(R_t \mid \theta) = \sum_{\pi \in S_k} P(\pi(R_t) \mid \theta)\]

where \(S_k\) is the set of all \(k!\) permutations and \(\pi(R_t)\) is the ordered sequence under permutation \(\pi\).

Computational Challenge

Computing all \(k!\) permutations is infeasible for typical recall lengths: - \(k = 5\): \(5! = 120\) permutations - \(k = 10\): \(10! = 3,628,800\) permutations - \(k = 15\): \(15! \approx 1.3 \times 10^{12}\) permutations

Monte Carlo Approximation

Instead, we sample \(N\) random permutations and average:

\[\hat{P}(R_t \mid \theta) \approx \frac{1}{N} \sum_{i=1}^{N} \prod_{j=1}^{k} P(\pi_i(r_j) \mid \pi_i(r_{1:j-1}), \theta)\]

where \(\pi_i \sim \text{Uniform}(S_k)\) are independent random permutations.

Loss Function

The negative log-likelihood:

\[L(\theta) = -\sum_t \log \hat{P}(R_t \mid \theta)\]

Implementation

Module: jaxcmr.loss.set_permutation_likelihood

Code
from jaxcmr.loss.set_permutation_likelihood import MemorySearchLikelihoodFnGenerator

Key Parameters

Parameter Default Description
simulation_count 50 Number of permutation samples per trial

How It Works

  1. Pre-sample random permutations for reproducibility
  2. For each trial:
    • Apply each permutation to the recall set
    • Compute sequence probability for each ordering
    • Average across permutations
  3. Sum negative log-probabilities across trials
Code
def predict_trial(self, model, present, recalls, permutations):
    """Compute average probability across permutation samples."""
    probs = []
    for perm in permutations:
        ordered_recalls = recalls[perm]
        prob = sequence_probability(model, present, ordered_recalls)
        probs.append(prob)
    return jnp.mean(probs)

When to Use

Use Set Permutation When:

  • Order not recorded: Data only contains which items were recalled
  • Final free recall (FFR): Extended recall period without order tracking
  • Order is unreliable: Noisy or incomplete order data
  • Theoretical focus: Testing recall probability, not temporal organization

Don’t Use When:

  • Order is meaningful: Testing contiguity, forward asymmetry
  • Full sequence data: Standard likelihood is more informative
  • Computational constraints: Monte Carlo adds overhead

Comparison with Standard Likelihood

Aspect Sequence Likelihood Set Permutation
Order information Full Ignored
Computation \(O(k)\) per trial \(O(N \cdot k)\) per trial
Captures CRP Yes No
Captures SPC Yes Yes
Contiguity Yes No
Recall probability Yes Yes

What’s Lost

By ignoring order, set permutation likelihood cannot capture: - Lag-CRP (conditional response probability by lag) - Forward asymmetry - Temporal clustering - Initiation effects (which item is recalled first)

What’s Preserved

Set permutation likelihood does capture: - Serial position curve (which positions are recalled) - Overall recall probability - Item-level effects (primacy, recency)

Usage Example

Code
from jaxcmr.fitting import ScipyDE
from jaxcmr.loss.set_permutation_likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.models.cmr import make_factory

# Create model factory
model_factory = make_factory(
    init_mfc, init_mcf, init_context, PositionalTermination
)

# Fit using set permutation likelihood
fitter = ScipyDE(
    dataset=data,
    features=None,
    base_params={"allow_repeated_recalls": False, ...},
    model_factory=model_factory,
    loss_fn_generator=MemorySearchLikelihoodFnGenerator,  # Same name, different module!
    hyperparams={
        "bounds": {
            "encoding_drift_rate": [0.0, 1.0],
            "start_drift_rate": [0.0, 1.0],
            ...
        },
        "num_steps": 1000,
    },
)

results = fitter.fit(trial_mask)

Important: Note that the class name is the same (MemorySearchLikelihoodFnGenerator) but the module differs (loss.set_permutation_likelihood vs loss.sequence_likelihood).

Variance Reduction

More Samples

Increasing simulation_count reduces variance but increases computation:

Code
# Lower variance, slower
generator = MemorySearchLikelihoodFnGenerator(...)
generator.simulation_count = 100  # Default is 50

Pre-Sampling

Permutations are pre-sampled for reproducibility. This means: - Same permutations used across optimization iterations - Gradient estimates are consistent - Different random seeds give different results

Theoretical Notes

Unbiased Estimation

Monte Carlo averaging produces an unbiased estimate:

\[\mathbb{E}[\hat{P}(R_t)] = P(R_t)\]

But the log of the average is not the average of the logs:

\[\mathbb{E}[\log \hat{P}(R_t)] \neq \log P(R_t)\]

This introduces a small bias, decreasing with sample count.

Connection to Exchangeability

If recall order were truly random (exchangeable), all permutations would have equal probability. The permutation-averaged likelihood tests whether the model’s predictions are consistent with the observed set, regardless of which order actually occurred.

Relationship to Sum-Product

An alternative is to compute the sum over item identities rather than permutations (sum-product algorithm). This is more efficient but requires model modifications. The permutation approach works with any model unchanged.

Computational Considerations

Scaling

Recall length \(k\) Sequence likelihood Set permutation (\(N=50\))
5 5 steps 250 steps
10 10 steps 500 steps
15 15 steps 750 steps

The overhead is a factor of \(N\) (sample count).

Parallelization

Permutation samples are independent and can be parallelized:

Code
# Vectorized over permutations using vmap
probs = vmap(sequence_probability, in_axes=(None, None, 0))(
    model, present, permuted_recalls
)

Memory

Pre-sampling permutations requires storage: - \(T\) trials × \(N\) samples × \(k\) items - Typically a few MB for standard datasets