How models are composed from protocols

jaxcmr models are composed from three core components, each defined by a protocol. This enables swapping implementations without changing model code.

The MemorySearch Protocol

All memory search models implement the MemorySearch protocol:

Code
from nbdev.showdoc import show_doc
from jaxcmr.typing import MemorySearch

show_doc(MemorySearch)

Component Protocols

Models are built from three component types:

Context

Maintains temporal context that drifts during study and retrieval:

Code
from jaxcmr.typing import Context

show_doc(Context)

Memory

Stores and retrieves associations:

Code
from jaxcmr.typing import Memory

show_doc(Memory)

TerminationPolicy

Determines when to stop recalling:

Code
from jaxcmr.typing import TerminationPolicy

show_doc(TerminationPolicy)

Composition

A CMR model combines these components:

MemorySearch
├── Context (temporal context state)
├── Memory (MFC: item → context)
├── Memory (MCF: context → item)
└── TerminationPolicy (when to stop)

Swapping Components

CMR accepts factory functions for creating each component:

Code
from jaxcmr.models.cmr import CMR
from jaxcmr.components.termination import PositionalTermination, NoStopTermination

params = {
    "encoding_drift_rate": 0.5,
    "start_drift_rate": 0.5,
    "recall_drift_rate": 0.5,
    "learning_rate": 0.5,
    "primacy_scale": 2.0,
    "primacy_decay": 0.8,
    "shared_support": 0.05,
    "item_support": 0.25,
    "choice_sensitivity": 0.6,
    "stop_probability_scale": 0.05,
    "stop_probability_growth": 0.2,
}

# Standard CMR with positional termination
model_positional = CMR(
    list_length=16,
    parameters=params,
    termination_policy_create_fn=PositionalTermination,
)

# CMR with no-stop termination (recalls until exhausted)
model_no_stop = CMR(
    list_length=16,
    parameters=params,
    termination_policy_create_fn=NoStopTermination,
)

print(f"Positional termination model created")
print(f"No-stop termination model created")

Comparing Termination Policies

Different termination policies produce different stopping behavior:

Code
# Study items
for item_id in range(1, 17):
    model_positional = model_positional.experience(item_id)
    model_no_stop = model_no_stop.experience(item_id)

# Start retrieval
model_positional = model_positional.start_retrieving()
model_no_stop = model_no_stop.start_retrieving()

# Compare stop probabilities
print("Stop probabilities at output position 0:")
print(f"  Positional: {model_positional.termination_policy.stop_probability(model_positional):.3f}")
print(f"  No-stop: {model_no_stop.termination_policy.stop_probability(model_no_stop):.3f}")

# Simulate a few recalls and check how stop probability changes
for i in range(3):
    model_positional = model_positional.retrieve(16 - i)
    model_no_stop = model_no_stop.retrieve(16 - i)
    
print(f"\nAfter 3 recalls:")
print(f"  Positional: {model_positional.termination_policy.stop_probability(model_positional):.3f}")
print(f"  No-stop: {model_no_stop.termination_policy.stop_probability(model_no_stop):.3f}")

Available Components

jaxcmr provides these built-in implementations:

Component Implementations
Context TemporalContext
Memory LinearMemory (MFC, MCF variants)
Termination PositionalTermination, NoStopTermination, SupportRatioTermination, RetrievalDependentTermination

Model variants in jaxcmr.models often customize one or more of these components.