Skip to content

Commit

Permalink
Fix Categorify inference and testing (#1874)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Apr 29, 2024
1 parent d43e9fe commit 0b58cc4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
4 changes: 2 additions & 2 deletions cpp/nvtabular/inference/categorify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ namespace nvtabular
// this operator currently only supports CPU arrays
.def_property_readonly("supports", [](py::object self)
{
py::object supports = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("Supports");
py::object supports = py::module_::import("nvtabular").attr("graph").attr("operator").attr("Supports");
return supports.attr("CPU_DICT_ARRAY");
})
.def_property_readonly("supported_formats", [](py::object self)
{
py::object supported = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("DataFormats");
py::object supported = py::module_::import("nvtabular").attr("graph").attr("operator").attr("DataFormats");
return supported.attr("NUMPY_DICT_ARRAY");
});
}
Expand Down
13 changes: 11 additions & 2 deletions nvtabular/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from merlin.dag import BaseOperator, ColumnSelector # noqa pylint: disable=unused-import
from merlin.dag import ( # noqa pylint: disable=unused-import
BaseOperator,
ColumnSelector,
DataFormats,
)

Operator = BaseOperator

# Avoid TENSOR_TABLE by default (for now)
class Operator(BaseOperator):
@property
def supported_formats(self):
return DataFormats.PANDAS_DATAFRAME | DataFormats.CUDF_DATAFRAME
5 changes: 5 additions & 0 deletions tests/unit/ops/test_categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,8 @@ def test_categorify_inference():
output_tensors = inference_op.transform(cats.input_columns, input_tensors)
for key in input_tensors:
assert output_tensors[key].dtype == np.dtype("int64")

# Check results are consistent with python code path
expect = workflow.transform(df)
got = pd.DataFrame(output_tensors)
assert_eq(expect, got)

0 comments on commit 0b58cc4

Please sign in to comment.