Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement class to crop pulsemaps to maximum length #648

Open
AMHermansen opened this issue Dec 19, 2023 · 1 comment
Open

Implement class to crop pulsemaps to maximum length #648

AMHermansen opened this issue Dec 19, 2023 · 1 comment
Labels
feature New feature or request

Comments

@AMHermansen
Copy link
Collaborator

AMHermansen commented Dec 19, 2023

Is your feature request related to a problem? Please describe.
With #558 we now have better control over how a pulsemap is processed. From the Kaggle competition it became apparent that many of the top scoring models simply cropped the number of pulses to some fixed number, to reduce the impact of the n^2 term from Self-Attention components.

While their primary way to select pulses was to simply select the first n pulses, I believe it might be interesting to look into other methods of selecting pulses. (Randomly, sorted by charge, sorted by probability of real signal, farthest point sampling etc.)

Describe the solution you'd like

To avoid having to implement many Node Definitions I think it might make sense to make a common class for all cropped nodes

class CroppedNodes(NodeDefinition):
    def __init__(self, max_pulses: int, cropping_method: Callable) -> None:
        super().__init__()
        self.max_pulses = max_pulses
        self._cropping_method = cropping_method

    def _construct_nodes(self, x: torch.Tensor) -> Data:
        x = self._cropping_method(x, self.max_pulses)
        return Data(x=x)

Such a structure would also allow to easier re-use the copping methods in other node definitions. (Maybe you want to crop after calculating summary nodes per dom, to make sure you do not get an event which triggered 5k doms.

Describe alternatives you've considered
We could of course just implement each cropping algorithm as a subclass of a common CroppedNodes class and have the logic restricted to each subclass. But I think the cropping logic is general enough that there is merit to have it as a separate component.

@AMHermansen AMHermansen added the feature New feature or request label Dec 19, 2023
@RasmusOrsoe
Copy link
Collaborator

Hey @AMHermansen!

I think it is a great idea to allow for such functionality in GraphDefinition. "Cropping" pulsemaps is essentially just sub-sampling of the available pulses. I would suggest to add this as an independent sub-module of GraphDefinition, so on the user side it could look like:

from graphnet.models.graphs import GraphDefinition
graph_definition = GraphDefinition(detector = detector,
                                   node_definition = node_definition,
                                   edge_definition = edge_definition,
                                   sampler = sampler)

in the forward pass of GraphDefinition we could add early on (perhaps just after the basic checks) a line like so:

if self.sampler is not None:
    subsample_idx = self.sampler(input_features = input_features,
                                 input_feature_names = input_feature_names)
    input_features = input_features[subsample_idx,:]

That would mean that the sampling would be independent of what users would like to do with the pulses.

Here's a quick take on what the sampling module could look like:

from typing import List
from abc import abstractmethod

from graphnet.models import Model
from graphnet.utilities.decorators import final
import numpy as np

class Sampler(Model):
    """Base class for sub-sampling rows in single events."""

    def __init__(self) -> None:
        """Construct `Sampler`."""
        # Base class constructor
        super().__init__(name=__name__, class_name=self.__class__.__name__)

    @final
    def forward(self, 
                input_features: np.ndarray, 
                input_feature_names: List[str]) -> List[bool]:
        """Produce subsampling indices."""
        mask = self._create_subsample_indices(input_features = input_features,
                                              input_feature_names = input_feature_names)
        self._validate_mask(mask = mask,
                            input_features = input_features)
        return mask
    
    def _validate_mask(self, 
                       mask: List[bool], 
                       input_features: np.ndarray) -> None:
        """Check that the output of the custom mask method meets requirements."""
        try:
            assert isinstance(mask, list)
        except AssertionError as e:
            self.error(f"Subsampling indices must be a list of bools. 
                       Got {type(mask)}.")
            raise e

        try:
            assert len(mask) == len(input_features)
        except AssertionError as e:
            self.error(f"Subsampling method did not return a bool for reach row.")
            raise e
        return

    
    @abstractmethod
    def _create_subsample_indices(self,
                                  input_features: np.ndarray,
                                  input_feature_names: List[str]) -> List[int]:
        """Create a list of integers that defines which rows in `input_features are kept.`
        
            Example:
            input_features = [[1,2,3],
                            [5,5,5],
                            [0,0,1],]
            input_feature_names = ['dom_x', 'dom_y', 'dom_z']

            Suppose we wrote logic that produced the following 
            mask = [0,1]
            
            This would mean that the corresponding subsampled rows would be:
            
            input_features = [[1,2,3],
                              [5,5,5]]"""
        raise NotImplementedError

So a Sampler that would randomly subsample events exceeding some limit could look like so:

class RandomMaxSampler(Sampler):
    """Randomly sample events exceeding a maximum length."""

    def __init__(self, 
                 max_event_size: int,
                 seed: int = 42):
        """Randomly sample available pulses if event is larger than `max_event_size`.

        Args:
            max_event_size: The maximum number of pulses in the event. 
                            Events with more pulses than this will be randomly sampled.
            seed: seed used for random sampling. Defaults to 42.
                            
        """
        self._max_size = max_event_size
        self._seed = seed
    
    def _create_subsample_indices(self,
                                  input_features: np.ndarray,
                                  input_feature_names: List[str]) -> List[int]:
        if input_features.shape[0] > self._max_size:
            mask  = np.random.choice(input_features, self._max_size, seed = self._seed)
        else:
            mask = np.arange(0, len(input_features))
        return mask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants