Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch utilities sampler #424

Open
glemaitre opened this issue May 13, 2018 · 5 comments
Open

PyTorch utilities sampler #424

glemaitre opened this issue May 13, 2018 · 5 comments
Labels
Type: Enhancement Indicates new feature requests

Comments

@glemaitre
Copy link
Member

We could add utilities for PyTorch.
Basically it should be inheriting from torch.utils.data.Sampler.

The implementation could look like something:

class BalancedSampler(Sampler):

    def __init__(self, X, y, sampler=None, random_state=None):
        self.X = X
        self.y = y
        self.sampler = sampler
        self.random_state = random_state
        self._sample()

    def _sample(self):
        random_state = check_random_state(self.random_state)
        if self.sampler is None:
            self.sampler_ = RandomUnderSampler(return_indices=True,
                                               random_state=random_state)
        else:
            if not hasattr(self.sampler, 'return_indices'):
                raise ValueError("'sampler' needs to return the indices of "
                                 "the samples selected. Provide a sampler "
                                 "which has an attribute 'return_indices'.")
            self.sampler_ = clone(self.sampler)
            self.sampler_.set_params(return_indices=True)
            set_random_state(self.sampler_, random_state)

        _, _, self.indices_ = self.sampler_.fit_sample(self.X, self.y)
        # shuffle the indices since the sampler are packing them by class
        random_state.shuffle(self.indices_)

    def __iter__(self):
        return iter(self.indices_.tolist())

    def __len__(self):
        return len(self.X.shape[0])
@glemaitre glemaitre added Type: Enhancement Indicates new feature requests new features labels May 13, 2018
@chkoar
Copy link
Member

chkoar commented May 24, 2018

I can't help with this. I have never had the chance to play with PyTorch.

@kaihhe
Copy link

kaihhe commented Aug 16, 2019

Is there any difference between I resample the data with the samplers before feed into neural networks and using the generators to train?

@glemaitre
Copy link
Member Author

glemaitre commented Aug 19, 2019 via email

@mattbev
Copy link

mattbev commented Jun 16, 2023

@glemaitre has any progress been made on this?

@tuhinsharma121
Copy link

@jnothman @glemaitre Can I take it up if nobody is working on it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Type: Enhancement Indicates new feature requests
Projects
None yet
Development

No branches or pull requests

5 participants