Temporal Context

Context representation for memory search models

In retrieved-context models, TemporalContext is a vector of continuous values, each tied to a context unit. It’s initialized with a start-of-list unit set to 1, and one unit per study item set to 0.

Context Drift Equation

Context evolves as items are encoded and retrieved, integrating new input at each step:

\[c_i = \rho_i c_{i-1} + \beta \, c_{i}^{IN}\]

where \(\beta\) controls integration of new input \(c_i^{IN}\), and \(\rho_i\) normalizes the vector:

\[\rho_i = \sqrt{1 + \beta^2\bigl[(c_{i-1} \cdot c^{IN}_i)^2 - 1\bigr]} - \beta(c_{i-1} \cdot c^{IN}_i)\]

This gradual integration yields a recency-based gradient reflecting the order in which items were presented.

Setup

Code
import matplotlib.pyplot as plt
from jax import numpy as jnp
from jaxcmr.components.context import TemporalContext
from jaxcmr.typing import Context

Parameters

Code
item_count = 10
drift_rate = 0.5

Basic Usage

Initialize context and integrate a new input:

Code
# Initialize context (start-of-list unit = 1, all others = 0)
context = TemporalContext.init(item_count)
print(f"Initial context: {context.state}")
print(f"Context size: {context.size}")

# Integrate input for first item
context_input = jnp.zeros(context.size).at[1].set(1.0)  # Unit for item 1
new_context = context.integrate(context_input, drift_rate)
print(f"After item 1: {new_context.state}")

Visualizing Context Drift

Watch how context evolves as items are studied. The start-of-list unit (index 0) decays while item units (indices 1-N) activate.

Code
# Track context state during study
context = TemporalContext.init(item_count)
context_history = [context.state.copy()]

for i in range(item_count):
    # Create input for item i (unit i+1, since 0 is start-of-list)
    context_input = jnp.zeros(context.size).at[i + 1].set(1.0)
    context = context.integrate(context_input, drift_rate)
    context_history.append(context.state.copy())

context_history = jnp.stack(context_history)

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(context_history.T, aspect='auto', cmap='viridis')
ax.set_xlabel('Study Position')
ax.set_ylabel('Context Unit')
ax.set_title(f'Context Drift During Study (β = {drift_rate})')
ax.set_xticks(range(item_count + 1))
ax.set_xticklabels(['Init'] + [str(i+1) for i in range(item_count)])
ax.set_yticks([0, item_count])
ax.set_yticklabels(['Start', f'Item {item_count}'])
plt.colorbar(im, label='Activation')
plt.tight_layout()
plt.show()

Context States at Different Study Positions

Compare context vectors after studying different numbers of items:

Code
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
positions = [0, item_count // 2, item_count]
labels = ['Initial', f'After {item_count // 2} items', f'After {item_count} items']

for ax, pos, label in zip(axes, positions, labels):
    ax.bar(range(context.size), context_history[pos])
    ax.set_xlabel('Context Unit')
    ax.set_ylabel('Activation')
    ax.set_title(label)
    ax.set_ylim(0, 1.1)

plt.tight_layout()
plt.show()

Effect of Drift Rate

Higher drift rates cause faster context change:

Code
drift_rates = [0.2, 0.5, 0.8]
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

for ax, beta in zip(axes, drift_rates):
    context = TemporalContext.init(item_count)
    history = [context.state.copy()]
    
    for i in range(item_count):
        context_input = jnp.zeros(context.size).at[i + 1].set(1.0)
        context = context.integrate(context_input, beta)
        history.append(context.state.copy())
    
    history = jnp.stack(history)
    im = ax.imshow(history.T, aspect='auto', cmap='viridis', vmin=0, vmax=1)
    ax.set_xlabel('Study Position')
    ax.set_ylabel('Context Unit')
    ax.set_title(f'β = {beta}')

plt.suptitle('Context Drift at Different Rates')
plt.tight_layout()
plt.show()

Protocol Compliance

TemporalContext implements the Context protocol:

Code
from nbdev.showdoc import show_doc

show_doc(Context)