set_input_size()
method to EVA models, used by OpenCLIP 3.0.0 to allow resizing for timm based encoder models.eva.py
) including EVA, EVA02, Meta PE ViT, timm
SBB ViT w/ ROPE, and Naver ROPE-ViT can be now loaded in NaFlexViT when use_naflex=True
passed at model creation timeeva.py
, add RotaryEmbeddingMixed module for mixed mode, weights on HuggingFace Hubforward_intermediates
and fix some checkpointing bugs. Thanks https://github.com/brianhou0208vision_transformer.py
can be loaded into the NaFlexVit model by adding the use_naflex=True
flag to create_model
train.py
and validate.py
add the --naflex-loader
arg, must be used with a NaFlexVitpython validate.py /imagenet --amp -j 8 --model vit_base_patch16_224 --model-kwargs use_naflex=True --naflex-loader --naflex-max-seq-len 256
--naflex-train-seq-lens'
argument specifies which sequence lengths to randomly pick from per batch during training--naflex-max-seq-len
argument sets the target sequence length for validation--model-kwargs enable_patch_interpolator=True --naflex-patch-sizes 12 16 24
will enable random patch size selection per-batch w/ interpolation--naflex-loss-scale
arg changes loss scaling mode per batch relative to the batch size, timm
NaFlex loading changes the batch size for each seq lentimm
weights
forward_intermediates()
and some additional fixes thanks to https://github.com/brianhou0208
forward_intermediates()
thanks to https://github.com/brianhou0208local-dir:
pretrained schema, can use local-dir:/path/to/model/folder
for model name to source model / pretrained cfg & weights Hugging Face Hub models (config.json + weights file) from a local folder.vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k
- 88.1% top-1vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k
- 87.9% top-1vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k
- 87.3% top-1vit_so150m2_patch16_reg4_gap_256.sbb_e200_in12k
timm
vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k
- 86.7% top-1vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k
- 87.4% top-1vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k
bfloat16
or float16
wandb
project name arg added by https://github.com/caojiaolong, use arg.experiment for nametorch.utils.checkpoint.checkpoint()
wrapper in timm.models
that defaults use_reentrant=False
, unless TIMM_REENTRANT_CKPT=1
is set in env.convnext_nano
384x384 ImageNet-12k pretrain & fine-tune. https://huggingface.co/models?search=convnext_nano%20r384vit_large_patch14_clip_224.dfn2b_s39b
RmsNorm
layer & fn to match standard formulation, use PT 2.5 impl when possible. Move old impl to SimpleNorm
layer, it's LN w/o centering or bias. There were only two timm
models using it, and they have been updated.cache_dir
arg for model creationtrust_remote_code
for HF datasets wrapperinception_next_atto
model added by creatorhf-hub:
based loading, and thus will work with new Transformers TimmWrapperModel
list_optimizers
, get_optimizer_class
, get_optimizer_info
to reworked create_optimizer_v2
fn to explore optimizers, get info or classoptim.optim_factory
, move fns to optim/_optim_factory.py
and optim/_param_groups.py
and encourage import via timm.optim
Add a set of new very well trained ResNet & ResNet-V2 18/34 (basic block) weights. See https://huggingface.co/blog/rwightman/resnet-trick-or-treat
timm.models.registry
, increased priority of existing deprecation warnings to be visibletimm
as vit_intern300m_patch14_448
mobilenet_edgetpu_v2_m
weights w/ ra4
mnv4-small based recipe. 80.1% top-1 @ 224 and 80.7 @ 256.set_input_size()
added to vit and swin v1/v2 models to allow changing image size, patch size, window size after model creation.set_input_size
, always_partition
and strict_img_size
args have been added to __init__
to allow more flexible input size constraintstiny
< .5M param models for testing that are actually trained on ImageNet-1ktimm
trained weights added:normalize=
flag for transforms, return non-normalized torch.Tensor with original dtype (for chug
)Searching for Better ViT Baselines (For the GPU Poor)
weights and vit variants released. Exploring model shapes between Tiny and Base.timm
models. See example in #1232 (comment)forward_intermediates()
API refined and added to more models including some ConvNets that have other extraction methods.features_only=True
feature extraction. Remaining 34 architectures can be supported but based on priority requests.features_only=True
support for ViT models with flat hidden states or non-std module layouts (so far covering 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
)forward_intermediates()
API that can be used with a feature wrapping module or directly.model = timm.create_model('vit_base_patch16_224') final_feat, intermediates = model.forward_intermediates(input) output = model.forward_head(final_feat) # pooling + classifier head print(final_feat.shape) torch.Size([2, 197, 768]) for f in intermediates: print(f.shape) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) torch.Size([2, 768, 14, 14]) print(output.shape) torch.Size([2, 1000])
model = timm.create_model('eva02_base_patch16_clip_224', pretrained=True, img_size=512, features_only=True, out_indices=(-3, -2,)) output = model(torch.randn(2, 3, 512, 512)) for o in output: print(o.shape) torch.Size([2, 768, 32, 32]) torch.Size([2, 768, 32, 32])
PyTorch Image Models (timm
) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.
The work of many others is present here. I've tried to make sure all source material is acknowledged via links to github, arxiv papers, etc in the README, documentation, and code docstrings. Please let me know if I missed anything.
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated.
To see full list of optimizers w/ descriptions: timm.optim.list_optimizers(with_description=True)
Included optimizers available via timm.optim.create_optimizer_v2
factory method:
adabelief
an implementation of AdaBelief adapted from https://github.com/juntang-zhuang/Adabelief-Optimizer - https://arxiv.org/abs/2010.07468adafactor
adapted from FAIRSeq impl - https://arxiv.org/abs/1804.04235adafactorbv
adapted from Big Vision - https://arxiv.org/abs/2106.04560adahessian
by David Samuel - https://arxiv.org/abs/2006.00719adamp
and sgdp
by Naver ClovAI - https://arxiv.org/abs/2006.08217adan
an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677adopt
ADOPT adapted from https://github.com/iShohei220/adopt - https://arxiv.org/abs/2411.02853kron
PSGD w/ Kronecker-factored preconditioner from https://github.com/evanatyourservice/kron_torch - https://sites.google.com/site/lixilinx/home/psgdlamb
an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962laprop
optimizer from https://github.com/Z-T-WANG/LaProp-Optimizer - https://arxiv.org/abs/2002.04839lars
an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888lion
and implementation of Lion adapted from https://github.com/google/automl/tree/master/lion - https://arxiv.org/abs/2302.06675lookahead
adapted from impl by Liam - https://arxiv.org/abs/1907.08610madgrad
an implementation of MADGRAD adapted from https://github.com/facebookresearch/madgrad - https://arxiv.org/abs/2101.11075mars
MARS optimizer from https://github.com/AGI-Arena/MARS - https://arxiv.org/abs/2411.10438nadam
an implementation of Adam w/ Nesterov momentumnadamw
an implementation of AdamW (Adam w/ decoupled weight-decay) w/ Nesterov momentum. A simplified impl based on https://github.com/mlcommons/algorithmic-efficiencynovograd
by Masashi Kimura - https://arxiv.org/abs/1905.11286radam
by Liyuan Liu - https://arxiv.org/abs/1908.03265rmsprop_tf
adapted from PyTorch RMSProp by myself. Reproduces much improved Tensorflow RMSProp behavioursgdw
and implementation of SGD w/ decoupled weight-decayfused<name>
optimizers by name with NVIDIA Apex installedbnb<name>
optimizers by name with BitsAndBytes installedcadamw
, clion
, and more 'Cautious' optimizers from https://github.com/kyleliang919/C-Optim - https://arxiv.org/abs/2411.16085adam
, adamw
, rmsprop
, adadelta
, adagrad
, and sgd
pass through to torch.optim
implementationsc
suffix (eg adamc
, nadamc
to implement 'corrected weight decay' in https://arxiv.org/abs/2506.02285)Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
get_classifier
and reset_classifier
forward_features
(see documentation)create_model(name, features_only=True, out_indices=..., output_stride=...)
out_indices
creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the C(i + 1)
feature level.output_stride
creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this..feature_info
memberstep
, cosine
w/ restarts, tanh
w/ restarts, plateau
Model validation results can be found in the results tables
Getting Started (Documentation)The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide by Chris Hughes is an extensive blog post covering many aspects of timm
in detail.
timmdocs is an alternate set of documentation for timm
. A big thanks to Aman Arora for his efforts creating timmdocs.
paperswithcode is a good resource for browsing the models within timm
.
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See documentation.
Awesome PyTorch ResourcesOne of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below.
Object Detection, Instance and Semantic SegmentationThe code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.
So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
Pretrained on more than ImageNetSeveral weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.
@misc{rw2019timm, author = {Ross Wightman}, title = {PyTorch Image Models}, year = {2019}, publisher = {GitHub}, journal = {GitHub repository}, doi = {10.5281/zenodo.4414861}, howpublished = {\url{https://github.com/rwightman/pytorch-image-models}} }
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4