Skip to content

vlue-c/Visual-Explanation-Methods-PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

61 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torchvex

Visual Explanation Methods for pyTorch

  • PyTorch friendly
  • More utilization of GPU.
  • Higher-order derivative friendly
  • Batch processing

Usage

pip install -r requirements.txt
python setup.py install

requirements

  • pytorch >= 1.7 (finally torch.quantile)
  • tqdm
  • TBD

optional

  • captum (for now. for DeepLift)

Class Activation Mapping based Methods

functionality progress
Higher order derivative ✔️
Batch processing ✔️
Post processing ✔️
Pre processing ✔️

Expample:

from torchvision.models import resnet50
from torchvex import CAM

# in case of torchvision.models.resnet,
# the output of the model.layer4 is equals to
# the output of the last conv-bn-relu layer
resnet = resnet50(pretrained=True).eval()
cam_generator = CAM(resnet, target_layer=model.layer4, fc_layer=model.fc)

# if no target is passed,
# the predicted class is used as the target.
cam = cam_generator(image)
cam = cam_generator(image, target)

Result:

torchvex/cam/cat_dog.jpg (I don't know where this image comes from.)

cam_example

functionality progress
Higher order derivative ✔️
Batch processing ✔️
Post processing ✔️
Pre processing ✔️

Example:

from torchvex import GradCAM
model = ...
grad_cam_generator = GradCAM(model, target_layer=model.layer4)
multil_layer_gcamgen = GradCAM(model, target_layer=[model.layer3, model.layer4])

grad_cam = grad_cam_generator(image)
grad_cam = grad_cam_generator(image, target)

multiple_grad_cam = multil_layer_gcamgen(image)
# multiple_grad_cam.shape: torch.Size([2, 1, image.size(-2), image.size(-1)])

Results:

torchvex/cam/cat_dog.jpg (I don't know where this image comes from.)

gradcam_example


DeepLift


RISE

functionality progress
Higher order derivative 🙅
Batch processing ✔️
Post processing ✔️
Pre processing ✔️

Example:

from torchvex import RISE

model = ...
rise_generator = RISE(model, num_masks=8000, cell_size=7,
                      probability=0.5, batch_size=256)

rise = rise_generator(image)
rise = rise_generator(image, target)

Meaningful Perturbation

  • Interpretable Explanations of Black Boxes by Meaningful Perturbation
functionality progress
Higher order derivative 🙅
Batch processing
Post processing ✔️
Pre processing ✔️

Example:

from torchvision import transforms as T
from torchvex import MeaningfulPerturbation

# if Normalization is needed
normalization = T.Normalize(MEAN, STD)
transform = T.Compose([
    TransformA(),
    TransformB(),
    ...,
    normalization
])
# else
normalization = None

dataset = Dataset(..., transform=transform)
mp_generator = MeaningfulPerturbation(model, normalization)

mp = mp_generator(image)
mp = mp_generator(image, target)

Result:

{ImageNet}/train/n03372029/n03372029_42103.JPEG

meanpert_example


Simple Gradient

  • where M is model, c is target class, x is input image.

functionality progress
Higher order derivative ✔️
Batch processing ✔️
Post processing ✔️
Pre processing ✔️

Example:

from torchvex import SimpleGradient
from torchvex import clamp_quantile

def clip_gradient(gradient):
  gradient = gradient.abs().sum(1, keepdim=True)
  return clamp_quantile(gradient, q=0.99)

def normalize_gradient(gradient):
  gradient = gradient.abs().sum(1, keepdim=True)
  nbatchs, nchannels, w, h = gradient.shape
  return w * h * gradient / gradient.sum()

model = ...
simgrad_generator = SimpleGradient(model, post_process=clip_gradient)

simgrad = simgrad_generator(image)
simgrad = simgrad_generator(image, target)

Result:

{ImageNet}/val/ILSVRC2012_val_00046413.JPEG or

{ImageNet}/val/n02423022/ILSVRC2012_val_00046413.JPEG

simplegrad_example


SmoothGrad

functionality progress
Higher order derivative 💥💻💥
Batch processing ✔️
Post processing ✔️
Pre processing ✔️

Example:

from torchvex import SmoothGradient
from torchvex import clamp_quantile

def clip_gradient(gradient):
  gradient = gradient.abs().sum(1, keepdim=True)
  return clamp_quantile(gradient, q=0.99)

smoothgrad_gen = SmoothGradient(
    model, num_samples=50, stdev_spread=0.1,
    magnitude=True, postprocess=postprocess
)

smoothg = smoothgrad_gen(image)
smoothg = smoothgrad_gen(image, target)

Result:

{ImageNet}/val/ILSVRC2012_val_00046413.JPEG or

{ImageNet}/val/n02423022/ILSVRC2012_val_00046413.JPEG

magnitude = True smooth_grad_example_mag_True

magnitude = False smooth_grad_example_mag_False


About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages