Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concatenate bins #543

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions pyfar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,45 @@ def concatenate_channels(signals, caxis=0, broadcasting=False):
return pf.TimeData(data, signals[0].times)
else:
return pf.FrequencyData(data, signals[0].frequencies)


def concatenate_bins(signals):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now have functions to broadcast cshapes of audio objects it would be nice to include this by default to make this more flexible.

"""
Merge multiple FrequencyData objects along the frequency axe.

Parameters
----------
signals : tuple of FrequencyData
The signals to concatenate. All signals must have the same cshape.
the frequency bins get sorted and doubles are removed, the first
entry is used.

Returns
-------
merged : FrequencyData
The merged signal object.
"""
# check input
if not isinstance(signals, (tuple, list)):
raise TypeError(
"Input must be a tuple or list of pf.FrequencyData objects.")

for signal in signals:
if not isinstance(signal, (pf.FrequencyData)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is True for pyfar.Signal due to the inheritance. I think type(signal) != pf.FrequencyData is what you want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be tested then.

raise TypeError(
"All input data must be of type pf.FrequencyData.")

# check matching meta data of input signals.
[signals[0]._assert_matching_meta_data(s) for s in signals]

# concatenate data
data = np.concatenate([s.freq for s in signals], axis=-1)
frequencies = np.concatenate([s.frequencies for s in signals], axis=-1)

# Sort frequency entries
idx = np.argsort(frequencies)
frequencies = frequencies[idx]
data = data[..., idx]

# return merged Signal
return pf.FrequencyData(data, frequencies)
70 changes: 70 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,73 @@ def test_concatenate_assertions():
pf.TimeData([1, 2, 3], [1, 2, 3]))
with pytest.raises(ValueError, match="Comparison only valid against"):
pf.utils.concatenate_channels(signals)


def test_concatenate_bins_frequencydata():
"""Test concatenate_bins function with FrequencyData objects"""
signals = (
pf.FrequencyData([1, 2, 3], [1, 2, 3]),
pf.FrequencyData([4, 5, 6], [4, 5, 6]))
merged = pf.utils.concatenate_bins(signals)
assert isinstance(merged, pf.FrequencyData)
npt.assert_array_equal(merged.frequencies, np.arange(1, 7))


def test_concatenate_bins_frequencydata_with_sort():
"""Test concatenate_bins function with FrequencyData objects"""
signals = (
pf.FrequencyData([1, 3, 5], [1, 3, 5]),
pf.FrequencyData([2, 4, 6], [2, 4, 6]))
merged = pf.utils.concatenate_bins(signals)
assert isinstance(merged, pf.FrequencyData)
npt.assert_array_equal(merged.frequencies, np.arange(1, 7))
npt.assert_array_equal(merged.freq, np.arange(1, 7).reshape((1, 6)))


def test_concatenate_bins_multidim():
"""Test concatenate_bins function with multidimensional input"""
signals = (
pf.FrequencyData(np.array([[1, 2, 3], [4, 5, 6]]), [1, 2, 3]),
pf.FrequencyData(np.array([[7, 8, 9], [10, 11, 12]]), [4, 5, 6]))
merged = pf.utils.concatenate_bins(signals)
assert isinstance(merged, pf.FrequencyData)
npt.assert_array_equal(merged.frequencies, np.arange(1, 7))
npt.assert_array_equal(merged.freq, np.array(
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]))


def test_concatenate_bins_multidim_with_sort():
"""Test concatenate_bins function with multidimensional
input and unsorted frequencies"""
signals = (
pf.FrequencyData(np.array([[1, 3, 5], [7, 9, 11]]), [1, 3, 5]),
pf.FrequencyData(np.array([[2, 4, 6], [8, 10, 12]]), [2, 4, 6]))
merged = pf.utils.concatenate_bins(signals)
assert isinstance(merged, pf.FrequencyData)
npt.assert_array_equal(merged.frequencies, np.arange(1, 7))
npt.assert_array_equal(merged.freq, np.array(
[[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]))


def test_concatenate_bins_assertions():
"""Test assertions"""
# invalid input type
with pytest.raises(
TypeError,
match="All input data must be of type pf.FrequencyData."):
pf.utils.concatenate_bins([1, 2, 3])


def test_input_type():
# Create some FrequencyData objects for testing
signal1 = pf.FrequencyData(np.random.rand(10), np.arange(10))
signal3 = "not a FrequencyData object"

# Test that a TypeError is raised if the input is not a tuple
with pytest.raises(TypeError):
pf.utils.concatenate_bins(signal1)

# Test that a TypeError is raised if the input is not a
# tuple of FrequencyData objects
with pytest.raises(TypeError):
pf.utils.concatenate_bins((signal1, signal3))