ModelBatch

Train many independent PyTorch models simultaneously on a single GPU using vectorized operations.

ModelBatch eliminates GPU waste by training multiple independent models in a single vectorized step using torch.vmap. Achieve near-linear speedup until VRAM or compute saturates.

Key Features

  • Massive Speedups: almost linear speedup with correct setup
  • Single GPU Efficiency: Max out GPU utilization with many small models
  • Drop-in Replacement: Minimal code changes to existing PyTorch workflows
  • Framework Integration: (Hopes to) work with HuggingFace, PyTorch Lightning
  • Per-model Isolation: Separate parameters, optimizers, and metrics

Getting Started

New to ModelBatch? See the getting started guide for installation, a basic example, and common development commands.

API Reference

The API reference covers the core ModelBatch class, data routing helpers, and optimizer utilities.

Additional Resources