Skip to content

Commit

Permalink
black formatting, flake linting, mypy corrections across files. Added…
Browse files Browse the repository at this point in the history
… test for tf_example_serialization, added / corrected tests for other data / train files.
  • Loading branch information
bgenchel committed Apr 24, 2024
1 parent 4975d82 commit 5f78d3a
Show file tree
Hide file tree
Showing 22 changed files with 239 additions and 405 deletions.
8 changes: 2 additions & 6 deletions basic_pitch/commandline_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def generating_file_message(output_type: str) -> None:
print(f"\n\n Creating {output_type.replace('_', ' ').lower()}...")


def file_saved_confirmation(
output_type: str, save_path: Union[pathlib.Path, str]
) -> None:
def file_saved_confirmation(output_type: str, save_path: Union[pathlib.Path, str]) -> None:
"""Print a confirmation that the file was saved succesfully
Args:
Expand All @@ -63,9 +61,7 @@ def failed_to_save(output_type: str, save_path: Union[pathlib.Path, str]) -> Non
save_path: The path to output file.
"""
print(
f"\n🚨 Failed to save {output_type.replace('_', ' ').lower()} to {save_path} \n"
)
print(f"\n🚨 Failed to save {output_type.replace('_', ' ').lower()} to {save_path} \n")


@contextmanager
Expand Down
12 changes: 3 additions & 9 deletions basic_pitch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,11 @@
}


def _freq_bins(
bins_per_semitone: int, base_frequency: float, n_semitones: int
) -> np.array:
def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int) -> np.array:
d = 2.0 ** (1.0 / (12 * bins_per_semitone))
bin_freqs = base_frequency * d ** np.arange(bins_per_semitone * n_semitones)
return bin_freqs


FREQ_BINS_NOTES = _freq_bins(
NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES
)
FREQ_BINS_CONTOURS = _freq_bins(
CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES
)
FREQ_BINS_NOTES = _freq_bins(NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)
FREQ_BINS_CONTOURS = _freq_bins(CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)
25 changes: 12 additions & 13 deletions basic_pitch/data/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
import argparse
import os

from pathlib import Path
from typing import Optional


def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None:
default_source = Path.home() / "mir_datasets" / dataset_name
default_destination = Path.home() / "data" / "basic_pitch" / dataset_name
parser.add_argument(
"--source",
default=os.path.join(os.path.expanduser("~"), "mir_datasets", dataset_name),
help="Source directory for mir data. Defaults to local mir_datasets folder.",
default=default_source,
type=Path,
help=f"Source directory for mir data. Defaults to {default_source}",
)
parser.add_argument(
"--destination",
default=os.path.join(
os.path.expanduser("~"), "data", "basic_pitch", dataset_name
),
help="Output directory to write results to. Defaults to local ~/data/basic_pitch/{dataset}/",
default=default_destination,
type=Path,
help=f"Output directory to write results to. Defaults to {default_destination}",
)
parser.add_argument(
"--runner",
Expand All @@ -46,9 +49,7 @@ def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None
action="store_true",
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'",
)
parser.add_argument(
"--batch-size", default=5, type=int, help="Number of examples per tfrecord"
)
parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord")
parser.add_argument(
"--worker-harness-container-image",
default="",
Expand All @@ -58,17 +59,15 @@ def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None


def resolve_destination(namespace: argparse.Namespace, time_created: int) -> str:
return os.path.join(
namespace.destination, str(time_created) if namespace.timestamped else "splits"
)
return os.path.join(namespace.destination, str(time_created) if namespace.timestamped else "splits")


def add_split(
parser: argparse.ArgumentParser,
train_percent: float = 0.8,
validation_percent: float = 0.1,
split_seed: Optional[int] = None,
):
) -> None:
parser.add_argument(
"--train-percent",
type=float,
Expand Down
1 change: 1 addition & 0 deletions basic_pitch/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DOWNLOAD = True
33 changes: 13 additions & 20 deletions basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,39 @@
import os
import random
import time
import sys

from typing import List, Tuple, Optional
from typing import Any, List, Dict, Tuple, Optional

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline
from basic_pitch.data.datasets import DOWNLOAD


class GuitarSetInvalidTracks(beam.DoFn):
def process(self, element: Tuple[str, str], *args, **kwargs):
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
track_id, split = element
yield beam.pvalue.TaggedOutput(split, track_id)


class GuitarSetToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_mic_path", "jams_path"]

def __init__(self, source: str):
def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self):
def setup(self) -> None:
import apache_beam as beam
import mirdata

self.guitarset_remote = mirdata.initialize("guitarset", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()
if (
type(self.filesystem.get_filesystem(self.source)) == beam.io.localfilesystem.LocalFileSystem
and "pytest" not in sys.modules
):
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.guitarset_remote.download()

def process(self, element: List[str], *args, **kwargs):
def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
import tempfile

import mirdata
Expand Down Expand Up @@ -85,9 +83,7 @@ def process(self, element: List[str], *args, **kwargs):
source = getattr(track_remote, attribute)
destination = getattr(track_local, attribute)
os.makedirs(os.path.dirname(destination), exist_ok=True)
with self.filesystem.open(source) as s, open(
destination, "wb"
) as d:
with self.filesystem.open(source) as s, open(destination, "wb") as d:
d.write(s.read())

local_wav_path = f"{track_local.audio_mic_path}_tmp.wav"
Expand Down Expand Up @@ -150,17 +146,14 @@ def determine_split() -> str:
return "test"

guitarset = mirdata.initialize("guitarset")
# guitarset.download()

return [(track_id, determine_split()) for track_id in guitarset.track_ids]


def main(known_args, pipeline_args):
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data(
known_args.train_percent, known_args.validation_percent, known_args.split_seed
)
input_data = create_input_data(known_args.train_percent, known_args.validation_percent, known_args.split_seed)

pipeline_options = {
"runner": known_args.runner,
Expand All @@ -175,7 +168,7 @@ def main(known_args, pipeline_args):
pipeline.run(
pipeline_options,
input_data,
GuitarSetToTfExample(known_args.source),
GuitarSetToTfExample(known_args.source, DOWNLOAD),
GuitarSetInvalidTracks(),
destination,
known_args.batch_size,
Expand Down
2 changes: 1 addition & 1 deletion basic_pitch/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
}


def main():
def main() -> None:
dataset_parser = argparse.ArgumentParser()
dataset_parser.add_argument(
"dataset",
Expand Down
34 changes: 14 additions & 20 deletions basic_pitch/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import os
import uuid
from typing import Dict, List, Tuple, Callable, Union
from typing import Any, Dict, List, Tuple, Callable, Union

import apache_beam as beam
import tensorflow as tf
Expand All @@ -27,39 +27,37 @@

# Beacase beam.GroupIntoBatches isn't supported as of 2.29
class Batch(beam.DoFn):
def __init__(self, batch_size):
def __init__(self, batch_size: int) -> None:
self.batch_size = batch_size

def process(self, element):
def process(self, element: List[Any], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
for i in range(0, len(element), self.batch_size):
yield element[i : i + self.batch_size]


class WriteBatchToTfRecord(beam.DoFn):
def __init__(self, destination):
def __init__(self, destination: str) -> None:
self.destination = destination

def process(self, element):
def process(self, element: Any, *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> None:
if not isinstance(element, list):
element = [element]

logging.info(f"Writing to file batch of length {len(element)}")
# hopefully uuids are unique enough
with tf.io.TFRecordWriter(
os.path.join(self.destination, f"{uuid.uuid4()}.tfrecord")
) as writer:
with tf.io.TFRecordWriter(os.path.join(self.destination, f"{uuid.uuid4()}.tfrecord")) as writer:
for example in element:
writer.write(example.SerializeToString())


def transcription_dataset_writer(
p: beam.Pipeline,
input_data: List[Tuple[str, str]],
to_tf_example: Union[beam.DoFn, Callable],
to_tf_example: Union[beam.DoFn, Callable[[List[Any]], Any]],
filter_invalid_tracks: beam.PTransform,
destination: str,
batch_size: int,
):
) -> None:
valid_track_ids = (
p
| "Create PCollection of track IDS" >> beam.Create(input_data)
Expand All @@ -75,15 +73,13 @@ def transcription_dataset_writer(
(
getattr(valid_track_ids, split)
| f"Combine {split} into giant list" >> beam.transforms.combiners.ToList()
| f"Batch {split}" >> beam.ParDo(Batch(batch_size))
# | f"Batch {split}" >> beam.ParDo(Batch(batch_size))
| f"Batch {split}" >> beam.BatchElements(max_batch_size=batch_size)
| f"Reshuffle {split}" >> beam.Reshuffle() # To prevent fuses
| f"Create tf.Example {split} batch" >> beam.ParDo(to_tf_example)
| f"Write {split} batch to tfrecord"
>> beam.ParDo(WriteBatchToTfRecord(os.path.join(destination, split)))
| f"Write {split} batch to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(os.path.join(destination, split)))
)
getattr(
valid_track_ids, split
) | f"Write {split} index file" >> beam.io.textio.WriteToText(
getattr(valid_track_ids, split) | f"Write {split} index file" >> beam.io.textio.WriteToText(
os.path.join(destination, split, "index.csv"),
num_shards=1,
header="track_id",
Expand All @@ -98,8 +94,6 @@ def run(
filter_invalid_tracks: beam.DoFn,
destination: str,
batch_size: int,
):
) -> None:
with beam.Pipeline(options=PipelineOptions(**pipeline_options)) as p:
transcription_dataset_writer(
p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size
)
transcription_dataset_writer(p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size)
40 changes: 12 additions & 28 deletions basic_pitch/data/tf_example_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def _to_transcription_tfex(
encoded_wav: bytes,
notes_indices: List[Tuple[int, int]],
notes_values: List[float],
onsets_indices: List[float],
onsets_indices: List[int],
onsets_values: List[float],
contours_indices: List[float],
contours_indices: List[int],
contours_values: List[float],
notes_onsets_shape: Tuple[int, int],
contours_shape: Tuple[int, int],
Expand All @@ -62,30 +62,14 @@ def _to_transcription_tfex(
"file_id": bytes_feature(bytes(file_id, "utf-8")),
"source": bytes_feature(bytes(source, "utf-8")),
"audio_wav": bytes_feature(encoded_wav),
"notes_indices": bytes_feature(
tf.io.serialize_tensor(np.array(notes_indices, np.int64))
),
"notes_values": bytes_feature(
tf.io.serialize_tensor(np.array(notes_values, np.float32))
),
"onsets_indices": bytes_feature(
tf.io.serialize_tensor(np.array(onsets_indices, np.int64))
),
"onsets_values": bytes_feature(
tf.io.serialize_tensor(np.array(onsets_values, np.float32))
),
"contours_indices": bytes_feature(
tf.io.serialize_tensor(np.array(contours_indices, np.int64))
),
"contours_values": bytes_feature(
tf.io.serialize_tensor(np.array(contours_values, np.float32))
),
"notes_onsets_shape": bytes_feature(
tf.io.serialize_tensor(np.array(notes_onsets_shape, np.int64))
),
"contours_shape": bytes_feature(
tf.io.serialize_tensor(np.array(contours_shape, np.int64))
),
"notes_indices": bytes_feature(tf.io.serialize_tensor(np.array(notes_indices, np.int64))),
"notes_values": bytes_feature(tf.io.serialize_tensor(np.array(notes_values, np.float32))),
"onsets_indices": bytes_feature(tf.io.serialize_tensor(np.array(onsets_indices, np.int64))),
"onsets_values": bytes_feature(tf.io.serialize_tensor(np.array(onsets_values, np.float32))),
"contours_indices": bytes_feature(tf.io.serialize_tensor(np.array(contours_indices, np.int64))),
"contours_values": bytes_feature(tf.io.serialize_tensor(np.array(contours_values, np.float32))),
"notes_onsets_shape": bytes_feature(tf.io.serialize_tensor(np.array(notes_onsets_shape, np.int64))),
"contours_shape": bytes_feature(tf.io.serialize_tensor(np.array(contours_shape, np.int64))),
}
)
)
Expand All @@ -97,9 +81,9 @@ def to_transcription_tfexample(
audio_wav_file_path: str,
notes_indices: List[Tuple[int, int]],
notes_values: List[float],
onsets_indices: List[float],
onsets_indices: List[int],
onsets_values: List[float],
contours_indices: List[float],
contours_indices: List[int],
contours_values: List[float],
notes_onsets_shape: Tuple[int, int],
contours_shape: Tuple[int, int],
Expand Down

0 comments on commit 5f78d3a

Please sign in to comment.