Skip to content

Commit

Permalink
add BCHW support #7
Browse files Browse the repository at this point in the history
to use, set `layout` argument to be 'BCHW' when initializing STN
  • Loading branch information
fxia22 committed Jun 12, 2017
1 parent 4543489 commit b6e3958
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions script/functions/stn.py
Expand Up @@ -48,10 +48,10 @@ def forward(self, input1, input2):
my_lib.BilinearSamplerBCHW_updateOutput(input1, input2, output)
else:

output = output.transpose(1,2).transpose(2,3)
input1 = input1.transpose(1,2).transpose(2,3)
input2 = input2.transpose(1,2).transpose(2,3)

output = output.transpose(1,2).transpose(2,3).contiguous()
input1 = input1.transpose(1,2).transpose(2,3).contiguous()
input2 = input2.transpose(1,2).transpose(2,3).contiguous()
#print(output.size(), input1.size(), input2.size())
output = output.cuda(self.device)
my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, self.device_c)
output = output.transpose(2,3).transpose(1,2)
Expand All @@ -65,9 +65,9 @@ def backward(self, grad_output):
if not grad_output.is_cuda:
my_lib.BilinearSamplerBCHW_updateGradInput(self.input1, self.input2, grad_input1, grad_input2, grad_output)
else:
grad_input1 = grad_input1.transpose(1,2).transpose(2,3)
grad_input2 = grad_input2.transpose(1,2).transpose(2,3)
grad_output = grad_output.transpose(1,2).transpose(2,3)
grad_input1 = grad_input1.transpose(1,2).transpose(2,3).contiguous()
grad_input2 = grad_input2.transpose(1,2).transpose(2,3).contiguous()
grad_output = grad_output.transpose(1,2).transpose(2,3).contiguous()

grad_input1 = grad_input1.cuda(self.device)
grad_input2 = grad_input2.cuda(self.device)
Expand Down
2 changes: 1 addition & 1 deletion script/test.py
Expand Up @@ -62,7 +62,7 @@
out.backward(input1.data)
print(input1.grad.size(), 'time:', time.time() - start)

with torch.cuda.device(3):
with torch.cuda.device(1):
input1 = input1.cuda()
input2 = input2.cuda()
start = time.time()
Expand Down

0 comments on commit b6e3958

Please sign in to comment.