Skip to content

Commit

Permalink
Add missing Type Hints to torch. issue pytorch#7318
Browse files Browse the repository at this point in the history
  • Loading branch information
kimdwkimdw committed Jun 25, 2018
1 parent e31ab99 commit 6cebcdb
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 179 deletions.
2 changes: 1 addition & 1 deletion tools/autograd/gen_python_functions.py
Expand Up @@ -186,7 +186,7 @@ def should_bind(declaration):

py_torch_functions = group_declarations_by_name(declarations, should_bind)

env = create_python_bindings(py_torch_functions, has_self=False)
env = create_python_bindings(py_torch_functions, has_self=False, is_module=True)
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)

Expand Down
28 changes: 17 additions & 11 deletions tools/autograd/templates/python_torch_functions.cpp
Expand Up @@ -269,18 +269,23 @@ static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* k
${py_methods}

static PyMethodDef torch_functions[] = {
{"arange", (PyCFunction)THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"as_tensor", (PyCFunction)THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"clamp", (PyCFunction)THPVariable_clamp, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_promote_types", (PyCFunction)THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"range", (PyCFunction)THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_coo_tensor", (PyCFunction)THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{NULL}
};

PyMethodDef torch_functions_additional[] = {
{"arange", (PyCFunction)THPVariable_arange, METH_VARARGS | METH_KEYWORDS, NULL},
{"as_tensor", (PyCFunction)THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp", (PyCFunction)THPVariable_clamp, METH_VARARGS | METH_KEYWORDS, NULL},
{"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS, NULL},
// {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"_promote_types", (PyCFunction)THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS, NULL},
{"range", (PyCFunction)THPVariable_range, METH_VARARGS | METH_KEYWORDS, NULL},
{"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"sparse_coo_tensor", (PyCFunction)THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS, NULL},
{"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS, NULL},
{"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS, NULL},
${py_method_defs}
{NULL}
};
Expand Down Expand Up @@ -331,6 +336,7 @@ void initTorchFunctions(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPVariableFunctions);
PyModule_AddFunctions(module, torch_functions_additional);
if (PyModule_AddObject(module, "_VariableFunctions", (PyObject*)&THPVariableFunctions) < 0) {
throw python_error();
}
Expand Down

0 comments on commit 6cebcdb

Please sign in to comment.