Model Comparison

Code
import os
import json

import numpy as np
import pandas as pd

from jaxcmr.summarize import (
    calculate_aic_weights,
    calculate_aicc_weights,
    generate_t_p_matrices,
    summarize_parameters,
    winner_comparison_matrix,
    aicc_winner_comparison_matrix,
    raw_winner_comparison_matrix,
    pairwise_aic_differences,
    pairwise_aicc_differences,
    pairwise_median_aic_differences,
    pairwise_median_aicc_differences,
    bayesian_model_selection,
    floor_nested_fitness,
)
from jaxcmr.helpers import find_project_root, load_data, generate_trial_mask
Code
# Parameters
data_names = ["LohnasKahana2014"]
fit_dir = "data/fits/"
table_dir = "results/tables"
run_tag = "Model_Comparison"
use_aicc = True

model_class_names = ["WeirdCMRNoStop", "BasePositionalCMRNoStop"]
model_display_titles = ["CMR", "Positional CMR"]
model_fit_tags = ["rerun_best_of_1", "rerun_best_of_1"]

dataset_groups = {}
nesting_pairs = []
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "choice_sensitivity",
]

# AICc-specific (single-dataset mode only)
data_path = "data/LohnasKahana2014.h5"
trial_query = "data['list_type'] > 0"
Code
project_root = find_project_root()
n_obs_map = {}  # {data_name: n_obs array or None}

if use_aicc and len(data_names) == 1 and data_path:
    data = load_data(os.path.join(project_root, data_path))
    trial_mask = generate_trial_mask(data, trial_query)
    subjects = data["subject"].flatten()
    unique_subjects = np.unique(subjects)
    n_obs_map[data_names[0]] = np.array([int(np.sum(trial_mask[subjects == s])) for s in unique_subjects])

# Ensure data_names derived from data_path in single-dataset AICc mode
if len(data_names) == 1 and not data_names[0] and data_path:
    data_names = [os.path.splitext(os.path.basename(data_path))[0]]

print(f"Datasets: {data_names}")
print(f"Models: {len(model_class_names)}")
Code
all_results = {}

for data_name in data_names:
    results = []
    for model_name, model_title, fit_tag in zip(model_class_names, model_display_titles, model_fit_tags):
        fit_path = os.path.join(
            project_root, fit_dir,
            f"{data_name}_{model_name}_{fit_tag}.json",
        )
        with open(fit_path) as f:
            result = json.load(f)
        if "subject" not in result["fits"]:
            result["fits"]["subject"] = result["subject"]
        if "allow_repeated_recalls" not in result["fixed"]:
            result["fixed"]["allow_repeated_recalls"] = False
            result["fits"]["allow_repeated_recalls"] = [False] * len(result["fits"]["subject"])
        if "mfc_trace_sensitivity" in result.get("free", {}):
            result["free"]["repetition_orthogonality"] = result["free"].pop("mfc_trace_sensitivity")
            result["fits"]["repetition_orthogonality"] = result["fits"].pop("mfc_trace_sensitivity")
        result["name"] = model_title
        results.append(result)
    all_results[data_name] = results

if nesting_pairs:
    for data_name, results in all_results.items():
        n_floored = floor_nested_fitness(results, nesting_pairs)
        if n_floored:
            print(f"{data_name}: floored {n_floored} entries")

print(f"Loaded {sum(len(r) for r in all_results.values())} model fits across {len(data_names)} datasets")
Code
def save_table(content, metric, prefix=""):
    """Save a markdown table to disk."""
    fname = f"{prefix}_{run_tag}_{metric}.md" if prefix else f"{run_tag}_{metric}.md"
    table_path = os.path.join(project_root, table_dir, fname)
    os.makedirs(os.path.dirname(table_path), exist_ok=True)
    with open(table_path, "w") as f:
        f.write(content)

Per-Dataset Analyses

Code
for data_name, results in all_results.items():
    n_obs = n_obs_map.get(data_name)
    names = [r["name"] for r in results]
    N = len(names)

    print(f"\n{'=' * 60}")
    print(f"  {data_name}")
    print(f"{'=' * 60}")

    # Parameter summary
    summary = summarize_parameters(results, query_parameters, include_std=True, include_ci=True)
    save_table(summary, "parameters", data_name)
    print("\n## Parameter Summary")
    print(summary)

    # T-test
    df_t, df_p = generate_t_p_matrices(results)
    save_table(df_p.to_markdown(), "t_test_p_values", data_name)
    print("\n## T-Test P-Values")
    print(df_p.to_markdown())

    # Winner ratios
    if use_aicc and n_obs is not None:
        df_winner = aicc_winner_comparison_matrix(results, n_obs)
    else:
        df_winner = winner_comparison_matrix(results)
    save_table(df_winner.to_markdown().replace(" nan ", "     "), "winner_ratios", data_name)
    print("\n## Winner Ratios")
    print(df_winner.to_markdown().replace(" nan ", "     "))

    df_raw = raw_winner_comparison_matrix(results)
    save_table(df_raw.to_markdown().replace(" nan ", "     "), "raw_winner_ratios", data_name)
    print("\n## Raw Winner Ratios")
    print(df_raw.to_markdown().replace(" nan ", "     "))

    # Delta AIC
    ic_label = "AICc" if (use_aicc and n_obs is not None) else "AIC"
    if use_aicc and n_obs is not None:
        mean_delta, ci_hw, _ = pairwise_aicc_differences(results, n_obs)
    else:
        mean_delta, ci_hw, _ = pairwise_aic_differences(results)

    delta_table = mean_delta.copy().astype(object)
    for rl in delta_table.index:
        for cl in delta_table.columns:
            if rl == cl:
                delta_table.loc[rl, cl] = ""
            else:
                mv, ci = mean_delta.loc[rl, cl], ci_hw.loc[rl, cl]
                if mv != mv or ci != ci:
                    delta_table.loc[rl, cl] = ""
                else:
                    delta_table.loc[rl, cl] = f"{mv:.2f} [{mv - ci:.2f}, {mv + ci:.2f}]"
    save_table(delta_table.to_markdown(), f"delta_{ic_label.lower()}", data_name)
    print(f"\n## Pairwise Δ{ic_label}")
    print(delta_table.to_markdown())

    if use_aicc and n_obs is not None:
        med_delta, med_ci = pairwise_median_aicc_differences(results, n_obs)
    else:
        med_delta, med_ci = pairwise_median_aic_differences(results)
    med_table = med_delta.copy().astype(object)
    for rl in med_table.index:
        for cl in med_table.columns:
            if rl == cl:
                med_table.loc[rl, cl] = ""
            else:
                m, ci = med_delta.loc[rl, cl], med_ci.loc[rl, cl]
                if m != m or ci != ci:
                    med_table.loc[rl, cl] = ""
                else:
                    med_table.loc[rl, cl] = f"{m:.2f} [{m - ci:.2f}, {m + ci:.2f}]"
    save_table(med_table.to_markdown(), f"median_delta_{ic_label.lower()}", data_name)
    print(f"\n## Median Δ{ic_label}")
    print(med_table.to_markdown())

    # Pairwise AICw
    col_name = "AICcw" if (use_aicc and n_obs is not None) else "AICw"
    aicw_fn = (lambda pair: calculate_aicc_weights(pair, n_obs)) if (use_aicc and n_obs is not None) else calculate_aic_weights
    aicw_matrix = pd.DataFrame("", index=names, columns=names)
    for i in range(N):
        for j in range(i + 1, N):
            pair = [results[i], results[j]]
            w = aicw_fn(pair).set_index("Model")
            aicw_matrix.loc[names[i], names[j]] = f"{w.loc[names[i], col_name]:.4f}"
            aicw_matrix.loc[names[j], names[i]] = f"{w.loc[names[j], col_name]:.4f}"
    save_table(aicw_matrix.to_markdown(), "pairwise_aicw", data_name)
    print(f"\n## Pairwise {col_name}")
    print(aicw_matrix.to_markdown())

    # Pairwise BMS
    xp_matrix = pd.DataFrame("", index=names, columns=names)
    pxp_matrix = pd.DataFrame("", index=names, columns=names)
    for i in range(N):
        for j in range(i + 1, N):
            pair = [results[i], results[j]]
            bms = bayesian_model_selection(pair, n_obs=n_obs).set_index("Model")
            xp_matrix.loc[names[i], names[j]] = f"{bms.loc[names[i], 'Exceedance Probability']:.4f}"
            xp_matrix.loc[names[j], names[i]] = f"{bms.loc[names[j], 'Exceedance Probability']:.4f}"
            pxp_matrix.loc[names[i], names[j]] = f"{bms.loc[names[i], 'Protected XP']:.4f}"
            pxp_matrix.loc[names[j], names[i]] = f"{bms.loc[names[j], 'Protected XP']:.4f}"
    save_table(xp_matrix.to_markdown(), "pairwise_xp", data_name)
    save_table(pxp_matrix.to_markdown(), "pairwise_pxp", data_name)
    print("\n## Pairwise Exceedance Probability")
    print(xp_matrix.to_markdown())
    print("\n## Pairwise Protected XP")
    print(pxp_matrix.to_markdown())

Pooled Analyses

Code
def get_aic_per_subject(result):
    fitness = np.array(result["fitness"])
    k = len(result["free"])
    return 2 * k + 2 * fitness

def get_total_aic(result):
    return np.sum(get_aic_per_subject(result))

def get_mean_aic(result):
    return np.mean(get_aic_per_subject(result))

def results_by_title(data_name):
    return {r["name"]: r for r in all_results[data_name]}

def compute_summed_aic(datasets):
    summed = {}
    for title in model_display_titles:
        total = sum(get_total_aic(results_by_title(dn)[title]) for dn in datasets)
        summed[title] = total
    return summed

def summed_aic_table(datasets):
    summed = compute_summed_aic(datasets)
    df = pd.DataFrame({"Model": list(summed.keys()), "Summed AIC": list(summed.values())})
    df = df.sort_values("Summed AIC").reset_index(drop=True)
    min_aic = df["Summed AIC"].min()
    df["ΔAIC"] = df["Summed AIC"] - min_aic
    df["Relative Likelihood"] = np.exp(-0.5 * df["ΔAIC"])
    df["AIC Weight"] = df["Relative Likelihood"] / df["Relative Likelihood"].sum()
    return df

def compute_ranks_per_dataset(datasets):
    ranks = {}
    for dn in datasets:
        model_aic = {t: get_mean_aic(results_by_title(dn)[t]) for t in model_display_titles}
        sorted_models = sorted(model_aic, key=model_aic.get)
        ranks[dn] = {m: r + 1 for r, m in enumerate(sorted_models)}
    return ranks

def rank_aggregation_table(datasets, group_datasets=None):
    ranks = compute_ranks_per_dataset(datasets)
    data = []
    for title in model_display_titles:
        row = {"Model": title}
        for dn in datasets:
            short = dn.replace("RepeatedRecalls", "").replace("Kahana", "K").replace("Gordon", "G").replace("Ranschburg", "R")
            row[short] = ranks[dn][title]
        row["Mean Rank"] = np.mean([ranks[dn][title] for dn in datasets])
        if group_datasets:
            for gname, gds in group_datasets.items():
                gds_in = [d for d in gds if d in datasets]
                if gds_in:
                    row[f"Mean Rank ({gname})"] = np.mean([ranks[dn][title] for dn in gds_in])
        data.append(row)
    return pd.DataFrame(data).sort_values("Mean Rank").reset_index(drop=True)

def meta_analytic_pooling(datasets, ref_model):
    pooled = {}
    for title in model_display_titles:
        if title == ref_model:
            pooled[title] = {"pooled_mean": 0, "pooled_se": 0, "ci_lower": 0, "ci_upper": 0}
            continue
        means, weights = [], []
        for dn in datasets:
            rbt = results_by_title(dn)
            ref_aic = get_aic_per_subject(rbt[ref_model])
            mod_aic = get_aic_per_subject(rbt[title])
            delta = mod_aic - ref_aic
            se = np.std(delta, ddof=1) / np.sqrt(len(delta))
            if se > 0:
                means.append(np.mean(delta))
                weights.append(1 / se**2)
        if weights:
            w, m = np.array(weights), np.array(means)
            pm_ = np.sum(w * m) / np.sum(w)
            pse = np.sqrt(1 / np.sum(w))
            pooled[title] = {"pooled_mean": pm_, "pooled_se": pse, "ci_lower": pm_ - 1.96 * pse, "ci_upper": pm_ + 1.96 * pse}
        else:
            pooled[title] = {"pooled_mean": np.nan, "pooled_se": np.nan, "ci_lower": np.nan, "ci_upper": np.nan}
    return pooled

def meta_analysis_table(datasets):
    summed = compute_summed_aic(datasets)
    best = min(summed, key=summed.get)
    pooled = meta_analytic_pooling(datasets, best)
    data = []
    for title in model_display_titles:
        p = pooled[title]
        reliable = "Yes" if p["ci_lower"] > 0 or p["ci_upper"] < 0 else ("No" if title != best else "-")
        data.append({"Model": title, "Pooled ΔAIC": f"{p['pooled_mean']:.2f}",
                      "95% CI": f"[{p['ci_lower']:.2f}, {p['ci_upper']:.2f}]", "Reliably Worse?": reliable})
    df = pd.DataFrame(data)
    df["_s"] = [pooled[m]["pooled_mean"] for m in df["Model"]]
    df = df.sort_values("_s").drop("_s", axis=1).reset_index(drop=True)
    print(f"Reference model (best by summed AIC): {best}")
    return df
Code
if len(data_names) > 1:
    groups = dict(dataset_groups) if dataset_groups else {}
    groups["all"] = list(data_names)

    for group_name, group_datasets in groups.items():
        print(f"\n{'=' * 60}")
        print(f"  POOLED: {group_name} ({len(group_datasets)} datasets)")
        print(f"{'=' * 60}")

        df_summed = summed_aic_table(group_datasets)
        save_table(df_summed.to_markdown(index=False), "summed_aic", f"pooled_{group_name}")
        print("\n## Summed AIC")
        print(df_summed.to_string(index=False))

        df_meta = meta_analysis_table(group_datasets)
        save_table(df_meta.to_markdown(index=False), "meta_analysis", f"pooled_{group_name}")
        print("\n## Meta-Analytic ΔAIC")
        print(df_meta.to_string(index=False))

    # Rank table across all datasets with group breakdowns
    df_ranks = rank_aggregation_table(data_names, dataset_groups if dataset_groups else None)
    save_table(df_ranks.to_markdown(index=False), "ranks", "pooled_all")
    print(f"\n{'=' * 60}")
    print("  RANK AGGREGATION")
    print(f"{'=' * 60}")
    print(df_ranks.to_string(index=False))

    # Summary
    print(f"\n{'=' * 60}")
    print("  SUMMARY")
    print(f"{'=' * 60}")
    for gname, gds in groups.items():
        df_s = summed_aic_table(gds)
        print(f"\n{gname}: Summed AIC winner = {df_s.iloc[0]['Model']}")
else:
    print("Single dataset — pooled analyses skipped.")