ModelBatch Implementation Summary

Status as of July 2025

🎯 Project Overview

ModelBatch is a library for training many independent PyTorch models simultaneously on a single GPU by grouping them into a single ModelBatch object. This enables efficient hyperparameter sweeps and maximizes GPU utilization.

⚠️ Current Status: Active Development

Implementation Milestones

M1: Core ModelBatch + demo (stable for standard models) - ModelBatch class with parameter stacking - OptimizerFactory for per-model optimizer configs - DataRouter for data filtering -- untested - CallbackPack for monitoring -- untested - Working demos with performance benchmarks

M2: OptimizerFactory + AMP (consistent APIs, AMP parity pending) - OptimizerFactory for per-model optimizer configs -- tested & passing - AMP support with GradScaler -- training works but differs from sequential runs

M3: HuggingFace integration - HFModelBatch for transformer models - ModelBatchTrainer wrapper for Trainer

Additional Work (not on original roadmap) - logger.py provides structured logging and context managers - optuna_integration.py enables batched hyperparameter search

🔄 M4: Lightning example + docs 🔄 M5: Benchmark suite 🔄 M6: v1.0 release

Known Issues

  1. Training Equivalence: Batched training now matches sequential training unless dropout is used. Dropout randomness remains different despite setting seeds.
  2. See examples/cifar10_lenet_benchmark.py and examples/quick_consistency_test_dropout.py
  3. AMP Training Equivalence: AMP training is supported, but used a batch-level GradScaler. This leads to different scaling behavior (& consequently overflow handling) compared to sequential training.
  4. See tests/test_amp_optimizer.py::test_amp_overflow_handling
  5. ~~LSTM Models~~ (dropped for now): RuntimeError: Batching rule not implemented for aten::lstm.input. This is because LSTM module is not supported by torch.vmap.

🏗️ Project Structure

ModelBatch/
├── src/modelbatch/
│   ├── core.py          # ModelBatch class (main component)
│   ├── optimizer.py     # OptimizerFactory + AMP support
│   ├── data.py          # DataRouter for data filtering
│   ├── callbacks.py     # CallbackPack for monitoring
│   ├── huggingface_integration.py  # HuggingFace models & Trainer adapters
│   ├── optuna_integration.py      # Optuna study helpers
│   ├── logger.py        # Structured logging utilities
│   └── utils.py         # Training utilities
├── tests/               # Unit tests
├── examples/            # Demo scripts
└── docs/               # Design documentation