-
Notifications
You must be signed in to change notification settings - Fork 31
/
chief.py
21 lines (18 loc) · 758 Bytes
/
chief.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.optim as optim
import torch.multiprocessing as mp
from torch.autograd import Variable
import time
def chief(rank, params, traffic_light, counter, shared_model, shared_grad_buffers, optimizer):
while True:
time.sleep(1)
# workers will wait after last loss computation
if counter.get() > params.update_treshold:
#print(shared_grad_buffers.grads['mu.weight_grad'])
for n,p in shared_model.named_parameters():
p._grad = Variable(shared_grad_buffers.grads[n+'_grad'])
optimizer.step()
counter.reset()
shared_grad_buffers.reset()
traffic_light.switch() # workers start new loss computation
#print('update')