Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer got an empty parameter list #23

Open
fafafafafafafa opened this issue Feb 6, 2023 · 5 comments
Open

optimizer got an empty parameter list #23

fafafafafafafa opened this issue Feb 6, 2023 · 5 comments

Comments

@fafafafafafafa
Copy link

when i put centerloss's parameters into optimizer, raise valueError("optimizer got an empty parameter list")
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent)

@mohamedr002
Copy link

Can you provide a code snippet for the definition of criterion_cent, any you need to ensure that you are doing it like this

self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())

@fafafafafafafa
Copy link
Author

fafafafafafafa commented Feb 7, 2023

SoftmaxLoss = torch.nn.CrossEntropyLoss()
centerLoss = center_losses.CenterLoss(classes=train_class, feature_dims=1024)
optimizer_centerloss = torch.optim.SGD(list(centerLoss.parameters()), lr=0.5)

class CenterLoss(nn.Module):
    def __init__(self, classes, feature_dims, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.classes = classes
        self.feature_dims = feature_dims
        self.use_gpu = use_gpu
        if use_gpu:
            centers = nn.Parameter(torch.randn(self.classes, self.feature_dims)).cuda()
        else:
            centers = nn.Parameter(torch.randn(self.classes, self.feature_dims))
        self.centers = centers

    def forward(self, x, labels):

        # labels: [N_way, K_shot]
        batch_size = x.shape[0]    # x_shape: torch.Size([N_way*K_shot, 1024])
        # dist_mat: torch.Size([batch_size, classes])
        # print('x: ', x)
        print('centers: ', self.centers)
        dist_mat = torch.sum(torch.square(x), 1, keepdim=True).expand(batch_size, self.classes) + \
             torch.sum(torch.square(self.centers), 1, keepdim=True).expand(self.classes, batch_size).t()
        dist_mat = dist_mat - 2*torch.matmul(x, self.centers.t())

        # dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.classes) + \
        #    torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.classes, batch_size).t()
        # dist_mat.addmm_(x, self.centers.t(), beta=1, alpha=-2)

        labels = torch.reshape(labels, (-1, 1)).expand(batch_size, self.classes)
        classes_mat = torch.arange(self.classes).expand(batch_size, self.classes).long()
        if self.use_gpu:
            # print('use_gpu:', self.use_gpu)
            classes_mat = classes_mat.cuda()
        mask = labels.eq(classes_mat).float()
        dist_mat = dist_mat*mask
        center_loss = torch.sum(dist_mat.clamp(min=1e-12, max=1e+12))/(batch_size*self.feature_dims)
        # get support set centers
        mask1 = torch.sum(mask, 0).bool()
        support_centers = self.centers[mask1, :]
        # print('support_centers:', support_centers.shape)
        return center_loss, support_centers

@mohamedr002
Copy link

It seems there is no issues with your code, but can try removing list from the below part
optimizer_centerloss = torch.optim.SGD(list(centerLoss.parameters()), lr=0.5). to be
optimizer_centerloss = torch.optim.SGD(centerLoss.parameters(), lr=0.5).

@fafafafafafafa
Copy link
Author

I have removed list, but it also has the error above.

@fafafafafafafa
Copy link
Author

I find where is wrong, edit
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)).cuda() to be
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
but I don't know the difference between them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants