ModelBatch Implementation Summary¶
Status as of July 2025
🎯 Project Overview¶
ModelBatch is a library for training many independent PyTorch models simultaneously on a single GPU by grouping them into a single ModelBatch object. This enables efficient hyperparameter sweeps and maximizes GPU utilization.
⚠️ Current Status: Active Development¶
Implementation Milestones¶
✅ M1: Core ModelBatch + demo (stable for standard models) - ModelBatch class with parameter stacking - OptimizerFactory for per-model optimizer configs - DataRouter for data filtering -- untested - CallbackPack for monitoring -- untested - Working demos with performance benchmarks
✅ M2: OptimizerFactory + AMP (consistent APIs, AMP parity pending) - OptimizerFactory for per-model optimizer configs -- tested & passing - AMP support with GradScaler -- training works but differs from sequential runs
✅ M3: HuggingFace integration
- HFModelBatch for transformer models
- ModelBatchTrainer wrapper for Trainer
✨ Additional Work (not on original roadmap)
- logger.py provides structured logging and context managers
- optuna_integration.py enables batched hyperparameter search
🔄 M4: Lightning example + docs 🔄 M5: Benchmark suite 🔄 M6: v1.0 release
Known Issues¶
- Training Equivalence: Batched training now matches sequential training unless dropout is used. Dropout randomness remains different despite setting seeds.
- See
examples/cifar10_lenet_benchmark.pyandexamples/quick_consistency_test_dropout.py - AMP Training Equivalence: AMP training is supported, but used a batch-level GradScaler. This leads to different scaling behavior (& consequently overflow handling) compared to sequential training.
- See
tests/test_amp_optimizer.py::test_amp_overflow_handling - ~~LSTM Models~~ (dropped for now):
RuntimeError: Batching rule not implemented for aten::lstm.input. This is because LSTM module is not supported bytorch.vmap.
🏗️ Project Structure¶
ModelBatch/
├── src/modelbatch/
│ ├── core.py # ModelBatch class (main component)
│ ├── optimizer.py # OptimizerFactory + AMP support
│ ├── data.py # DataRouter for data filtering
│ ├── callbacks.py # CallbackPack for monitoring
│ ├── huggingface_integration.py # HuggingFace models & Trainer adapters
│ ├── optuna_integration.py # Optuna study helpers
│ ├── logger.py # Structured logging utilities
│ └── utils.py # Training utilities
├── tests/ # Unit tests
├── examples/ # Demo scripts
└── docs/ # Design documentation