Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementations of the custom layers in C++ and Native Pytorch for CPU support. #212

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ _ext
*.o
work
work/*
_ext/
_ext/
model_weights
4 changes: 4 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ cd ../channelnorm_package
rm -rf *_cuda.egg-info build dist __pycache__
python3 setup.py install --user

cd ../correlation_cpp_package
rm -rf *_cuda.egg-info build dist __pycache__
python3 setup.py install --user

cd ..
12 changes: 11 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
try:
from networks.resample2d_package.resample2d import Resample2d
from networks.channelnorm_package.channelnorm import ChannelNorm
# PyTorch versions
# To use the CPU implementation of Resample2D and Channelnorm uncomment the
# two lines below and comment the two lines above.
# from networks.channelnorm import ChannelNorm
# from networks.resample2d import Resample2d

from networks import FlowNetC
from networks import FlowNetS
Expand All @@ -18,7 +23,12 @@
except:
from .networks.resample2d_package.resample2d import Resample2d
from .networks.channelnorm_package.channelnorm import ChannelNorm

# PyTorch versions
# To use the CPU implementation of Resample2D and Channelnorm uncomment the
# two lines below and comment the two lines above.
# from .networks.channelnorm import ChannelNorm
# from .networks.resample2d import Resample2d

from .networks import FlowNetC
from .networks import FlowNetS
from .networks import FlowNetSD
Expand Down
5 changes: 5 additions & 0 deletions networks/FlowNetC.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import numpy as np

from .correlation_package.correlation import Correlation
# To use CPU implementation of correlation in C++ comment line above and uncomment
# the line below.
# from .correlation_cpp_package.correlation import Correlation
# PyTorch Version
# from .correlation import Correlation

from .submodules import *
'Parameter count , 39,175,298 '
Expand Down
39 changes: 39 additions & 0 deletions networks/channelnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from torch.autograd import Function, Variable
from torch.nn.modules.module import Module
# import channelnorm_cpp

class ChannelNormFunction(Function):

@staticmethod
def forward(ctx, input1, norm_deg=2):
assert input1.is_contiguous()
b, c, h, w = input1.size()
output = input1.new(b, 1, h, w).zero_()
output = torch.pow(input1, norm_deg)
output = torch.sqrt( torch.sum(output, dim=1))
ctx.save_for_backward(input1, output)
ctx.norm_deg = norm_deg
return output.unsqueeze(0)

@staticmethod
def backward(ctx, grad_output):
"""
Not Implemented!
"""
input1, output = ctx.saved_tensors

grad_input1 = Variable(input1.new(input1.size()).zero_())

return grad_input1, None


class ChannelNorm(Module):

def __init__(self, norm_deg=2):
super(ChannelNorm, self).__init__()
self.norm_deg = norm_deg

def forward(self, input1):
return ChannelNormFunction.apply(input1, self.norm_deg)

121 changes: 121 additions & 0 deletions networks/correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Copyright 2020 Samim Taray

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0
"""

import torch
from torch.nn.modules.module import Module
from torch.autograd import Function
from torch.nn import ZeroPad2d
# import correlation_cuda
import code

class CorrelationFunction(Function):

def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
super(CorrelationFunction, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply

def extractwindow(self, f2pad, i, j):
hindex = torch.tensor( range(i,i+(2*self.max_displacement)+1, self.stride2) )
windex = torch.tensor( range(j,j+(2*self.max_displacement)+1, self.stride2) )
# Advanced indexing logic. Ref: https://github.com/pytorch/pytorch/issues/1080
# the way advance indexing works:
# ---> f2pad[:, :, hindex] chose value at f2pad at hindex location, then
# ---> appending [:, :, :, windex] to it only choses values at windex.
# ---> Thus it choses value at the alternative location of f2pad
# win = f2pad[:,:, i:i+(2*self.max_displacement)+1, j:j+(2*self.max_displacement)+1]

win = f2pad[:, :, hindex][:, :, :, windex]
return win

def forward(self, f1, f2):
self.save_for_backward(f1, f2)
f1b = f1.shape[0] #batch
f1c = f1.shape[1] #channel
f1h = f1.shape[2] #height
f1w = f1.shape[3] #width

f2b = f2.shape[0] #batch
f2c = f2.shape[1] #channel
f2h = f2.shape[2] #height
f2w = f2.shape[3] #width

# generate padded f2
padder = ZeroPad2d(self.pad_size)
f2pad = padder(f2)

# Define output shape and initialize it
outc = (2*(self.max_displacement/self.stride2)+1) * (2*(self.max_displacement/self.stride2)+1)
outc = int(outc) # number of output channel
outb = f1b # size of output batch
outh = f1h # size of output height
outw = f1w # size of output width
output = torch.ones((outb, outc, outh, outw))
# this gives device type
output = output.to(f1.device)

for i in range(f1h):
for j in range(f1w):
# Extract window W around i,j from f2pad of size (1X256X21X21)
win = self.extractwindow(f2pad, i, j)
# Extract kernel: size [1, 256, 1, 1]
k = f1[:, :, i, j].unsqueeze(2).unsqueeze(3)
# boradcasting multiplication along channel dimension
# it multiplies all the 256 element of k to win and keep the result as it is
# size of mult: 1, 256, 21, 21
mult = win * k
# Sum along channel dimension to get dot product. size 1X21X21
inner_prod = torch.sum(mult, dim = 1)

# Flatten last 2 dimensions h,w to one dimension of h*w = no of channels in output
# size 1X1X1X441
inner_prod = inner_prod.flatten(-2, -1).unsqueeze(1).unsqueeze(1)
output[:, :, i, j] = inner_prod
# return the average
return output/f1c

def backward(self, grad_output):
"""
Not Implemented!
"""
input1, input2 = self.saved_tensors
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()

grad_input1 = input1.new()
grad_input2 = input2.new()

correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)

return grad_input1, grad_input2


class Correlation(Module):
def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
super(Correlation, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply

def forward(self, input1, input2):

result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)

return result

145 changes: 145 additions & 0 deletions networks/correlation_cpp_package/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
*.pyc
.torch
_ext
*.o
work
work/*
_ext/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/