Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script hangs on .gather_for_metrics() #2785

Open
2 of 4 tasks
priyammaz opened this issue May 15, 2024 · 4 comments
Open
2 of 4 tasks

Script hangs on .gather_for_metrics() #2785

priyammaz opened this issue May 15, 2024 · 4 comments

Comments

@priyammaz
Copy link

System Info

- `Accelerate` version: 0.28.0
- Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
- Python version: 3.12.2
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.1 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 125.27 GB
- GPU type: NVIDIA RTX A6000
- `Accelerate` default config:
	- compute_environment: LOCAL_MACHINE
	- distributed_type: MULTI_GPU
	- mixed_precision: bf16
	- use_cpu: False
	- debug: False
	- num_processes: 2
	- machine_rank: 0
	- num_machines: 1
	- gpu_ids: 0,1
	- rdzv_backend: static
	- same_network: True
	- main_training_function: main
	- downcast_bf16: no
	- tpu_use_cluster: False
	- tpu_use_sudo: False
	- tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.models import resnet18
from accelerate import Accelerator
from transformers import get_cosine_schedule_with_warmup

EPOCHS = 50
GRADIENT_ACCUM_STEPS = 2
BATCH_SIZE = 128
LEARNING_RATE = 0.001

### Init Accelerator ###
accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUM_STEPS)

### Load Model ###
model = resnet18()
model = model.to(accelerator.device)

### Load Dataset ###
transform = transforms.Compose(
    [transforms.Resize((64,64)),
    transforms.ToTensor()]
)

mini_batchsize = BATCH_SIZE // GRADIENT_ACCUM_STEPS 
trainset = datasets.CIFAR10(root=".", train=True, transform=transform, download=True)
testset = datasets.CIFAR10(root=".", train=False, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=mini_batchsize, shuffle=True, num_workers=8, pin_memory=True)
testloader = DataLoader(testset, batch_size=mini_batchsize, shuffle=True, num_workers=8, pin_memory=True)

### Define Loss Function ###
loss_fn = nn.CrossEntropyLoss()

### Define Optimizer ###
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

### Define Scheduler ###
total_training_steps = len(trainloader) * EPOCHS
scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=500, 
                                            num_training_steps=total_training_steps)

### Prepare Everything ###
model, optimizer, trainloader, testloader, scheduler = accelerator.prepare(
    model, optimizer, trainloader, testloader, scheduler
)

losses = []
accs = []

for epoch in range(EPOCHS):

        model.train()

        accumulated_loss = 0 
        accumulated_accuracy = 0
        for images, targets in trainloader:

            ### Move Data to Correct GPU ###
            images, targets = images.to(accelerator.device), targets.to(accelerator.device)
            
            with accelerator.accumulate(model):
                
                ### Pass Through Model ###
                pred = model(images)

                ### Compute and Store Loss ##
                loss = loss_fn(pred, targets)
                accumulated_loss += loss / GRADIENT_ACCUM_STEPS

                ### Compute and Store Accuracy ###
                predicted = pred.argmax(axis=1)
                accuracy = (predicted == targets).sum() / len(predicted)

                ### Update Model ###
                accelerator.backward(loss)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            ### Gather Training Loss Metrics ###
            if accelerator.sync_gradients and accelerator.is_main_process:
                print("gathering")
                gathered_loss = accelerator.gather_for_metrics(accumulated_loss)
                print(gathered_loss)
                gathered_acc = accelerator.gather_for_metrics(accuracy)
                print(gathered_acc)
                accumulated_loss = 0

Expected behavior

This is a very quick script that replicates a problem I have been having on gather() or gather_for_metrics(). When I run this on my machine, the forward pass through the model works just fine, then once we are about to sync gradients, then we can gather our metrics. I print out "gathering" before we do this, and the entire code hangs at "gathering" and nothing seems to happen. I could totally be doing this wrong, im just not sure what it could be?

What I want to be able to do is on my own custom models, to use Accelerate along with its gradient accumulation, and grab the training loss across the GPUs and average them to plot it.

@priyammaz
Copy link
Author

As a baseline, I also ran the example script given for Gradient Accumulation and the gather_for_metrics() runs just fine, but its being done inside the evaluation loop not the training one

@priyammaz
Copy link
Author

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.models import resnet18
from accelerate import Accelerator
from transformers import get_cosine_schedule_with_warmup

EPOCHS = 50
GRADIENT_ACCUM_STEPS = 2
BATCH_SIZE = 128
LEARNING_RATE = 0.001

### Init Accelerator ###
accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUM_STEPS)

### Load Model ###
model = resnet18()
model = model.to(accelerator.device)

### Load Dataset ###
transform = transforms.Compose(
    [transforms.Resize((64,64)),
    transforms.ToTensor()]
)

mini_batchsize = BATCH_SIZE // GRADIENT_ACCUM_STEPS 
trainset = datasets.CIFAR10(root=".", train=True, transform=transform, download=True)
testset = datasets.CIFAR10(root=".", train=False, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=mini_batchsize, shuffle=True, num_workers=8, pin_memory=True)
testloader = DataLoader(testset, batch_size=mini_batchsize, shuffle=True, num_workers=8, pin_memory=True)

### Define Loss Function ###
loss_fn = nn.CrossEntropyLoss()

### Define Optimizer ###
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

