Custom Analyses

Implementing your own analyses

jaxcmr analyses follow consistent patterns. You can create custom analyses that integrate with the library.

Analysis Pattern

Most analyses follow this structure:

Code
import jax.numpy as jnp
from jaxtyping import Array, Integer, Float

def my_analysis(
    recalls: Integer[Array, "trials max_recalls"],
    list_length: int,
    **kwargs
) -> Float[Array, "..."]:
    """
    Compute my custom analysis.

    Args:
        recalls: Recall sequences (1-indexed, 0 = padding)
        list_length: Number of items in each list
        **kwargs: Additional parameters

    Returns:
        Analysis results
    """
    # Your implementation here
    ...

Example: First Recall Probability

Probability of recalling each serial position first:

Code
def first_recall_probability(
    recalls: Integer[Array, "trials max_recalls"],
    list_length: int,
) -> Float[Array, "list_length"]:
    """Probability each position is recalled first."""
    # Get first recall for each trial
    first_recalls = recalls[:, 0]

    # Count how often each position is first
    counts = jnp.zeros(list_length)
    for pos in range(1, list_length + 1):
        counts = counts.at[pos - 1].set(jnp.sum(first_recalls == pos))

    # Normalize
    return counts / jnp.sum(counts)

JIT Compatibility

For performance, ensure your analysis is JIT-compatible:

Code
import jax

@jax.jit
def my_fast_analysis(recalls, list_length):
    ...

Avoid Python loops over data—use JAX operations instead.

Vectorization

Use jax.vmap to apply analyses across subjects or conditions:

Code
# Apply analysis to each subject separately
subject_results = jax.vmap(my_analysis)(subject_data)