Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn #55

Open
littlepae opened this issue Aug 4, 2022 · 1 comment
Labels
enhancement New feature or request question Further information is requested

Comments

@littlepae
Copy link

Problem
Hello. This is nice repo and make me easily to understand and implement FSL model to my project. But i would like to ask you How to implement and train Transductive Fine-tuning model

Since this model is classical training (if i understand correctly) so i use the same classical training tutorial and just replace PrototypicalNetworks to TransductiveFinetuning in few_shot_classifier()

But in training stage, This error is show up

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_16618/1151720753.py in <module>
      6 for epoch in range(n_epochs):
      7     print(f"Epoch {epoch}")
----> 8     average_loss = training_epoch(model, train_loader, train_optimizer)
      9 
     10     if epoch % validation_frequency == validation_frequency - 1:

/tmp/ipykernel_16618/3987028259.py in training_epoch(model_, data_loader, optimizer)
      7 
      8             loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
----> 9             loss.backward()
     10             optimizer.step()
     11 #             model_(images.to(DEVICE))

~/.local/lib/python3.7/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

~/.local/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    176 
    177 def grad(

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

So i comment loss.backward() and optimizer.step() line as it already have in forward function of Transductive Fine-tuning model but next problem is the loss of model in training process does not reduce and i still don't know why?

Can you provide me how to fix this error and how to implement Transductive Fine-tuning or other classical model in EasyFSL?

Thanks a lot

@littlepae littlepae added the question Further information is requested label Aug 4, 2022
@ebennequin
Copy link
Collaborator

Hi @littlepae, thank you for appreciating EasyFSL!

In TransductiveFinetuning.__init__() we set the backbone to not require grad (here). So when you try to compute the gradient during loss.backward() it gives this error.

To fix your issue

Before training, set model.requires_grad_(True), and then for evaluation model.requires_grad_(False)

To improve EasyFSL

Silently disabling gradients for the backbone in the initialization of transductive method is blatantly a bad practice. It will cause errors like this and will not be easy to debug for the user. Either we only freeze the backbone during the forward method and unfreeze it later, either we signal the freezing in the logs. Option 1 seems best.

I'm marking this as enhancement, thank you for pointing this out to me!

@ebennequin ebennequin added the enhancement New feature or request label Aug 16, 2022
@ebennequin ebennequin changed the title How to use Transductive Fine-tuning model? RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn Aug 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
Status: 📋 Backlog
Development

No branches or pull requests

2 participants