Skip to content

Commit

Permalink
[NumPy] Fix uses of the np.stack() family of functions under NumPy 1.25.
Browse files Browse the repository at this point in the history
NumPy 1.25 release notes: https://github.com/numpy/numpy/releases/tag/v1.25.0

Per the release notes, in NumPy 1.25:

A sequence must now be passed into the stacking family of functions (stack, vstack, hstack, dstack and column_stack). (gh-23019)

This change fixes test failures where a non-sequence was passed to a stack function. The most common fix is to convert the input to a list explicitly before passing it to any stacking functions.

PiperOrigin-RevId: 547165426
  • Loading branch information
hawkinsp authored and Magenta Team committed Jul 11, 2023
1 parent 78661d8 commit 548dc4e
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions magenta/models/music_vae/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,11 @@ def testSliced(self):
tensors = converter.to_tensors(self.sequence)
self.assertArraySetsEqual(tensors.inputs, tensors.outputs)
actual_sliced_labels = [
np.stack(np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1))
for t in tensors.outputs]
np.stack(
[np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1)]
)
for t in tensors.outputs
]

self.assertArraySetsEqual(self.expected_sliced_labels, actual_sliced_labels)

Expand All @@ -999,8 +1002,11 @@ def testSlicedChordConditioned(self):
tensors = converter.to_tensors(self.sequence)
self.assertArraySetsEqual(tensors.inputs, tensors.outputs)
actual_sliced_labels = [
np.stack(np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1))
for t in tensors.outputs]
np.stack(
[np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1)]
)
for t in tensors.outputs
]
actual_sliced_chord_labels = [
np.argmax(t, axis=-1) for t in tensors.controls]

Expand All @@ -1018,8 +1024,11 @@ def testSlicedKeyConditioned(self):
tensors = converter.to_tensors(self.sequence)
self.assertArraySetsEqual(tensors.inputs, tensors.outputs)
actual_sliced_labels = [
np.stack(np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1))
for t in tensors.outputs]
np.stack(
[np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1)]
)
for t in tensors.outputs
]
actual_sliced_key_labels = [
np.argmax(t, axis=-1) for t in tensors.controls]

Expand All @@ -1038,8 +1047,11 @@ def testSlicedChordAndKeyConditioned(self):
tensors = converter.to_tensors(self.sequence)
self.assertArraySetsEqual(tensors.inputs, tensors.outputs)
actual_sliced_labels = [
np.stack(np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1))
for t in tensors.outputs]
np.stack(
[np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1)]
)
for t in tensors.outputs
]
actual_sliced_chord_labels = [
np.argmax(t[:, :-12], axis=-1) for t in tensors.controls]
actual_sliced_key_labels = [
Expand Down

0 comments on commit 548dc4e

Please sign in to comment.