Skip to content

sayakpaul/ViT-jax2tf

Repository files navigation

ViT-jax2tf


Example usage.

This repository hosts code for converting the original Vision Transformer models [1] (JAX) to TensorFlow.

The original models were fine-tuned on the ImageNet-1k dataset [2]. For more details on the training protocols, please follow [3]. The authors of [3] open-sourced about 50k different variants of Vision Transformer models in JAX. Using the conversion.ipynb notebook, one should be able to take a model from the pool of models and convert that to TensorFlow and use that with TensorFlow Hub and Keras.

The original model classes and weights [4] were converted using the jax2tf tool [5].

Note that it's a requirement to use TensorFlow 2.6 or greater to use the converted models.

Vision Transformers on TensorFlow Hub

Find the model collection on TensorFlow Hub: https://tfhub.dev/sayakpaul/collections/vision_transformer/1.

Eight best performing ImageNet-1k models have also been made available on TensorFlow Hub that can be used either for off-the-shelf image classification or transfer learning. Please follow the model-selector.ipynb notebook to understand how these models were chosen.

The table below provides a performance summary:

Model Top-1 Accuracy Checkpoint Misc
B/8 85.948 B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz
L/16 85.716 L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz
B/16 84.018 B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz
R50-L/32 83.784 R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz
R26-S/32 (light aug) 80.944 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz tb.dev run
R26-S/32 (medium aug) 80.462 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz
S/16 80.462 S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz tb.dev run
B/32 79.436 B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz

Note that the top-1 accuracy is reported on ImageNet-1k validation set. The checkpoints are present in the following GCS location: gs://vit_models/augreg. More details on these can be found in [4].

Image classifiers

Feature extractors

Other notebooks

  • classification.ipynb: Shows how to load a Vision Transformer model from TensorFlow Hub and run image classification.
  • fine_tune.ipynb: Shows how to fine-tune a Vision Transformer model from TensorFlow Hub on the tf_flowers dataset.

Additionally, i1k_eval contains files for running evaluation on ImageNet-1k validation split.

References

[1] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Dosovitskiy et al.

[2] ImageNet-1k

[3] How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers by Steiner et al.

[4] Vision Transformer GitHub

[5] jax2tf tool

Acknowledgements

Thanks to the authors of Vision Transformers for their efforts put into open-sourcing the models.

Thanks to the ML-GDE program for providing GCP Credit support that helped me execute the experiments for this project.

About

This repository hosts code for converting the original Vision Transformer models (JAX) to TensorFlow.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published