Skip to content

Commit

Permalink
[ONNX] Fix pow op export [1.5.1] (#39791)
Browse files Browse the repository at this point in the history
* [ONNX] Fix pow op export (#38065)

Summary:
Fix pow type cast for opset 9 and update opset 12
Pull Request resolved: #38065

Differential Revision: D21485353

Pulled By: malfet

fbshipit-source-id: 3993e835ffad07b2e6585eb5cf1cb7c8474de2ec

* Update ort-nighly version as suggested in #39685 (comment)

* Apply changes from #37846 to  `test_topk_smallest_unsorted`

Co-authored-by: neginraoof <neginmr@utexas.edu>
  • Loading branch information
malfet and neginraoof committed Jun 11, 2020
1 parent dfe8cdf commit 3c31d73
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .jenkins/caffe2/test.sh
Expand Up @@ -144,7 +144,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# default pip version is too old(9.0.2), unable to support tag `manylinux2010`.
# Fix the pip error: Couldn't find a version that satisfies the requirement
sudo pip install --upgrade pip
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.1.0.dev1228
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.3.0.dev202005123
fi
"$ROOT_DIR/scripts/onnx/test.sh"
fi
27 changes: 26 additions & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -1342,6 +1342,27 @@ def forward(self, input):
model = StandardDeviation()
self.run_test(model, x)

def test_pow(self):
class PowModule(torch.nn.Module):
def forward(self, x, y):
return x.pow(y)

x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(PowModule(), (x, y))

x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
self.run_test(PowModule(), (x, y))

x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 3, 4))
self.run_test(PowModule(), (x, y))

x = torch.randn(2, 3, 4).to(dtype=torch.float64)
y = torch.randint(10, (2, 3, 4))
self.run_test(PowModule(), (x, y))

def test_std_along_dims(self):
class StandardDeviation(torch.nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -1463,7 +1484,11 @@ def forward(self, x):
def test_topk_smallest_unsorted(self):
class MyModule(torch.nn.Module):
def forward(self, x, k):
return torch.topk(x, k, largest=False, sorted=False)
# When sorted=False, order of elements in the outout tensors
# are not expected to match between PyTorch and ORT
topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
topk_sorted = torch.topk(x, k, largest=False, sorted=True)
return topk_sorted, torch.sort(topk_unsorted.values).values

x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
Expand Down
5 changes: 5 additions & 0 deletions torch/onnx/symbolic_opset12.py
Expand Up @@ -69,5 +69,10 @@ def nll_loss(g, self, target, weight, reduction, ignore_index):
nllloss = g.op("Div", nllloss, denominator)
return nllloss


def nll_loss2d(g, self, target, weight, reduction, ignore_index):
return nll_loss(g, self, target, weight, reduction, ignore_index)


def pow(g, self, exponent):
return g.op("Pow", self, exponent)
11 changes: 10 additions & 1 deletion torch/onnx/symbolic_opset9.py
Expand Up @@ -1230,7 +1230,16 @@ def log1p(g, self):


def pow(g, self, exponent):
return g.op("Pow", self, exponent)
f_dtype = self_dtype = self.type().scalarType()
if not sym_help._is_fp(self):
f_dtype = 'Float'
self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[f_dtype])
if not sym_help._is_fp(exponent):
exponent = g.op("Cast", exponent, to_i=sym_help.cast_pytorch_to_onnx[f_dtype])
pow = g.op("Pow", self, exponent)
if self_dtype and self_dtype != f_dtype:
pow = g.op("Cast", pow, to_i=sym_help.cast_pytorch_to_onnx[self_dtype])
return pow


def clamp(g, self, min, max):
Expand Down

0 comments on commit 3c31d73

Please sign in to comment.