Skip to content

Commit

Permalink
remove mirdata inline import from guitarset bc unnecessary, add ikala…
Browse files Browse the repository at this point in the history
… dataset file, test file, add as option to download.py, black formatting
  • Loading branch information
bgenchel committed May 1, 2024
1 parent 3296466 commit 0caa311
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 9 deletions.
1 change: 0 additions & 1 deletion basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def setup(self) -> None:
def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
import tempfile

import mirdata
import numpy as np
import sox

Expand Down
190 changes: 190 additions & 0 deletions basic_pitch/data/datasets/ikala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import random
import sys
import time
from typing import Any, Dict, List, Tuple, Optional

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline
from basic_pitch.data.datasets import DOWNLOAD


class IkalaInvalidTracks(beam.DoFn):
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
track_id, split = element
yield beam.pvalue.TaggedOutput(split, track_id)


class IkalaToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "notes_pyin_path", "f0_path"]

def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self) -> None:
import apache_beam as beam
import os
import mirdata

self.ikala_remote = mirdata.initialize("ikala", data_home=os.path.join(self.source, "iKala"))
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.ikala_remote.download()

def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
import tempfile

import numpy as np
import sox

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_CONTOURS,
N_FREQ_BINS_NOTES,
)
from basic_pitch.data import tf_example_serialization

logging.info(f"Processing {element}")
batch = []

for track_id in element:
track_remote = self.ikala_remote.track(track_id)
with tempfile.TemporaryDirectory() as local_tmp_dir:
ikala_local = mirdata.initialize("ikala", local_tmp_dir)
track_local = ikala_local.track(track_id)

for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())

local_wav_path = "{}_tmp.wav".format(track_local.audio_path)

tfm = sox.Transformer()
tfm.rate(AUDIO_SAMPLE_RATE)
tfm.remix({1: [2]})
tfm.channels(AUDIO_N_CHANNELS)
tfm.build(track_local.audio_path, local_wav_path)

duration = sox.file_info.duration(local_wav_path)
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

if track_local.notes_pyin is not None:
note_indices, note_values = track_local.notes_pyin.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz"
)
onset_indices, onset_values = track_local.notes_pyin.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
note_shape = (n_time_frames, N_FREQ_BINS_NOTES)
# if there are no notes, return empty note indices
else:
note_indices = []
onset_indices = []
note_values = []
onset_values = []
note_shape = (0, 0)

contour_indices, contour_values = track_local.f0.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)

batch.append(
tf_example_serialization.to_transcription_tfexample(
track_id,
"ikala",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
note_shape,
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)
return [batch]


def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
assert train_percent < 1.0, "Don't over allocate the data!"

# Test percent is 1 - train - validation
validation_bound = train_percent

if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
return "train"
return "validation"

ikala = mirdata.initialize("ikala")

return [(track_id, determine_split()) for track_id in ikala.track_ids]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)

pipeline_options = {
"runner": known_args.runner,
"job_name": f"ikala-tfrecords-{time_created}",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2", "no_use_multiple_sdk_containers"],
"save_main_session": True,
"worker_harness_container_image": known_args.worker_harness_container_image,
}
input_data = create_input_data(known_args.train_percent, known_args.split_seed)
pipeline.run(
pipeline_options,
input_data,
IkalaToTfExample(known_args.source, DOWNLOAD),
IkalaInvalidTracks(known_args.source),
destination,
known_args.batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

main(known_args, pipeline_args)
2 changes: 2 additions & 0 deletions basic_pitch/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from basic_pitch.data import commandline
from basic_pitch.data.datasets.guitarset import main as guitarset_main
from basic_pitch.data.datasets.ikala import main as ikala_main

logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand All @@ -14,6 +15,7 @@

DATASET_DICT = {
"guitarset": guitarset_main,
"ikala": ikala_main
}


Expand Down
70 changes: 70 additions & 0 deletions tests/data/test_ikala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import apache_beam as beam
import itertools
import os

from apache_beam.testing.test_pipeline import TestPipeline

from basic_pitch.data.datasets.ikala import (
IkalaInvalidTracks,
create_input_data,
)


def test_guitar_set_to_tf_example(tmpdir: str) -> None:
# TODO: Acquire test data
pass


def test_ikala_invalid_tracks(tmpdir: str) -> None:
split_labels = ["train", "validation"]
input_data = [(str(i), split) for i, split in enumerate(split_labels)]
with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(IkalaInvalidTracks()).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
)

for i, split in enumerate(split_labels):
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == str(i)


def test_create_input_data() -> None:
data = create_input_data(train_percent=0.5)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.05
for key, group in itertools.groupby(data, lambda el: el[1]):
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)


def test_create_input_data_overallocate() -> None:
try:
create_input_data(train_percent=1.1)
except AssertionError:
assert True
else:
assert False
40 changes: 32 additions & 8 deletions tests/data/test_tf_example_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,58 @@ def test_to_transcription_tfexample(tmpdir: str) -> None:
assert example.features.feature["source"].bytes_list.value[0].decode("utf-8") == source
assert example.features.feature["audio_wav"].bytes_list.value[0] == open(tmpfile, "rb").read()
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["notes_indices"].bytes_list.value[0], out_type=tf.int64)
tf.io.parse_tensor(
example.features.feature["notes_indices"].bytes_list.value[0],
out_type=tf.int64,
)
== notes_indices
)
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["notes_values"].bytes_list.value[0], out_type=tf.float32)
tf.io.parse_tensor(
example.features.feature["notes_values"].bytes_list.value[0],
out_type=tf.float32,
)
== notes_values
)
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["onsets_indices"].bytes_list.value[0], out_type=tf.int64)
tf.io.parse_tensor(
example.features.feature["onsets_indices"].bytes_list.value[0],
out_type=tf.int64,
)
== onsets_indices
)
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["onsets_values"].bytes_list.value[0], out_type=tf.float32)
tf.io.parse_tensor(
example.features.feature["onsets_values"].bytes_list.value[0],
out_type=tf.float32,
)
== onsets_values
)
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["contours_indices"].bytes_list.value[0], out_type=tf.int64)
tf.io.parse_tensor(
example.features.feature["contours_indices"].bytes_list.value[0],
out_type=tf.int64,
)
== contours_indices
)
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["contours_values"].bytes_list.value[0], out_type=tf.float32)
tf.io.parse_tensor(
example.features.feature["contours_values"].bytes_list.value[0],
out_type=tf.float32,
)
== contours_values
)
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["notes_onsets_shape"].bytes_list.value[0], out_type=tf.int64)
tf.io.parse_tensor(
example.features.feature["notes_onsets_shape"].bytes_list.value[0],
out_type=tf.int64,
)
== notes_onsets_shape
)
assert tf.reduce_all(
tf.io.parse_tensor(example.features.feature["contours_shape"].bytes_list.value[0], out_type=tf.int64)
tf.io.parse_tensor(
example.features.feature["contours_shape"].bytes_list.value[0],
out_type=tf.int64,
)
== contours_shape
)

0 comments on commit 0caa311

Please sign in to comment.