Skip to content

Commit

Permalink
handle copying of partitioned stat files when saving workflow (#1838)
Browse files Browse the repository at this point in the history
* handle copying of partitioned stat files when saving workflow

* undo changes to tox
  • Loading branch information
nv-alaiacano committed Jun 7, 2023
1 parent 66c6e3a commit 4b7957a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 7 deletions.
16 changes: 14 additions & 2 deletions nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,12 +1861,24 @@ def _copy_storage(existing_stats, existing_path, new_path, copy):
existing_fs = get_fs_token_paths(existing_path)[0]
new_fs = get_fs_token_paths(new_path)[0]
new_locations = {}

for column, existing_file in existing_stats.items():
new_file = existing_file.replace(str(existing_path), str(new_path))
if copy and new_file != existing_file:
new_fs.makedirs(os.path.dirname(new_file), exist_ok=True)
with new_fs.open(new_file, "wb") as output:
output.write(existing_fs.open(existing_file, "rb").read())

# For some ops, the existing "file" is a directory containing `part.N.parquet` files.
# In that case, new_file is actually a directory and we will iterate through the "part"
# files and copy them individually
if os.path.isdir(existing_file):
new_fs.makedirs(new_file, exist_ok=True)
for existing_file_part in existing_fs.ls(existing_file):
new_file_part = os.path.join(new_file, os.path.basename(existing_file_part))
with new_fs.open(new_file_part, "wb") as output:
output.write(existing_fs.open(existing_file_part, "rb").read())
else:
with new_fs.open(new_file, "wb") as output:
output.write(existing_fs.open(existing_file, "rb").read())

new_locations[column] = new_file

Expand Down
15 changes: 10 additions & 5 deletions nvtabular/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import inspect
import json
import logging
import os
import sys
import time
import types
import warnings
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import cloudpickle
import fsspec
Expand Down Expand Up @@ -295,12 +296,12 @@ def _getmodules(cls, fs):

return [mod for mod in result if mod.__name__ not in exclusions]

def save(self, path, modules_byvalue=None):
def save(self, path: Union[str, os.PathLike], modules_byvalue=None):
"""Save this workflow to disk
Parameters
----------
path: str
path: Union[str, os.PathLike]
The path to save the workflow to
modules_byvalue:
A list of modules that should be serialized by value. This
Expand All @@ -314,6 +315,8 @@ def save(self, path, modules_byvalue=None):
# avoid a circular import getting the version
from nvtabular import __version__ as nvt_version

path = str(path)

fs = fsspec.get_fs_token_paths(path)[0]

fs.makedirs(path, exist_ok=True)
Expand Down Expand Up @@ -385,12 +388,12 @@ def save(self, path, modules_byvalue=None):
cloudpickle.unregister_pickle_by_value(sys.modules[m])

@classmethod
def load(cls, path, client=None) -> "Workflow":
def load(cls, path: Union[str, os.PathLike], client=None) -> "Workflow":
"""Load up a saved workflow object from disk
Parameters
----------
path: str
path: Union[str, os.PathLike]
The path to load the workflow from
client: distributed.Client, optional
The Dask distributed client to use for multi-gpu processing and multi-node processing
Expand All @@ -403,6 +406,8 @@ def load(cls, path, client=None) -> "Workflow":
# avoid a circular import getting the version
from nvtabular import __version__ as nvt_version

path = str(path)

fs = fsspec.get_fs_token_paths(path)[0]

# check version information from the metadata blob, and warn if we have a mismatch
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,43 @@ def test_workflow_saved_schema(tmpdir):
assert node.output_schema is not None


def test_stat_op_workflow_roundtrip(tmpdir):
"""
Categorify and TargetEncoding produce intermediate stats files that must be properly
saved and re-loaded.
"""
N = 100

df = Dataset(
make_df(
{
"a": np.random.randint(0, 100000, N),
"item_id": np.random.randint(0, 100, N),
"user_id": np.random.randint(0, 100, N),
"click": np.random.randint(0, 2, N),
}
),
)

outputs = ["a"] >> nvt.ops.Categorify()

continuous = (
["user_id", "item_id"]
>> nvt.ops.TargetEncoding(["click"], kfold=1, p_smooth=20)
>> nvt.ops.Normalize()
)
outputs += continuous
wf = nvt.Workflow(outputs)

wf.fit(df)
expected = wf.transform(df).compute()
wf.save(tmpdir)

wf2 = nvt.Workflow.load(tmpdir)
transformed = wf2.transform(df).compute()
assert_eq(transformed, expected)


def test_workflow_infer_modules_byvalue(tmp_path):
module_fn = tmp_path / "not_a_real_module.py"
sys.path.append(str(tmp_path))
Expand Down

0 comments on commit 4b7957a

Please sign in to comment.