Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optim-wip: Composable loss improvements #828

Open
wants to merge 10 commits into
base: optim-wip
Choose a base branch
from

Conversation

ProGamerGov
Copy link
Contributor

@ProGamerGov ProGamerGov commented Dec 20, 2021

This PR adds a few simple improvements to the CompositeLoss class and it's features.

  • Added support for __pos__ and __abs__ unary operators to CompositeLoss. These appear to be the only other basic operators that make sense to add support for.
  • Added operator.floordiv support. The current operator is being depreciated, but the operator symbol itself is likely not going be removed. Instead it's functionality will be changed: Python Array API Compatibility Tracker pytorch#58743
  • Made CompositeLoss's reduction operation a global variable that can be changed by users. This should improve the generality of the optim module, and it makes it possible to disable this aspect of CompositeLoss.
  • Added composable torch.mean and torch.sum reduction operations to CompositeLoss. These are common operations, so there are likely use cases that can benefit from them. Example usage: loss_fn.mean() & loss_fn.sum().
  • Added the custom_composable_op function that should allow for the composability of many Python & PyTorch operations, as well as custom user operations. This should allow users to cover any operations that aren't covered by default in Captum.
  • Added rmodule_op function for handling the 3 "r" versions of math operations. This helps simplify the code.
  • Added tests for the changes listed above.
  • 2.0 ** loss_obj, 2.0 / loss_obj, & 2.0 // loss_obj all work without the reduction op, so I've removed it for those cases.

* Added `operator.floordiv` support.
* Added the `basic_torch_module_op` function that should allow for the composability of many common torch operations.
* Added `rmodule_op` function for handling the 3 "r" versions of math operations.
* Improved documentation.
* Renamed `basic_torch_module_op` to `custom_composable_op`.
* Removed the reduction OP from 'r' module calls as it's not required.
* Custom loss objections can support any number of batch dimension values.
@aobo-y
Copy link
Contributor

aobo-y commented Apr 5, 2022

Hi, thank you for making this, but I may miss some context/history here. Why do we need the "composable loss" in Captum?

Pytorch has already provided a convention for loss: function/callable-module wrapping some tensor operations. For example, if I need a loss made of others

def new_loss(output_tensor):
  return nn.SomeLoss(output_tensor) + some_other_loss(output_tensor) + torch.linalg.norm(output_tensor)

Pytorch tensor has supported these basic arithmetic operations to modify & combine loss tensors. Why are we composing function/callable-module with arithmetic operations, instead of composing tensors? Pytorch supports more operation than what we have. And the "composable loss" cannot be composed with any existing Pytorch losses.

I think our optim loss can work the same without "composable loss" and even be more flexible. For example

deepdream = DeepDream(target)
layeractivation = LayerActivation(target)
def new_loss(targets_to_values: ModuleOutputMapping):
  loss = deepdream(targets_to_values) + layeractivation(targets_to_values)
  # can also use pytorch loss
  return loss + nn.SomeLoss(targets_to_values[target])

target = loss.target
return CompositeLoss(loss_fn, name=name, target=target)


class BaseLoss(Loss):
def __init__(
self,
target: Union[nn.Module, List[nn.Module]] = [],
Copy link
Contributor

@aobo-y aobo-y Apr 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the target can be List[nn.Module], many losses below cannot directly use it as dict key targets_to_values[self.target]. Did I miss anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aobo-y Losses like ActivationInterpolation have multiple targets (Faceted loss as well in an upcoming PR), but BaseLoss works off using a single target variable.

The BaseLoss class is called in the __init__ functions of loss classes like so:

# Single target
BaseLoss.__init__(self, target, batch_index)

# Multiple targets
BaseLoss.__init__(self, [target1, target2])

The loss class itself will indicate via target: List[nn.Module] type hint that multiple targets are supported / required, or it is handled things internally by passing the targets as a list to BaseLoss like in ActivationInterpolation.

The ActivationInterpolation loss class can be found here: https://github.com/ProGamerGov/captum/blob/optim-wip-composable-loss-improvements/captum/optim/_core/loss.py#L506

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but cases like DeepDream and some others directly inherits BaseLoss's init definition, where target can be a list while actually it should not https://github.com/ProGamerGov/captum/blob/optim-wip-composable-loss-improvements/captum/optim/_core/loss.py#L393-L407

If these losses have different assumptions of what their targets should be, why do we abstract the target into the base class. The base class BaseLoss does not need target anyway. Each class can define their own target in __init__. Or we can have 2 other intermediate abstract classes SingleTargetLoss MultiTargetsLoss

But anyway, this is just for discussion. It has nth related to this PR. We can leave it to future updates if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah I see what you mean now. In the original code, I think that Ludwig had SingleTargetObjective & MultiObjective for handling these cases: https://github.com/ludwigschubert/captum/blob/f1fd0729dece59564a7c10b7b397617d8a09a247/captum/optim/optim/objectives.py#L108

It'd probably be best to leave this to a future PR if decide on the changes

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Apr 6, 2022

@aobo-y Originally Captum’s loss functions were setup similar to the simple class-like functions like Lucid uses. Upon review we then changed the losses to use classes instead.

Ludwig (one of the main Lucid developers) designed the initial optim module to utilize a Lucid-like composable loss system. One of the main benefits of the composable loss system is ease of use and built-in target tracking (the list of targets has to created regardless of whether not we use composable losses, and doing it this way means the user doesn't have to repeat the loss targets in multiple locations). It also allows for easy-to-use handling of things like batch specific targeting.

@ProGamerGov
Copy link
Contributor Author

This PR can be skipped for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants