Skip to content

Scalable logistic regression for multi-GPU, TPU training with PyTorch

Notifications You must be signed in to change notification settings

annikabrundyn/scalable_logistic_regression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 

Repository files navigation

Scalable Logistic Regression with PyTorch Lightning

This logistic regression model allows you to scale to much bigger datasets by having the option to train on multiple GPUS and TPUS. I implemented this model in the PyTorch Lightning Bolts library, where it has been rigorously tested and documented.

I've also implemented the SklearnDataModule - a class that conveniently puts any Numpy array dataset into PyTorch DataLoaders.

Check out the Bolts documentation if you have any questions about how to use this model

I've also written a blog post explaining the relationship between logisitc regression and neural networks and how this allows us to make use of frameworks such as PyTorch to scale our training. Read it here.

An example

Train this model on any Numpy dataset as follows (here I'm demonstrating with the Sklearn Iris dataset):

from pl_bolts.models.regression import LogisticRegression
from pl_bolts.datamodules import SklearnDataModule
import pytorch_lightning as pl

from sklearn.datasets import load_iris

# use any numpy or sklearn dataset
X, y = load_iris(return_X_y=True)
dm = SklearnDataModule(X, y)

# build model
model = LogisticRegression(input_dim=4, num_classes=3)

# fit
trainer = pl.Trainer(gpus=2, precision=16)
trainer.fit(model, dm.train_dataloader(), dm.val_dataloader())

trainer.test(test_dataloaders=dm.test_dataloader(batch_size=12))

To specify the number of GPUs or TPUs, just specify the flag in the Trainer. You can also enable 16-bit precision in the Trainer.

# 1 GPU
trainer = pl.Trainer(gpus=1)

# 8 TPUs
trainer = pl.Trainer(tpu_cores=8)

# 16 GPUs and 16-bit precision
trainer = pl.Trainer(gpus=16, precision=16)

About

Scalable logistic regression for multi-GPU, TPU training with PyTorch

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages