Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Commit

Permalink
Merge pull request #25 from google/write-if-non-empty
Browse files Browse the repository at this point in the history
Generate TFRecords only if data exists in a split.
  • Loading branch information
mbernico committed Sep 18, 2020
2 parents f5d5c2d + a912554 commit 7779bee
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -4,7 +4,7 @@ init:
pip install -r requirements.txt

test:
nosetests --with-coverage --nocapture -v --cover-package=tfrecorder
nosetests --with-coverage -v --cover-package=tfrecorder

pylint:
pylint tfrecorder
Expand Down
1 change: 1 addition & 0 deletions tfrecorder/beam_image.py
Expand Up @@ -109,6 +109,7 @@ def process(
logging.warning('Could not load image: %s', image_uri)
logging.error('Exception was: %s', str(e))
self.image_bad_counter.inc()
d['split'] = 'DISCARD'

element.update(d)
yield element
119 changes: 68 additions & 51 deletions tfrecorder/beam_pipeline.py
Expand Up @@ -19,20 +19,22 @@
This file implements the full Beam pipeline for TFRecorder.
"""

from typing import Any, Dict, Generator, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Union

import functools
import logging
import os

import apache_beam as beam
from apache_beam import pvalue
import pandas as pd
import tensorflow_transform as tft
from tensorflow_transform import beam as tft_beam

from tfrecorder import beam_image
from tfrecorder import common
from tfrecorder import constants
from tfrecorder import types


def _get_job_name(job_label: str = None) -> str:
Expand Down Expand Up @@ -138,7 +140,7 @@ def _get_write_to_tfrecord(output_dir: str,
num_shards=num_shards,
)

def _preprocessing_fn(inputs, integer_label: bool = False):
def _preprocessing_fn(inputs: Dict[str, Any], integer_label: bool = False):
"""TensorFlow Transform preprocessing function."""

outputs = inputs.copy()
Expand Down Expand Up @@ -166,7 +168,7 @@ def __init__(self):
# pylint: disable=arguments-differ
def process(
self,
element: Dict[str, Any]
element: List[str],
) -> Generator[Dict[str, Any], None, None]:
"""Loads image and creates image features.
Expand All @@ -178,6 +180,43 @@ def process(
yield element


def get_split_counts(df: pd.DataFrame):
"""Returns number of rows for each data split type given dataframe."""
assert constants.SPLIT_KEY in df.columns
return df[constants.SPLIT_KEY].value_counts().to_dict()


def _transform_and_write_tfr(
dataset: pvalue.PCollection,
tfr_writer: Callable = None,
preprocessing_fn: Optional[Callable] = None,
transform_fn: Optional[types.TransformFn] = None,
label: str = 'data'):
"""Applies TF Transform to dataset and outputs it as TFRecords."""

dataset_metadata = (dataset, constants.RAW_METADATA)

if transform_fn:
transformed_dataset, transformed_metadata = (
(dataset_metadata, transform_fn)
| f'Transform{label}' >> tft_beam.TransformDataset())
else:
if not preprocessing_fn:
preprocessing_fn = lambda x: x
(transformed_dataset, transformed_metadata), transform_fn = (
dataset_metadata
| f'AnalyzeAndTransform{label}' >>
tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))

transformed_data_coder = tft.coders.ExampleProtoCoder(
transformed_metadata.schema)
_ = (
transformed_dataset
| f'Encode{label}' >> beam.Map(transformed_data_coder.encode)
| f'Write{label}' >> tfr_writer(prefix=label.lower()))

return transform_fn


# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
Expand Down Expand Up @@ -246,71 +285,49 @@ def build_pipeline(
| 'ReadImage' >> beam.ParDo(extract_images_fn)
)

# Split dataset into train and validation.
# Note: This will not always reflect actual number of samples per dataset
# written as TFRecords. The succeeding `Partition` operation may mark
# additional samples from other splits as discarded. If a split has all
# its samples discarded, the pipeline will still generate a TFRecord
# file for that split, albeit empty.
split_counts = get_split_counts(df)

# Require training set to be available in the input data. The transform_fn
# and transformed_metadata will be generated from the training set and
# applied to the other datasets, if any
assert 'TRAIN' in split_counts

train_data, val_data, test_data, discard_data = (
image_csv_data | 'SplitDataset' >> beam.Partition(
_partition_fn, len(constants.SPLIT_VALUES))
)

train_dataset = (train_data, constants.RAW_METADATA)
val_dataset = (val_data, constants.RAW_METADATA)
test_dataset = (test_data, constants.RAW_METADATA)

# TensorFlow Transform applied to all datasets.
preprocessing_fn = functools.partial(
_preprocessing_fn,
integer_label=integer_label)
transformed_train_dataset, transform_fn = (
train_dataset
| 'AnalyzeAndTransformTrain' >> tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn))

transformed_train_data, transformed_metadata = transformed_train_dataset
transformed_data_coder = tft.coders.ExampleProtoCoder(
transformed_metadata.schema)

transformed_val_data, _ = (
(val_dataset, transform_fn)
| 'TransformVal' >> tft_beam.TransformDataset()
)

transformed_test_data, _ = (
(test_dataset, transform_fn)
| 'TransformTest' >> tft_beam.TransformDataset()
)
tfr_writer = functools.partial(
_get_write_to_tfrecord, output_dir=job_dir, compress=compression,
num_shards=num_shards)
transform_fn = _transform_and_write_tfr(
train_data, tfr_writer, preprocessing_fn=preprocessing_fn,
label='Train')

# Sinks for TFRecords and metadata.
tfr_writer = functools.partial(_get_write_to_tfrecord,
output_dir=job_dir,
compress=compression,
num_shards=num_shards)
if 'VALIDATION' in split_counts:
_transform_and_write_tfr(
val_data, tfr_writer, transform_fn=transform_fn, label='Validation')

_ = (
transformed_train_data
| 'EncodeTrainData' >> beam.Map(transformed_data_coder.encode)
| 'WriteTrainData' >> tfr_writer(prefix='train'))

_ = (
transformed_val_data
| 'EncodeValData' >> beam.Map(transformed_data_coder.encode)
| 'WriteValData' >> tfr_writer(prefix='val'))

_ = (
transformed_test_data
| 'EncodeTestData' >> beam.Map(transformed_data_coder.encode)
| 'WriteTestData' >> tfr_writer(prefix='test'))
if 'TEST' in split_counts:
_transform_and_write_tfr(
test_data, tfr_writer, transform_fn=transform_fn, label='Test')

_ = (
discard_data
| 'DiscardDataWriter' >> beam.io.WriteToText(
| 'WriteDiscardedData' >> beam.io.WriteToText(
os.path.join(job_dir, 'discarded-data')))

# Output transform function and metadata
# Note: `transform_fn` already contains the transformed metadata
_ = (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(
job_dir))

# Output metadata schema
_ = (transformed_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata(
job_dir, pipeline=p))

return p
97 changes: 97 additions & 0 deletions tfrecorder/beam_pipeline_test.py
Expand Up @@ -16,13 +16,21 @@

"""Tests for beam_pipeline."""

import functools
import glob
import os
import tempfile
import unittest
from unittest import mock

import apache_beam as beam
import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform import beam as tft_beam

from tfrecorder import beam_pipeline
from tfrecorder import constants
from tfrecorder import test_utils


# pylint: disable=protected-access
Expand Down Expand Up @@ -78,5 +86,94 @@ def test_partition_fn(self):
'{} should be index {} but was index {}'.format(part, i, index))


class GetSplitCountsTest(unittest.TestCase):
"""Tests `get_split_counts` function."""

def setUp(self):
self.df = test_utils.get_test_df()

def test_all_splits(self):
"""Tests case where train, validation and test data exists"""
expected = {'TRAIN': 2, 'VALIDATION': 2, 'TEST': 2}
actual = beam_pipeline.get_split_counts(self.df)
self.assertEqual(actual, expected)

def test_one_split(self):
"""Tests case where only one split (train) exists."""
df = self.df[self.df.split == 'TRAIN']
expected = {'TRAIN': 2}
actual = beam_pipeline.get_split_counts(df)
self.assertEqual(actual, expected)

def test_error_no_split_key(self):
"""Tests case no split key/column exists."""
df = self.df.drop(constants.SPLIT_KEY, axis=1)
with self.assertRaises(AssertionError):
beam_pipeline.get_split_counts(df)


class TransformAndWriteTfrTest(unittest.TestCase):
"""Tests `_transform_and_write_tfr` function."""

def setUp(self):
self.pipeline = test_utils.get_test_pipeline()
self.raw_df = test_utils.get_raw_feature_df()
self.temp_dir_obj = tempfile.TemporaryDirectory(dir='/tmp', prefix='test-')
self.test_dir = self.temp_dir_obj.name
self.tfr_writer = functools.partial(
beam_pipeline._get_write_to_tfrecord, output_dir=self.test_dir,
compress='gzip', num_shards=2)
self.converter = tft.coders.CsvCoder(
constants.RAW_FEATURE_SPEC.keys(), constants.RAW_METADATA.schema)
self.transform_fn_path = ('./tfrecorder/test_data/sample_tfrecords')

def tearDown(self):
self.temp_dir_obj.cleanup()

def _get_dataset(self, pipeline, df):
"""Returns dataset `PCollection`."""
return (pipeline
| beam.Create(df.values.tolist())
| beam.ParDo(beam_pipeline.ToCSVRows())
| beam.Map(self.converter.decode))

def test_train(self):
"""Tests case where training data is passed."""

with self.pipeline as p:
with tft_beam.Context(temp_dir=os.path.join(self.test_dir, 'tmp')):
df = self.raw_df[self.raw_df.split == 'TRAIN']
dataset = self._get_dataset(p, df)
transform_fn = (
beam_pipeline._transform_and_write_tfr(
dataset, self.tfr_writer, label='Train'))
_ = transform_fn | tft_beam.WriteTransformFn(self.test_dir)

self.assertTrue(
os.path.isdir(os.path.join(self.test_dir, 'transform_fn')))
self.assertTrue(
os.path.isdir(os.path.join(self.test_dir, 'transformed_metadata')))
self.assertTrue(glob.glob(os.path.join(self.test_dir, 'train*.gz')))
self.assertFalse(glob.glob(os.path.join(self.test_dir, 'validation*.gz')))
self.assertFalse(glob.glob(os.path.join(self.test_dir, 'test*.gz')))

def test_non_training(self):
"""Tests case where dataset contains non-training (e.g. test) data."""

with self.pipeline as p:
with tft_beam.Context(temp_dir=os.path.join(self.test_dir, 'tmp')):

df = self.raw_df[self.raw_df.split == 'TEST']
dataset = self._get_dataset(p, df)
transform_fn = p | tft_beam.ReadTransformFn(self.transform_fn_path)
beam_pipeline._transform_and_write_tfr(
dataset, self.tfr_writer, transform_fn=transform_fn,
label='Test')

self.assertFalse(glob.glob(os.path.join(self.test_dir, 'train*.gz')))
self.assertFalse(glob.glob(os.path.join(self.test_dir, 'validation*.gz')))
self.assertTrue(glob.glob(os.path.join(self.test_dir, 'test*.gz')))


if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion tfrecorder/client.py
Expand Up @@ -73,7 +73,9 @@ def _validate_runner(

if (runner == 'DataflowRunner') & (not tfrecorder_wheel):
raise AttributeError(
'DataflowRunner requires a tfrecorder whl file for remote execution.')
'DataflowRunner requires a tfrecorder whl file for remote execution.')


# def read_image_directory(dirpath) -> pd.DataFrame:
# """Reads image data from a directory into a Pandas DataFrame."""
#
Expand Down
22 changes: 21 additions & 1 deletion tfrecorder/test_utils.py
Expand Up @@ -26,11 +26,13 @@
from apache_beam.testing import test_pipeline
import pandas as pd

from tfrecorder import constants


TEST_DIR = 'tfrecorder/test_data'


def get_test_df():
def get_test_df() -> pd.DataFrame:
"""Gets a test dataframe that works with the data in test_data/."""
return pd.read_csv(os.path.join(TEST_DIR, 'data.csv'))

Expand All @@ -41,6 +43,24 @@ def get_test_data() -> Dict[str, List[Any]]:
return get_test_df().to_dict(orient='list')


def get_raw_feature_df() -> pd.DataFrame:
"""Returns test dataframe having raw feature spec schema."""

df = get_test_df()
df.drop(constants.IMAGE_URI_KEY, axis=1, inplace=True)
df['image_name'] = 'image_name'
df['image'] = 'image'
# Note: TF Transform parser expects string values in input. They will
# be parsed based on the raw feature spec that is passed together with the
# data
df['image_height'] = '48'
df['image_width'] = '48'
df['image_channels'] = '3'
df = df[constants.RAW_FEATURE_SPEC.keys()]

return df


def get_test_pipeline():
"""Gets a test pipeline."""
return test_pipeline.TestPipeline(runner='DirectRunner')
Expand Down
27 changes: 27 additions & 0 deletions tfrecorder/types.py
@@ -0,0 +1,27 @@
# Lint as: python3

# Copyright 2020 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom types."""

from typing import Tuple

from apache_beam.pvalue import PCollection
from tensorflow_transform import beam as tft_beam


BeamDatasetMetadata = tft_beam.tft_beam_io.beam_metadata_io.BeamDatasetMetadata
TransformedMetadata = BeamDatasetMetadata
TransformFn = Tuple[PCollection, TransformedMetadata]

0 comments on commit 7779bee

Please sign in to comment.