Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structure Difference between PyTorch ResNet and JAX resnet (at layer 4) #5

Open
sarahelsherif opened this issue Nov 25, 2021 · 6 comments

Comments

@sarahelsherif
Copy link

sarahelsherif commented Nov 25, 2021

Hello Nicholas, while using pretrained RESNET(101)
I am comparing the output size of RESNET model in PyTorch after layer no. 4 (rendering the output before the avg pooling there)
after running it to an input batch size[1, 224, 224, 3]
It was torch.Size ([1, 2048, 28, 28]).
However, when I tried to render the output in your RESNET model JAX/FLAX (I have removed these 2 commented lines in RESNET function to get output before the avg pooling (layer4 equivalent to PyTorch)

def ResNet(
    block_cls: ModuleDef,
    *,
    stage_sizes: Sequence[int],
    n_classes: int,
    hidden_sizes: Sequence[int] = (64, 128, 256, 512),
    conv_cls: ModuleDef = nn.Conv,
    norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
    conv_block_cls: ModuleDef = ConvBlock,
    stem_cls: ModuleDef = ResNetStem,
    pool_fn: Callable = partial(nn.max_pool,
                                window_shape=(3, 3),
                                strides=(2, 2),
                                padding=((1, 1), (1, 1))),
) -> Sequential:
    conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
    stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
    block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
   
    layers = [stem_cls(), pool_fn] 

    for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
        for b in range(n_blocks):
            strides = (1, 1) if i == 0 or b != 0 else (2, 2)
            layers.append(block_cls(n_hidden=hsize, strides=strides))
 #------------------------------------------------------------------------------
 #  layers.append(partial(jnp.mean, axis=(1, 2)))  # global average pool
 # layers.append(nn.Dense(n_classes))
 #------------------------------------------------------------------------------
    return Sequential(layers)

It has a different output shape (for the same size of inp_batch(1, 224, 224, 3)) :

RESNET100, variables = pretrained_resnet(101)
RESNET = RESNET100()
model_out=RESNET.apply(variables, jnp.ones((1, 224, 224, 3)) ,mutable=False) 
print("pretrained resnet100 size:", jax.tree_map(lambda x: x.shape, model_out))

pretrained resnet100 size:--> (1, 7, 7, 2048)
So, what's happened at this stage in ResNet layers structure?
Kindly reply, if you have any explanation or recommendations.

@n2cholas
Copy link
Owner

Hi @sarahelsherif, thanks for raising this issue! Could you also paste in the PyTorch code that gives you torch.Size ([1, 2048, 28, 28]) for comparison?

@sarahelsherif
Copy link
Author

Thank you @n2cholas , ok here is the PyTorch code:

class RESNET_Layer_4(nn.Module):
    
    def __init__(self, backbone: nn.Module) -> None:
        super(RESNET_Layer_4, self).__init__()
        self.backbone = backbone
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        input_shape = x.shape[-2:]
        backbone_resnet= self.backbone(x)
        print("backbone resnet output shape",backbone_resnet["out"].shape)
        
        return backbone_resnet

def resnet4(
    backbone: ResNet,
) -> RESNET_Layer_4:
    return_layers = {"layer4": "out"}
    backbone = create_feature_extractor(backbone, return_layers)
    return RESNET_Layer_4(backbone)

the input_batch shape is : torch.Size([1, 3, 224, 224])

pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])
r4=resnet4(pretrained_resnet)
r4= r4.cuda()
out_resnet4=r4(inp_batch)

the output is : backbone resnet output shape torch.Size([1, 2048, 28, 28])

@n2cholas
Copy link
Owner

Hi @sarahelsherif, I wasn't able to directly use the code that you sent since I do not have create_feature_extractor. Instead, see this example of extracting the bacbone in both JAX and PyTorch here. As you can see they have the same output shape.

Does that help?

@sarahelsherif
Copy link
Author

Hey @n2cholas , first of all thank you so much for help.
About create_feature_extractor , it is a utility from TorchVison create_feature_extractor ,which can be imported like this:

from torchvision.models.feature_extraction import create_feature_extractor

And thank you for your example, it helped. I know now why the output shape is different because of replacing strides and dilation in the pretrained resnet:

pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])

So, my issue is solved now about different output shapes.
On the other hand, I will be grateful , if you suggested a way to apply replacing stride with dilation in JAX.

@n2cholas
Copy link
Owner

n2cholas commented Dec 1, 2021

This can definitely be supported, essentially we would need to apply the logic in _make_layer to the ResNetBottleneckBlock. I won't have the bandwidth to work on this for a few weeks, but would happy to review a PR if you decide to implement this.

@sarahelsherif
Copy link
Author

Yes, sure ..thank you so much for help.
And will update you, when I implement it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants