API Reference¶
This reference covers the core ModelBatch container, data routing utilities, and
optimizer helpers. Shapes follow PyTorch conventions with batch dimensions shown
explicitly.
Core¶
ModelBatch¶
class modelbatch.core.ModelBatch(models: list[nn.Module], shared_input: bool = True)
Stacks parameters from identical models and executes them in parallel with
torch.vmap.
Parameters¶
- models (
list[nn.Module]): Models with identical structure. - shared_input (
bool, optional): IfTrue, all models receive the same input tensor. Default isTrue.
forward(inputs: torch.Tensor) -> torch.Tensor¶
Vectorized forward pass.
- inputs (
Tensor): [batch, ...]whenshared_input=True[num_models, batch, ...]whenshared_input=False- returns (
Tensor):[num_models, batch, ...]
batch = torch.randn(64, 3, 32, 32)
mb = ModelBatch(models)
outputs = mb(batch) # [num_models, 64, 10]
zero_grad(set_to_none: bool = True) -> None¶
Clear gradients for all parameters. Pass set_to_none=False to zero tensors
instead of setting to None.
compute_loss(outputs: torch.Tensor, targets: torch.Tensor, loss_fn: Callable, reduction: str = "mean") -> torch.Tensor¶
Compute per-model losses.
- outputs (
Tensor):[num_models, batch, ...] - targets (
Tensor): [batch, ...]for shared targets[num_models, batch, ...]for per-model targets- returns (
Tensor): - scalar when
reduction="mean"or"sum" [num_models]whenreduction="none"
loss = mb.compute_loss(outputs, targets, torch.nn.functional.cross_entropy)
loss.backward()
get_model_states() -> list[dict[str, torch.Tensor]]¶
Return individual state_dict objects in [num_models] list form.
load_model_states(states: list[dict[str, torch.Tensor]]) -> None¶
Load per-model states produced by get_model_states.
save_all(path: str) -> None¶
Persist all model states under path/model_{i}.pt.
load_all(path: str) -> None¶
Load states saved by save_all.
enable_compile(**kwargs) -> None¶
Wrap the internal model with torch.compile.
metrics() -> dict[str, float]¶
Return latest per-model metrics such as {"loss_model_0": 0.1}.
Data¶
DataRouter¶
class modelbatch.data.DataRouter(mode: str = "passthrough")
Routes a batch to specific models using masks or index tensors.
- mode (
str):"passthrough","mask", or"indices".
route_batch(batch: torch.Tensor, masks: list[torch.Tensor] | None = None, indices: list[torch.Tensor] | None = None) -> torch.Tensor¶
- batch (
Tensor):[batch, ...] - masks (
list[Tensor], optional): one boolean mask per model,[batch]. - indices (
list[Tensor], optional): one index tensor per model. - returns (
Tensor): - passthrough: original batch
[batch, ...] - mask/indices:
[num_models, max_subset, ...]
router = DataRouter(mode="mask")
masked = router.route_batch(batch, masks=create_random_masks(len(batch), num_models))
StratifiedRouter¶
class modelbatch.data.StratifiedRouter(num_models: int, strategy: str = "round_robin")
Generates stratified index tensors to balance data across models.
- num_models (
int): number of models. - strategy (
str):"round_robin","random", or"class_based".
create_stratified_indices(labels: torch.Tensor, num_classes: int | None = None) -> list[torch.Tensor]¶
Return index tensors per model ensuring each model sees a balanced subset of
labels ([batch]).
create_random_masks(batch_size: int, num_models: int, subset_ratio: float = 0.8) -> list[torch.Tensor]¶
Generate boolean masks [batch] for each model selecting a random subset of the
batch.
Optimizer¶
OptimizerFactory¶
class modelbatch.optimizer.OptimizerFactory(optimizer_cls: type[Optimizer], base_config: dict[str, Any] | None = None)
Build optimizers and schedulers with per-model parameter groups.
- optimizer_cls (
type[Optimizer]): class liketorch.optim.SGD. - base_config (
dict[str, Any], optional): shared config for all groups.
create_optimizer(model_batch: ModelBatch, configs: list[dict[str, Any]]) -> Optimizer¶
Return optimizer with one parameter group per model. configs should have length
num_models.
create_lr_scheduler(optimizer: Optimizer, scheduler_cls: type, configs: list[dict[str, Any]]) -> list¶
Return a list of schedulers, one per parameter group.
Helper Functions¶
create_sgd_configs(learning_rates: list[float], momentum: float = 0.9, weight_decay: float = 1e-4) -> list[dict[str, Any]]create_adam_configs(learning_rates: list[float], betas: tuple = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0) -> list[dict[str, Any]]create_lr_sweep_configs(min_lr: float, max_lr: float, num_models: int, scale: str = "log") -> list[dict[str, float]]
factory = OptimizerFactory(torch.optim.SGD)
opt = factory.create_optimizer(mb, create_sgd_configs([1e-2] * mb.num_models))