Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

So 'RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn' #473

Open
ross-Hr opened this issue Dec 23, 2023 · 7 comments

Comments

@ross-Hr
Copy link

ross-Hr commented Dec 23, 2023

I used my own framework, but this error occurred.

The EigenCAM will not report the error, but all other cams that use requires_grad=True will report this error.

I am very confused and looking forward to the next upgrade.

You can check the gradient of the model before using it, which is more user-friendly.

The problems I encountered were similar to https://github.com/jacobgil/pytorch-grad-cam/issues/323

@ross-Hr
Copy link
Author

ross-Hr commented Dec 23, 2023

I debug the class ActivationsAndGradients, but I still can't solve the problem.

@Markson-Young
Copy link

@ross-Hr Hello! This question has been bothering me for a few days, and I tried a lot of things but I couldn't solve it.
I carefully ran through and compared the tutorial code https://jacobgil.github.io/pytorch-gradcam-book/Class%20Activation%20Maps%20for%20Semantic%20Segmentation.html with my own to try to solve this problem.

This is my error output and the activations is 0. I don't know what caused it:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
image

And, below is the debug result of the official tutorial:
image

If you have any idea what might have caused this, please let me know.

@Markson-Young
Copy link

I found that my own model was registered register_forward_hook in the process of visualization using grad-cam, but it was not executed in return self.model(x) in class ActivationsAndGradients, which means save_activation has not been called.

class ActivationsAndGradients:
    """ Class for extracting activations and
    registering gradients from targetted intermediate layers """

    def __init__(self, model, target_layers, reshape_transform):
        self.model = model
        self.gradients = []
        self.activations = []
        self.reshape_transform = reshape_transform
        self.handles = []
        for target_layer in target_layers:
            self.handles.append(
                target_layer.register_forward_hook(self.save_activation))
            # Because of https://github.com/pytorch/pytorch/issues/61519,
            # we don't use backward hook to record gradients.
            self.handles.append(
                target_layer.register_forward_hook(self.save_gradient))

    def save_activation(self, module, input, output):
        activation = output

        if self.reshape_transform is not None:
            activation = self.reshape_transform(activation)
        self.activations.append(activation.cpu().detach())

    def save_gradient(self, module, input, output):
        if not hasattr(output, "requires_grad") or not output.requires_grad:
            # You can only register hooks on tensor requires grad.
            return

        # Gradients are computed in reverse order
        def _store_grad(grad):
            if self.reshape_transform is not None:
                grad = self.reshape_transform(grad)
            self.gradients = [grad.cpu().detach()] + self.gradients

        output.register_hook(_store_grad)

    def __call__(self, x):
        self.gradients = []
        self.activations = []
        return self.model(x)

    def release(self):
        for handle in self.handles:
            handle.remove()

@jacobgil
Copy link
Owner

Hi all, I will be looking into it.
Some context will help - what is the model you're using? is it a custom model? Object detection, or something else?
Are you using a model wrapper ?
Anything else you can share ?

@Markson-Young
Copy link

Thank you for your reply. Grad-cam caught my eye in some papers, and I wanted to implement reliable feature visualizations in my own models. So, I tried to use grad-cam for feature visualization on my own implementation of a semantic segmentation model based on Maskdino, a semantic segmentation variant of the DETR model.
In addition, my model has two backbones that correspond to RGB and Thermal. My model is implemented based on mmdetection, which is also a framework for further encapsulation based on pytorch. My version is pytorch=1.13.1 grad-cam=1.4.8, and I tried the grad-cam semantic segmentation tutorial script , the output results is correct.

Here is my semantic segmentation feature visualizition script based on pytorch-grad-cam implementation:

import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
from PIL import Image
from pprint import pprint

import os, argparse, time, sys, torch
import numpy as np
from torchvision.transforms import Compose, Normalize, ToTensor
import torch
from torch.autograd import Variable
import mmcv
import torch.nn as nn
import copy
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union
from mmengine.config import Config
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmdet.registry import DATASETS
from mmdet.evaluation import get_classes
from mmdet.registry import MODELS

from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM, XGradCAM, EigenCAM, EigenGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from mmdet.apis import init_detector, inference_detector

##################################################################################################################################################################

# Supported grad-cam type map
METHOD_MAP = {
    'gradcam': GradCAM,
    'gradcam++': GradCAMPlusPlus,
    'xgradcam': XGradCAM,
    'eigencam': EigenCAM,
    'eigengradcam': EigenGradCAM,
    'layercam': LayerCAM,
}

DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device(
    'cpu')
image_file = './data/MF_dataset/images/val/'
rgb_file = './data/MF_dataset/seperated/'
image_name = "01007N"
IMAGE_FILE_PATH = os.path.join(image_file, image_name + (".png"))
RGB_FILE_PATH = os.path.join(rgb_file, image_name + ("_rgb.png"))
THER_FILE_PATH = os.path.join(rgb_file, image_name + ("_th.png"))
MEAN = [0.535, 0.520, 0.581]
STD = [0.149, 0.111, 0.104]
# MEAN = [123.675, 116.28, 103.53]
# STD = [58.395, 57.12, 57.375]
CONFIG = 'projects/TMM/configs/tmm_r152-MF.py'
CHECKPOINT = 'heatmap/checkpoints/MF/20231117_061123/best_mIoU_epoch_13.pth'
PREVIEW_MODEL = True
TARGET_LAYERS = ["model.model.backbone_r.layer3"]
METHOD = 'gradcam'

# for MFN dataset
SEM_CLASSES = [
    'unlabeled', 'car', 'person', 'bike', 'curve', 'car_stop', 'guardrail',
    'color_cone', 'bump'
]

TARGET_CATEGORY = 'person'
VIS_CAM_RESULTS = True
CAM_SAVE_PATH = "/heatmap/output"
LIKE_VIT = False
PRITN_MODEL_PRED_SEG = False


def parse_args():
    parser = argparse.ArgumentParser(description='Visualize CAM')
    parser.add_argument('--img', default=IMAGE_FILE_PATH, help='Image file')
    parser.add_argument('--config', default=CONFIG, help='Config file')
    parser.add_argument('--checkpoint',
                        default=CHECKPOINT,
                        help='Checkpoint file')
    parser.add_argument(
        '--target_layers',
        default=TARGET_LAYERS,
        nargs='+',
        type=str,
        help='The target layers to get CAM, if not set, the tool will '
        'specify the norm layer in the last block. Backbones '
        'implemented by users are recommended to manually specify'
        ' target layers in commmad statement.')
    parser.add_argument('--preview_model',
                        default=PREVIEW_MODEL,
                        help='To preview all the model layers')

    parser.add_argument('--method',
                        default=METHOD,
                        help='Type of method to use, supports '
                        f'{", ".join(list(METHOD_MAP.keys()))}.')

    parser.add_argument('--sem_classes',
                        default=SEM_CLASSES,
                        nargs='+',
                        type=int,
                        help='all classes that model predict.')
    parser.add_argument(
        '--target_category',
        default=TARGET_CATEGORY,
        type=str,
        help='The target category to get CAM, default to use result '
        'get from given model.')

    parser.add_argument('--aug_mean',
                        default=MEAN,
                        nargs='+',
                        type=float,
                        help='augmentation mean')

    parser.add_argument('--aug_std',
                        default=STD,
                        nargs='+',
                        type=float,
                        help='augmentation std')

    parser.add_argument(
        '--cam_save_path',
        default=CAM_SAVE_PATH,
        type=str,
        help='The path to save visualize cam image, default not to save.')
    parser.add_argument('--vis_cam_results', default=VIS_CAM_RESULTS)
    parser.add_argument('--device', default=DEVICE, help='Device to use cpu')

    parser.add_argument('--like_vision_transformer',
                        default=LIKE_VIT,
                        help='Whether the target model is a ViT-like network.')

    parser.add_argument('--print_model_pred_seg',
                        default=PRITN_MODEL_PRED_SEG,
                        help='')

    args = parser.parse_args()
    if args.method.lower() not in METHOD_MAP.keys():
        raise ValueError(f'invalid CAM type {args.method},'
                         f' supports {", ".join(list(METHOD_MAP.keys()))}.')

    return args


def norm_img(img, mean, std):
    image = img.copy()
    image = np.array(image)
    image = image.transpose(2, 0, 1)
    image = torch.tensor(image)
    data_rgb = image[:3, :, :]
    tmp = image[3:4, :, :]
    data_t = torch.cat([tmp] * 3, dim=0)
    preprocessing = Compose([
        # ToTensor(),
        Normalize(mean=mean, std=std)
    ])
    rgb = preprocessing(data_rgb.float()).unsqueeze(0)
    t = preprocessing(data_t.float()).unsqueeze(0)
    input_tensor = torch.cat((rgb, t), dim=1)
    return input_tensor


def make_input_tensor(image_file_path, mean, std, device):
    if not os.path.exists(image_file_path):
        raise (f"{image_file_path} is not exist!")
    image = np.asarray(Image.open(image_file_path))

    img = np.float32(image) / 255
    img = torch.tensor(img)
    input_tensor = norm_img(image, mean, std)
    # input_tensor = preprocess_image(rgb_img, mean=mean, std=std)
    if device == torch.device('cuda:0'):
        input_tensor = input_tensor.to(device)
    print(f"input_tensor has been to {device}")
    return input_tensor, img


def make_model(config_path, checkpoint_path, device):
    model = init_detector(config_path, checkpoint_path, device=device)
    print('Network setup complete: The trained weights were successfully loaded')
    return model


def init_detector(
    config: Union[str, Path, Config],
    checkpoint: Optional[str] = None,
    palette: str = 'none',
    device: str = 'cuda:0',
    cfg_options: Optional[dict] = None,
) -> nn.Module:
    """Initialize a detector from config file.

    Args:
        config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
            :obj:`Path`, or the config object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        palette (str): Color palette used for visualization. If palette
            is stored in checkpoint, use checkpoint's palette first, otherwise
            use externally passed palette. Currently, supports 'coco', 'voc',
            'citys' and 'random'. Defaults to none.
        device (str): The device where the anchors will be put on.
            Defaults to cuda:0.
        cfg_options (dict, optional): Options to override some settings in
            the used config.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, (str, Path)):
        config = Config.fromfile(config)
    elif not isinstance(config, Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
    elif 'init_cfg' in config.model.backbone_r:
        config.model.backbone_r.init_cfg = None
    elif 'init_cfg' in config.model.backbone_i:
        config.model.backbone_i.init_cfg = None

    scope = config.get('default_scope', 'mmdet')
    if scope is not None:
        init_default_scope(config.get('default_scope', 'mmdet'))

    model = MODELS.build(config.model)
    model = revert_sync_batchnorm(model)
    if checkpoint is None:
        warnings.simplefilter('once')
        warnings.warn('checkpoint is None, use COCO classes by default.')
        model.dataset_meta = {'classes': get_classes('coco')}
    else:
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
        # Weights converted from elsewhere may not have meta fields.
        checkpoint_meta = checkpoint.get('meta', {})

        # save the dataset_meta in the model for convenience
        if 'dataset_meta' in checkpoint_meta:
            # mmdet 3.x, all keys should be lowercase
            model.dataset_meta = {
                k.lower(): v
                for k, v in checkpoint_meta['dataset_meta'].items()
            }
        elif 'CLASSES' in checkpoint_meta:
            # < mmdet 3.x
            classes = checkpoint_meta['CLASSES']
            model.dataset_meta = {'classes': classes}
        else:
            warnings.simplefilter('once')
            warnings.warn(
                'dataset_meta or class names are not saved in the '
                'checkpoint\'s meta data, use COCO classes by default.')
            model.dataset_meta = {'classes': get_classes('coco')}

    # Priority:  args.palette -> config -> checkpoint
    if palette != 'none':
        model.dataset_meta['palette'] = palette
    else:
        test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset)
        # lazy init. We only need the metainfo.
        test_dataset_cfg['lazy_init'] = True
        metainfo = DATASETS.build(test_dataset_cfg).metainfo
        cfg_palette = metainfo.get('palette', None)
        if cfg_palette is not None:
            model.dataset_meta['palette'] = cfg_palette
        else:
            if 'palette' not in model.dataset_meta:
                warnings.warn(
                    'palette does not exist, random is used by default. '
                    'You can also set the palette to customize.')
                model.dataset_meta['palette'] = 'random'

    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model


from torch.nn import functional as F


class SegmentationModelOutputWrapper(torch.nn.Module):
    def __init__(self, model):
        super(SegmentationModelOutputWrapper, self).__init__()
        self.model = model

    def reshape_output(self, sem_seg, num_classes):
        multi_channel_mask = torch.zeros(
            (1, num_classes, sem_seg.shape[1], sem_seg.shape[2]),
            dtype=sem_seg.dtype)
        for i in range(num_classes):
            multi_channel_mask[0, i, :, :] = (sem_seg == i)
        return multi_channel_mask

    def forward(self, x):
        num_classes = self.model.num_classes
        result = self.model(x.squeeze(0))
        pred_sem_seg = self.reshape_output(result.sem_seg, num_classes)
        # out = pred_sem_seg
        out = F.interpolate(pred_sem_seg.float(),
                            size=x.shape[-2:],
                            mode='bilinear',
                            align_corners=False)
        return out

class SemanticSegmentationTarget:
    def __init__(self, category, mask):
        self.category = category
        self.mask = torch.from_numpy(mask)
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()

    def __call__(self, model_output):
        model_output = model_output.cuda()
        return (model_output[self.category, :, :] * self.mask).sum()


def reshape_transform_fc(in_tensor):
    result = in_tensor.reshape(in_tensor.size(0),
                               int(np.sqrt(in_tensor.size(1))),
                               int(np.sqrt(in_tensor.size(1))),
                               in_tensor.size(2))

    result = result.transpose(2, 3).transpose(1, 2)
    return result

def main():
    args = parse_args()

    input_tensor, img = make_input_tensor(args.img,
                                          args.aug_mean,
                                          args.aug_std,
                                          device=args.device)
    rgb_img = mmcv.imread(RGB_FILE_PATH)  # rgb
    th_img = mmcv.imread(THER_FILE_PATH)  # ther

    cfg = args.config
    checkpoint = args.checkpoint
    model = make_model(cfg, checkpoint, device=args.device)

    results = inference_detector(model, args.img)

    if args.print_model_pred_seg:
        pprint(results)

    if args.preview_model:
        pprint([name for name, _ in model.named_modules()])
    model = SegmentationModelOutputWrapper(model)
    output = model(input_tensor)

    sem_classes = args.sem_classes
    sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}

    if len(sem_classes) == 1:
        output = torch.sigmoid(output).cpu()
        perd_mask = torch.where(output > 0.3, torch.ones_like(output),
                                torch.zeros_like(output))
        perd_mask = perd_mask.detach().cpu().numpy()

    else:
        output = torch.nn.functional.softmax(output, dim=1).cpu()
        perd_mask = output[0, :, :, :].argmax(axis=0).detach().cpu().numpy()

    category = sem_class_to_idx[args.target_category]
    mask_float = np.float32(perd_mask == category)

    # # visual
    # car_mask_uint8 = 255 * np.uint8(perd_mask == category)
    #
    # both_images = np.hstack((rgb_img, np.repeat(car_mask_uint8[:, :, None], 3, axis=-1)))
    # image = Image.fromarray(both_images)
    # image.save('output.png')

    reshape_transform = reshape_transform_fc if args.like_vision_transformer else None

    ##########################################################################################################################################################################

    target_layers = [model.model.backbone_r.layer4]

    ##########################################################################################################################################################################
    targets = [SemanticSegmentationTarget(category, mask_float)]
    GradCAM_Class = METHOD_MAP[args.method.lower()]


    with GradCAM_Class(model=model,
                       target_layers=target_layers,
                       use_cuda=torch.cuda.is_available(),
                       reshape_transform=reshape_transform_fc
                       if args.like_vision_transformer else None) as cam:
        grayscale_cam = cam(input_tensor=input_tensor,
                            targets=targets,
                            aug_smooth=True,
                            eigen_smooth=True)[0, :]
        cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

    vir_image = Image.fromarray(cam_image)

    if args.vis_cam_results:
        vir_image.show()
    cam_save_path = f"{args.cam_save_path}/{os.path.basename(args.config).split('.')[0]}"
    if not os.path.exists(cam_save_path):
        os.makedirs(cam_save_path)
    vir_image.save(
        os.path.join(cam_save_path,
                     f"{os.path.basename(args.img).split('.')[0]}.png"))

if __name__ == '__main__':
    main()

@Markson-Young
Copy link

Before I used grad-cam, I tried using register_forward_hook to extract features from the middle layer of my model myself, and it worked.

hooks = [
    model.backbone_r.layer3.register_forward_hook(
        lambda self, input, output: conv_features_r_3.append(output)
    ),
    model.backbone_t.layer3.register_forward_hook(
        lambda self, input, output: conv_features_t_3.append(output)
    )
]
result = inference_detector(model, img)
for hook in hooks:
    hook.remove()

But when I tried to use grad-cam, I found that register_forward_hook was not called in class ActivationsAndGradients. Does pytorch-grad-cam have any restrictions on the model? Thank you for your attention.

@ross-Hr
Copy link
Author

ross-Hr commented Jan 6, 2024

I'm using a custom model based on yolov7.
The code is correct:
cam = EigenCAM(model, target_layers)
When i change the following code in eigen_cam.py, it reports the bug:
super(EigenCAM, self).__init__(model, target_layers, reshape_transform, **uses_gradients=True**)

Where can I manually set up requires_grad=True for a custom model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants