Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Add support for BatchNorm2d (and more options for Slice) #559

Open
jbmaxwell opened this issue Mar 29, 2020 · 1 comment
Open

Add support for BatchNorm2d (and more options for Slice) #559

jbmaxwell opened this issue Mar 29, 2020 · 1 comment
Labels
feature request Functionality does not currently exist, would need to be created as a new feature (type)

Comments

@jbmaxwell
Copy link

I'm trying to convert the ClusterGAN and the conversion fails in two places. The first error is a Slice in the encoder that fails in the forward() function:

def forward(self, in_feat):
        z_img = self.model(in_feat)
        # Reshape for output
        z = z_img.view(z_img.shape[0], -1)
        # Separate continuous and one-hot components
        zn = z[:, 0:self.latent_dim]
        zc_logits = z[:, self.latent_dim:]
        # Softmax on zc component
        zc = softmax(zc_logits)
        return zn, zc, zc_logits

I'm able to work around that error by moving the Slice out of the converted model and handling it in Swift. However, the 2nd error, which I'm not sure I can work around, is a BatchNorm2d() in the generator. ClusterGAN is a great solution for my purposes, so I'd love to get this model converted.

@jbmaxwell jbmaxwell added the feature request Functionality does not currently exist, would need to be created as a new feature (type) label Mar 29, 2020
@jbmaxwell
Copy link
Author

It seems I'm mistaken about the error here. I tried doing some reshaping to replace the BatchNorm2d with a BatchNorm1d, but I still get the error: Error while converting op of type: BatchNormalization. Error message: provided number axes 2 not supported.

The (original) model structure is:

self.model = nn.Sequential(
            # Fully connected layers
            torch.nn.Linear(self.latent_dim + self.n_c, 1024),
            nn.BatchNorm1d(1024),
            #torch.nn.ReLU(True),
            nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(1024, self.iels),
            nn.BatchNorm1d(self.iels),
            #torch.nn.ReLU(True),
            nn.LeakyReLU(0.2, inplace=True),
        
            # Reshape to 128 x (7x7)
            Reshape(self.ishape),

            # Upconvolution layers
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(64),
            #torch.nn.ReLU(True),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1, bias=True),
            nn.Sigmoid()
        )

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
feature request Functionality does not currently exist, would need to be created as a new feature (type)
Projects
None yet
Development

No branches or pull requests

1 participant