Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Added lowercase split value and passing tests#42 #67

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -165,7 +165,7 @@ on your local machine.
import tfrecorder

dataset_dict = tfrecorder.load('/path/to/tfrecord_dir')
train = dataset_dict['TRAIN']
train = dataset_dict['train']
```

### Verifying data in TFRecords generated by TFRecorder
Expand All @@ -176,7 +176,7 @@ import tfrecorder

tfrecorder.inspect(
tfrecord_dir='/path/to/tfrecords/',
split='TRAIN',
split='train',
num_records=5,
output_dir='/tmp/output')
```
Expand All @@ -189,7 +189,7 @@ Using the command line:
```bash
tfrecorder inspect \
--tfrecord-dir=/path/to/tfrecords/ \
--split='TRAIN' \
--split='train' \
--num_records=5 \
--output_dir=/tmp/output
```
Expand Down
8 changes: 4 additions & 4 deletions tfrecorder/beam_pipeline.py
Expand Up @@ -276,8 +276,8 @@ def build_pipeline(
# Require training set to be available in the input data. The transform_fn
# and transformed_metadata will be generated from the training set and
# applied to the other datasets, if any
if 'TRAIN' not in split_counts:
raise AttributeError('`TRAIN` set expected to be present in splits')
if 'train' not in split_counts:
raise AttributeError('`train` set expected to be present in splits')

# Split dataset into train, validation, test sets.
partition_fn = functools.partial(_partition_fn, split_key=split_key)
Expand All @@ -300,13 +300,13 @@ def build_pipeline(
metadata=pre_tft_metadata,
label='Train')

if 'VALIDATION' in split_counts:
if 'validation' in split_counts:
_transform_and_write_tfr(
val_data, tfr_writer, transform_fn=transform_fn,
metadata=pre_tft_metadata,
label='Validation')

if 'TEST' in split_counts:
if 'test' in split_counts:
_transform_and_write_tfr(
test_data, tfr_writer, transform_fn=transform_fn,
metadata=pre_tft_metadata,
Expand Down
16 changes: 8 additions & 8 deletions tfrecorder/beam_pipeline_test.py
Expand Up @@ -43,7 +43,7 @@ class BeamPipelineTests(unittest.TestCase):
def test_processing_fn_with_int_label(self):
'Test preprocessing fn with integer label.'
element = {
'split': 'TRAIN',
'split': 'train',
'image_uri': 'gs://foo/bar.jpg',
'label': 1}
my_schema = frozendict.FrozenOrderedDict({
Expand All @@ -60,7 +60,7 @@ def test_processing_fn_with_string_label(self, mock_transform):
mock_transform.compute_and_apply_vocabulary.return_value = tf.constant(
0, dtype=tf.int64)
element = {
'split': 'TRAIN',
'split': 'train',
'image_uri': 'gs://foo/bar.jpg',
'label': tf.constant('cat', dtype=tf.string)}
result = beam_pipeline._preprocessing_fn(
Expand All @@ -85,7 +85,7 @@ def test_partition_fn(self):
'image_uri': 'gs://foo/bar0.jpg',
'label': 1}

for i, part in enumerate(['TRAIN', 'VALIDATION', 'TEST', 'FOO']):
for i, part in enumerate(['train', 'validation', 'test', 'FOO']):
test_data['split'] = part.encode('utf-8')
index = beam_pipeline._partition_fn(test_data, split_key='split')

Expand All @@ -105,14 +105,14 @@ def setUp(self):

def test_all_splits(self):
"""Tests case where train, validation and test data exists"""
expected = {'TRAIN': 2, 'VALIDATION': 2, 'TEST': 2}
expected = {'train': 2, 'validation': 2, 'test': 2}
actual = beam_pipeline.get_split_counts(self.df, self.split_key)
self.assertEqual(actual, expected)

def test_one_split(self):
"""Tests case where only one split (train) exists."""
df = self.df[self.df.split == 'TRAIN']
expected = {'TRAIN': 2}
df = self.df[self.df.split == 'train']
expected = {'train': 2}
actual = beam_pipeline.get_split_counts(df, self.split_key)
self.assertEqual(actual, expected)

Expand Down Expand Up @@ -158,7 +158,7 @@ def test_train(self):

with self.pipeline as p:
with tft_beam.Context(temp_dir=os.path.join(self.test_dir, 'tmp')):
df = self.pre_tft_df[self.pre_tft_df.split == 'TRAIN']
df = self.pre_tft_df[self.pre_tft_df.split == 'train']
dataset = self._get_dataset(p, df)
preprocessing_fn = functools.partial(
beam_pipeline._preprocessing_fn,
Expand All @@ -185,7 +185,7 @@ def test_non_training(self):
with self.pipeline as p:
with tft_beam.Context(temp_dir=os.path.join(self.test_dir, 'tmp')):

df = self.pre_tft_df[self.pre_tft_df.split == 'TEST']
df = self.pre_tft_df[self.pre_tft_df.split == 'test']
dataset = self._get_dataset(p, df)
transform_fn = p | tft_beam.ReadTransformFn(self.transform_fn_path)
beam_pipeline._transform_and_write_tfr(
Expand Down
4 changes: 2 additions & 2 deletions tfrecorder/converter.py
Expand Up @@ -100,15 +100,15 @@ def _read_image_directory(image_dir: str) -> pd.DataFrame:

Example expected directory structure:
image_dir/
TRAIN/
train/
label0/
image_000.jpg
image_001.jpg
...
label1/
image_100.jpg
...
VALIDATION/
validation/
...

Output will be based on `schema.image_csv_schema`.
Expand Down
6 changes: 3 additions & 3 deletions tfrecorder/converter_test.py
Expand Up @@ -352,9 +352,9 @@ def setUp(self):
self.tfrecord_dir = '/path/to/tfrecords'
self.dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
self.datasets = {
'TRAIN': self.dataset,
'VALIDATION': self.dataset,
'TEST': self.dataset,
'train': self.dataset,
'validation': self.dataset,
'test': self.dataset,
}

@mock.patch.object(dataset_loader, 'load', autospec=True)
Expand Down
6 changes: 3 additions & 3 deletions tfrecorder/dataset_loader.py
Expand Up @@ -85,9 +85,9 @@ def load(tfrecord_dir: str) -> Dict[str, tf.data.Dataset]:
This returns a `dict` keyed by dataset split, e.g.
```
{
'TRAIN': <tf.data.Dataset>,
'VALIDATION': <tf.data.Dataset>,
'TEST': <tf.data.Dataset>,
'train': <tf.data.Dataset>,
'validation': <tf.data.Dataset>,
'test': <tf.data.Dataset>,
}
```

Expand Down
12 changes: 6 additions & 6 deletions tfrecorder/test_data/data.csv
@@ -1,7 +1,7 @@
split,image_uri,label
TEST,tfrecorder/test_data/images/TEST/cat/cat-800x600-3.jpg,cat
TEST,tfrecorder/test_data/images/TEST/goat/goat-640x427-3.jpg,goat
TRAIN,tfrecorder/test_data/images/TRAIN/cat/cat-640x853-1.jpg,cat
TRAIN,tfrecorder/test_data/images/TRAIN/goat/goat-640x640-1.jpg,goat
VALIDATION,tfrecorder/test_data/images/VALIDATION/cat/cat-800x600-2.jpg,cat
VALIDATION,tfrecorder/test_data/images/VALIDATION/goat/goat-320x320-2.jpg,goat
test,tfrecorder/test_data/images/TEST/cat/cat-800x600-3.jpg,cat
test,tfrecorder/test_data/images/TEST/goat/goat-640x427-3.jpg,goat
train,tfrecorder/test_data/images/TRAIN/cat/cat-640x853-1.jpg,cat
train,tfrecorder/test_data/images/TRAIN/goat/goat-640x640-1.jpg,goat
validation,tfrecorder/test_data/images/VALIDATION/cat/cat-800x600-2.jpg,cat
validation,tfrecorder/test_data/images/VALIDATION/goat/goat-320x320-2.jpg,goat
Samad999777 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions tfrecorder/test_utils.py
Expand Up @@ -28,8 +28,8 @@

from tfrecorder import input_schema


TEST_DIR = 'tfrecorder/test_data'
#for windows just write 'test_data/' as path
TEST_DIR = 'tfrecorder/test_data/'
TEST_TFRECORDS_DIR = os.path.join(TEST_DIR, 'sample_tfrecords')


Expand Down
2 changes: 1 addition & 1 deletion tfrecorder/types.py
Expand Up @@ -46,7 +46,7 @@ class ImageUri(SupportedType):
class SplitKey(SupportedType):
"""Supports split key columns."""
feature_spec = tf.io.FixedLenFeature([], tf.string)
allowed_values = ['TRAIN', 'VALIDATION', 'TEST', 'DISCARD']
allowed_values = ['train', 'validation', 'test', 'discard']
Samad999777 marked this conversation as resolved.
Show resolved Hide resolved


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion tfrecorder/utils.py
Expand Up @@ -50,7 +50,7 @@ def _save_image_from_record(record: Dict[str, tf.Tensor], outfile: str):

def inspect(
tfrecord_dir: str,
split: str = 'TRAIN',
split: str = 'train',
num_records: int = 1,
output_dir: str = 'output'):
"""Prints contents of TFRecord files generated by TFRecorder.
Expand Down
4 changes: 2 additions & 2 deletions tfrecorder/utils_test.py
Expand Up @@ -64,7 +64,7 @@ def setUp(self):
'image_channels': [image_channels] * num_records,
})
self.tfrecord_dir = 'gs://path/to/tfrecords/dir'
self.split = 'TRAIN'
self.split = 'train'
self.num_records = num_records
self.data = data
self.dataset = tf.data.Dataset.from_tensor_slices(self.data)
Expand All @@ -75,7 +75,7 @@ def test_valid_records(self, mock_fn):

mock_fn.return_value = {self.split: self.dataset}
num_records = len(self.data['image'])

#omit dir='/tmp' if on windows
Samad999777 marked this conversation as resolved.
Show resolved Hide resolved
with tempfile.TemporaryDirectory(dir='/tmp') as dir_:
actual_dir = utils.inspect(
self.tfrecord_dir, split=self.split, num_records=num_records,
Expand Down