Skip to content

Commit

Permalink
Merge branch 'chaichontat-master' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed May 1, 2022
2 parents c0bc337 + c3b3838 commit cc58f21
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions cellpose/core.py
Expand Up @@ -295,6 +295,7 @@ def network(self, x, return_conv=False):
self.net = mkldnn_utils.to_mkldnn(self.net)
with torch.no_grad():
y, style = self.net(X)
del X
y = self._from_device(y)
style = self._from_device(style)
if return_conv:
Expand Down Expand Up @@ -724,6 +725,7 @@ def _train_step(self, x, lbl):
#else:
self.net.train()
y = self.net(X)[0]
del X
loss = self.loss_fn(lbl,y)
loss.backward()
train_loss = loss.item()
Expand All @@ -736,6 +738,7 @@ def _test_eval(self, x, lbl):
self.net.eval()
with torch.no_grad():
y, style = self.net(X)
del X
loss = self.loss_fn(lbl,y)
test_loss = loss.item()
test_loss *= len(x)
Expand Down
5 changes: 3 additions & 2 deletions cellpose/dynamics.py
Expand Up @@ -77,13 +77,14 @@ def _extend_centers_gpu(neighbors, centers, isneighbor, Ly, Lx, n_iter=200, devi
Tneigh = T[:, pt[:,:,0], pt[:,:,1]]
Tneigh *= isneigh
T[:, pt[0,:,0], pt[0,:,1]] = Tneigh.mean(axis=1)

del meds, isneigh, Tneigh
T = torch.log(1.+ T)
# gradient positions
grads = T[:, pt[[2,1,4,3],:,0], pt[[2,1,4,3],:,1]]
del pt
dy = grads[:,0] - grads[:,1]
dx = grads[:,2] - grads[:,3]

del grads
mu_torch = np.stack((dy.cpu().squeeze(), dx.cpu().squeeze()), axis=-2)
return mu_torch

Expand Down

0 comments on commit cc58f21

Please sign in to comment.