Paper: https://arxiv.org/abs/1512.03385
Repository: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
Images must be in range [0, 1]. If the pretrained ImageNet weights are selected, the images are internally normalized with the ImageNet mean and standard deviation. If you don't want the images to be normalized, use normalize=False
(see here for details).
For more usage examples check out this Colab.
from PIL import Image
import jax
import jax.numpy as jnp
import flaxmodels as fm
key = jax.random.PRNGKey(0)
# Load image
img = Image.open('example.jpg')
# Image should be in range [0, 1]
x = jnp.array(img, dtype=jnp.float32) / 255.0
# Add batch dimension
x = jnp.expand_dims(x, axis=0)
resnet18 = fm.ResNet18(output='logits', pretrained='imagenet')
params = resnet18.init(key, x)
# Shape [1, 1000]
out = resnet18.apply(params, x, train=False)
Usage is equivalent for ResNet34, ResNet50, ResNet101, and Resnet152.
The documentation can be found here.
If you want to train ResNet in Jax/Flax, go here.