### Define Scheduler ###
total_training_steps = len(trainloader) * EPOCHS
scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=500, 
                                            num_training_steps=total_training_steps)

### Prepare Everything ###
model, optimizer, trainloader, testloader, scheduler = accelerator.prepare(
    model, optimizer, trainloader, testloader, scheduler
)

losses = []
accs = []

for epoch in range(EPOCHS):

        model.train()

        accumulated_loss = 0 
        accumulated_accuracy = 0
        for images, targets in trainloader:

            ### Move Data to Correct GPU ###
            images, targets = images.to(accelerator.device), targets.to(accelerator.device)
            
            with accelerator.accumulate(model):
                
                ### Pass Through Model ###
                pred = model(images)

                ### Compute and Store Loss ##
                loss = loss_fn(pred, targets)
                accumulated_loss += loss / GRADIENT_ACCUM_STEPS

                ### Compute and Store Accuracy ###
                predicted = pred.argmax(axis=1)
                accuracy = (predicted == targets).sum() / len(predicted)

                ### Update Model ###
                accelerator.backward(loss)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
           
        model.eval()

        for images, targets in testloader:
            images, targets = images.to(accelerator.device), targets.to(accelerator.device)
            with torch.no_grad():
                pred = model(images)

                #### Compute Loss ###
                loss = loss_fn(pred, targets)

                ### Computed Accuracy ###
                predicted = pred.argmax(axis=1)
                accuracy = (predicted == targets).sum() / len(predicted)

                ### Gather across GPUs ###
                loss_gathered = accelerator.gather_for_metrics(loss)
                accuracy_gathered = accelerator.gather_for_metrics(accuracy)

As a test, I just threw in a validation loop that also gathers the loss and accuracy of the eval data and it works like a charm, I am probably just doing the gather in the training loop wrong due to the accumulate context manager, any advice would be really helpful!

@priyammaz
Copy link
Author

I think I answered my own question... It looks like the problem was gather_for_metrics() isn't happy to be under accelerator.is_main_process(). So this script then works as far as I can tell:

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.models import resnet18
from accelerate import Accelerator
from transformers import get_cosine_schedule_with_warmup

EPOCHS = 50
GRADIENT_ACCUM_STEPS = 2
BATCH_SIZE = 128
LEARNING_RATE = 0.001

### Init Accelerator ###
accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUM_STEPS)

### Load Model ###
model = resnet18()
model = model.to(accelerator.device)

### Load Dataset ###
transform = transforms.Compose(
    [transforms.Resize((64,64)),
    transforms.ToTensor()]
)

mini_batchsize = BATCH_SIZE // GRADIENT_ACCUM_STEPS 
trainset = datasets.CIFAR10(root=".", train=True, transform=transform, download=True)
testset = datasets.CIFAR10(root=".", train=False, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=mini_batchsize, shuffle=True, num_workers=8, pin_memory=True)
testloader = DataLoader(testset, batch_size=mini_batchsize, shuffle=True, num_workers=8, pin_memory=True)

### Define Loss Function ###
loss_fn = nn.CrossEntropyLoss()

### Define Optimizer ###
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

### Define Scheduler ###
total_training_steps = len(trainloader) * EPOCHS
scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=500, 
                                            num_training_steps=total_training_steps)

### Prepare Everything ###
model, optimizer, trainloader, testloader, scheduler = accelerator.prepare(
    model, optimizer, trainloader, testloader, scheduler
)

losses = []
accs = []

for epoch in range(EPOCHS):

        model.train()

        accumulated_loss = 0 
        accumulated_accuracy = 0
        for images, targets in trainloader:

            ### Move Data to Correct GPU ###
            images, targets = images.to(accelerator.device), targets.to(accelerator.device)
            
            with accelerator.accumulate(model):
                
                ### Pass Through Model ###
                pred = model(images)

                ### Compute and Store Loss ##
                loss = loss_fn(pred, targets)
                accumulated_loss += loss / GRADIENT_ACCUM_STEPS

                ### Compute and Store Accuracy ###
                predicted = pred.argmax(axis=1)
                accuracy = (predicted == targets).sum() / len(predicted)
                accumulated_accuracy += accuracy / GRADIENT_ACCUM_STEPS

                ### Update Model ###
                accelerator.backward(loss)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        
            # if accelerator.sync_gradients and accelerator.is_main_process:
            if accelerator.sync_gradients:
                loss_gathered = accelerator.gather_for_metrics(accumulated_loss)
                accuracy_gathered = accelerator.gather_for_metrics(accumulated_accuracy)

                accumulated_loss, accumulated_accuracy = 0, 0

                
        model.eval()

        for images, targets in testloader:
            images, targets = images.to(accelerator.device), targets.to(accelerator.device)
            with torch.no_grad():
                pred = model(images)

                #### Compute Loss ###
                loss = loss_fn(pred, targets)

                ### Computed Accuracy ###
                predicted = pred.argmax(axis=1)
                accuracy = (predicted == targets).sum() / len(predicted)

                ### Gather across GPUs ###
                loss_gathered = accelerator.gather_for_metrics(loss)
                accuracy_gathered = accelerator.gather_for_metrics(accuracy)

If this can be verified that this is the correct way to do this, that would be really helpful!

@muellerzr
Copy link
Collaborator

Yes indeed, it calls gather (hence the name) so it was waiting to hear from the other professes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants