Code
import matplotlib.pyplot as plt
from jax import numpy as jnp
from jaxcmr.components.context import TemporalContext
from jaxcmr.typing import ContextContext representation for memory search models
In retrieved-context models, TemporalContext is a vector of continuous values, each tied to a context unit. It’s initialized with a start-of-list unit set to 1, and one unit per study item set to 0.
Context evolves as items are encoded and retrieved, integrating new input at each step:
\[c_i = \rho_i c_{i-1} + \beta \, c_{i}^{IN}\]
where \(\beta\) controls integration of new input \(c_i^{IN}\), and \(\rho_i\) normalizes the vector:
\[\rho_i = \sqrt{1 + \beta^2\bigl[(c_{i-1} \cdot c^{IN}_i)^2 - 1\bigr]} - \beta(c_{i-1} \cdot c^{IN}_i)\]
This gradual integration yields a recency-based gradient reflecting the order in which items were presented.
Initialize context and integrate a new input:
# Initialize context (start-of-list unit = 1, all others = 0)
context = TemporalContext.init(item_count)
print(f"Initial context: {context.state}")
print(f"Context size: {context.size}")
# Integrate input for first item
context_input = jnp.zeros(context.size).at[1].set(1.0) # Unit for item 1
new_context = context.integrate(context_input, drift_rate)
print(f"After item 1: {new_context.state}")Watch how context evolves as items are studied. The start-of-list unit (index 0) decays while item units (indices 1-N) activate.
# Track context state during study
context = TemporalContext.init(item_count)
context_history = [context.state.copy()]
for i in range(item_count):
# Create input for item i (unit i+1, since 0 is start-of-list)
context_input = jnp.zeros(context.size).at[i + 1].set(1.0)
context = context.integrate(context_input, drift_rate)
context_history.append(context.state.copy())
context_history = jnp.stack(context_history)
# Plot
fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(context_history.T, aspect='auto', cmap='viridis')
ax.set_xlabel('Study Position')
ax.set_ylabel('Context Unit')
ax.set_title(f'Context Drift During Study (β = {drift_rate})')
ax.set_xticks(range(item_count + 1))
ax.set_xticklabels(['Init'] + [str(i+1) for i in range(item_count)])
ax.set_yticks([0, item_count])
ax.set_yticklabels(['Start', f'Item {item_count}'])
plt.colorbar(im, label='Activation')
plt.tight_layout()
plt.show()Compare context vectors after studying different numbers of items:
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
positions = [0, item_count // 2, item_count]
labels = ['Initial', f'After {item_count // 2} items', f'After {item_count} items']
for ax, pos, label in zip(axes, positions, labels):
ax.bar(range(context.size), context_history[pos])
ax.set_xlabel('Context Unit')
ax.set_ylabel('Activation')
ax.set_title(label)
ax.set_ylim(0, 1.1)
plt.tight_layout()
plt.show()Higher drift rates cause faster context change:
drift_rates = [0.2, 0.5, 0.8]
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, beta in zip(axes, drift_rates):
context = TemporalContext.init(item_count)
history = [context.state.copy()]
for i in range(item_count):
context_input = jnp.zeros(context.size).at[i + 1].set(1.0)
context = context.integrate(context_input, beta)
history.append(context.state.copy())
history = jnp.stack(history)
im = ax.imshow(history.T, aspect='auto', cmap='viridis', vmin=0, vmax=1)
ax.set_xlabel('Study Position')
ax.set_ylabel('Context Unit')
ax.set_title(f'β = {beta}')
plt.suptitle('Context Drift at Different Rates')
plt.tight_layout()
plt.show()TemporalContext implements the Context protocol:
CMR: How context drift fits into the full model
Linear Memory: Association matrices that interact with context