Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GMM with Mini-Batches #51

Open
justuswill opened this issue Mar 9, 2023 · 1 comment
Open

GMM with Mini-Batches #51

justuswill opened this issue Mar 9, 2023 · 1 comment

Comments

@justuswill
Copy link

Hi,

Like #7 and #19 I am trying to fit a GMM to a large dataset [10^10, 50] and want to (need to) use mini-batching.

However, in contrast to the previous answers, gmm.fit only accpects a TensorLike and won't work with my data which is a torch.utils.data.DataLoader. Even if I input a torch.utils.data.Dataset it only computes a GMM on the first batch.

What is the preferred way to do what I want to do?

Ideally, I would want my code to work like this:

from pycave.bayes import GaussianMixture as GMM
from torch.utils.data import Dataset, DataLoader

data = Data(DATA_PATH).dataloader(batch_size=256)
assert(type(data) == DataLoader)

gmm = GMM(num_components=3, batch_size=256, trainer_params=dict(accelerator='gpu', devices=1))
class_labels = gmm.fit_predict(data)
means, stds = gmm.model_.means, gmm.model_.covariances

Manually changing the code in gmm/estimator.py (among others) from

num_features = len(data[0])
...
loader = DataLoader(
    dataset_from_tensors(data),
    batch_size=self.batch_size or len(data),
    collate_fn=collate_tensor,
)
is_batch_training = self._num_batches_per_epoch(loader) == 1          # Also, shouldn't this be > anyway?

to

num_features = data.dataset[0].shape[1]
...
loader = data
is_batch_training = True

allows the for error-free fitting and prediction but I am not sure if the output is trustworthy.

@hashim19
Copy link

hashim19 commented Jan 2, 2024

Hi, were you able to solve this issue? I am also trying to do GMM training with mini-batches. My dataset size is huge and I cannot load all the data into the memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants