High-performance, extensible, chainable optimizers for PyTorch.
foreach
operations deliver significant speedups on large models.torch.optim
with seamless integration.ForeachAdamW
, ForeachRMSprop
, ForeachSFAdamW
, Muon
, ADOPT
, MSAM
, …benchmark/
).Install:
Basic usage:
import torch from torch import nn from heavyball import ForeachAdamW model = nn.Sequential( nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10) ) optimizer = ForeachAdamW(model.parameters(), lr=1e-3) for data, target in dataloader: optimizer.zero_grad() output = model(data) loss = torch.nn.functional.cross_entropy(output, target) loss.backward() optimizer.step()
Reproduce benchmarks with:
python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
We welcome contributions! Please check the issue tracker and follow these steps:
pip install -e .[dev]
.pytest
.BSD 3-Clause — see the LICENSE file.
Made by the HeavyBall team.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4