Skip to content

Commit

Permalink
Replace use of tf.Dimension.value, which has been deprecated in TF2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 314395568
Change-Id: I29212e28642a063f7989297ecb8dd4f74a09bdcc
  • Loading branch information
adarob authored and Copybara-Service committed Jun 2, 2020
1 parent 20cbc69 commit caf088f
Show file tree
Hide file tree
Showing 16 changed files with 43 additions and 43 deletions.
10 changes: 5 additions & 5 deletions magenta/contrib/rnn.py
Expand Up @@ -111,11 +111,11 @@ def __init__(self,
for shape in shapes:
if shape.ndims != 2:
raise ValueError("linear is expecting 2D arguments: %s" % shapes)
if shape.dims[1].value is None:
if shape[1] is None:
raise ValueError("linear expects shape[1] to be provided for shape %s, "
"but saw %s" % (shape, shape[1]))
else:
total_arg_size += shape.dims[1].value
total_arg_size += int(shape[1])

dtype = [a.dtype for a in args][0]

Expand Down Expand Up @@ -553,7 +553,7 @@ def _lstm_block_cell(x,
ValueError: If cell_size is None.
"""
if wci is None:
cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
cell_size = int(cs_prev.shape.with_rank(2)[1])
if cell_size is None:
raise ValueError("cell_size from `cs_prev` should not be None.")
wci = tf.constant(0, dtype=tf.float32, shape=[cell_size])
Expand Down Expand Up @@ -682,10 +682,10 @@ def output_size(self):
return self._num_units

def build(self, inputs_shape):
if not inputs_shape.dims[1].value:
if not inputs_shape[1]:
raise ValueError(
"Expecting inputs_shape[1] to be set: %s" % str(inputs_shape))
input_size = inputs_shape.dims[1].value
input_size = int(inputs_shape[1])
self._kernel = self.add_variable(
self._names["W"], [input_size + self._num_units, self._num_units * 4])
self._bias = self.add_variable(
Expand Down
4 changes: 2 additions & 2 deletions magenta/contrib/seq2seq.py
Expand Up @@ -50,8 +50,8 @@ def _transpose_batch_time(x):
x, tf.concat(([1, 0], tf.range(2, x_rank)), axis=0))
x_t.set_shape(
tf.TensorShape(
[x_static_shape.dims[1].value,
x_static_shape.dims[0].value]).concatenate(x_static_shape[2:]))
[x_static_shape[1], x_static_shape[0]]).concatenate(
x_static_shape[2:]))
return x_t


Expand Down
4 changes: 2 additions & 2 deletions magenta/models/gansynth/lib/model.py
Expand Up @@ -194,7 +194,7 @@ def __init__(self, stage_id, batch_size, config):

# gen_one_hot_labels = real_one_hot_labels
gen_one_hot_labels = data_helper.provide_one_hot_labels(batch_size)
num_tokens = real_one_hot_labels.shape[1].value
num_tokens = int(real_one_hot_labels.shape[1])

current_image_id = tf.train.get_or_create_global_step()
current_image_id_inc_op = current_image_id.assign_add(batch_size)
Expand Down Expand Up @@ -436,7 +436,7 @@ def _add_specgrams_summary(name, specgrams, max_outputs):
fake_batch_size = config['fake_batch_size']
real_batch_size = self.batch_size
real_one_hot_labels = self.real_one_hot_labels
num_tokens = real_one_hot_labels.shape[1].value
num_tokens = int(real_one_hot_labels.shape[1])

# When making prediction, use the ema smoothed generator vars by
# `_custom_getter`.
Expand Down
2 changes: 1 addition & 1 deletion magenta/models/gansynth/lib/networks.py
Expand Up @@ -420,7 +420,7 @@ def _to_rgb(x):
outputs.append(lod * alpha)

predictions = tf.add_n(outputs)
batch_size = z.shape[0].value
batch_size = int(z.shape[0])
predictions.set_shape([batch_size, final_h, final_w, colors])
end_points['predictions'] = predictions

Expand Down
2 changes: 1 addition & 1 deletion magenta/models/gansynth/lib/spectral_ops.py
Expand Up @@ -252,7 +252,7 @@ def crop_or_pad(waves, length, channels):
A 3D `Tensor` of NLC format with shape [N, length, channels].
"""
waves = tf.convert_to_tensor(waves)
batch_size = waves.shape[0].value
batch_size = int(waves.shape[0])
waves_shape = tf.shape(waves)

# Force audio length.
Expand Down
2 changes: 1 addition & 1 deletion magenta/models/gansynth/lib/util.py
Expand Up @@ -30,7 +30,7 @@ def get_default_embedding_size(num_features):

def one_hot_to_embedding(one_hot, embedding_size=None):
"""Gets a dense embedding vector from a one-hot encoding."""
num_tokens = one_hot.shape[1].value
num_tokens = int(one_hot.shape[1])
label_id = tf.argmax(one_hot, axis=1)
if embedding_size is None:
embedding_size = get_default_embedding_size(num_tokens)
Expand Down
20 changes: 10 additions & 10 deletions magenta/models/image_stylization/image_utils.py
Expand Up @@ -318,7 +318,7 @@ def arbitrary_style_image_inputs(style_dataset_file,
label = features['label']

if image_size is not None:
image_channels = image.shape[2].value
image_channels = int(image.shape[2])
if augment_style_images:
image_orig = image
image = tf.image.random_brightness(image, max_delta=0.8)
Expand Down Expand Up @@ -436,7 +436,7 @@ def load_image(image_file, image_size=None):
image = tf.constant(np.uint8(load_np_image(image_file) * 255.0))
if image_size is not None:
# Center-crop into a square and resize to image_size
small_side = min(image.get_shape()[0].value, image.get_shape()[1].value)
small_side = int(min(image.shape[0], image.shape[1]))
image = tf.image.resize_image_with_crop_or_pad(
image, small_side, small_side)
image = tf.image.resize_images(image, [image_size, image_size])
Expand Down Expand Up @@ -487,17 +487,17 @@ def form_image_grid(input_tensor, grid_shape, image_shape, num_channels):
ValueError: The grid shape and minibatch size don't match, or the image
shape and number of channels are incompatible with the input tensor.
"""
if grid_shape[0] * grid_shape[1] != int(input_tensor.get_shape()[0]):
if grid_shape[0] * grid_shape[1] != int(input_tensor.shape[0]):
raise ValueError('Grid shape incompatible with minibatch size.')
if len(input_tensor.get_shape()) == 2:
if len(input_tensor.shape) == 2:
num_features = image_shape[0] * image_shape[1] * num_channels
if int(input_tensor.get_shape()[1]) != num_features:
if int(input_tensor.shape[1]) != num_features:
raise ValueError('Image shape and number of channels incompatible with '
'input tensor.')
elif len(input_tensor.get_shape()) == 4:
if (int(input_tensor.get_shape()[1]) != image_shape[0] or
int(input_tensor.get_shape()[2]) != image_shape[1] or
int(input_tensor.get_shape()[3]) != num_channels):
elif len(input_tensor.shape) == 4:
if (int(input_tensor.shape[1]) != image_shape[0] or
int(input_tensor.shape[2]) != image_shape[1] or
int(input_tensor.shape[3]) != num_channels):
raise ValueError('Image shape and number of channels incompatible with '
'input tensor.')
else:
Expand Down Expand Up @@ -629,7 +629,7 @@ def _aspect_preserving_resize(image, smallest_side):
"""
smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32)

input_rank = len(image.get_shape())
input_rank = len(image.shape)
if input_rank == 3:
image = tf.expand_dims(image, 0)

Expand Down
2 changes: 1 addition & 1 deletion magenta/models/image_stylization/model.py
Expand Up @@ -169,7 +169,7 @@ def residual_block(input_, kernel_size, scope, activation_fn=tf.nn.relu):
if kernel_size % 2 == 0:
raise ValueError('kernel_size is expected to be odd.')
with tf.variable_scope(scope):
num_outputs = input_.get_shape()[-1].value
num_outputs = int(input_.get_shape()[-1])
h_1 = conv2d(input_, kernel_size, 1, num_outputs, 'conv1', activation_fn)
h_2 = conv2d(h_1, kernel_size, 1, num_outputs, 'conv2', None)
return input_ + h_2
4 changes: 2 additions & 2 deletions magenta/models/music_vae/base_model.py
Expand Up @@ -333,10 +333,10 @@ def eval(self, input_sequence, output_sequence, sequence_length,

def sample(self, n, max_length=None, z=None, c_input=None, **kwargs):
"""Sample with an optional conditional embedding `z`."""
if z is not None and z.shape[0].value != n:
if z is not None and int(z.shape[0]) != n:
raise ValueError(
'`z` must have a first dimension that equals `n` when given. '
'Got: %d vs %d' % (z.shape[0].value, n))
'Got: %d vs %d' % (z.shape[0], n))

if self.hparams.z_size and z is None:
tf.logging.warning(
Expand Down
20 changes: 10 additions & 10 deletions magenta/models/music_vae/lstm_models.py
Expand Up @@ -181,7 +181,7 @@ def encode(self, sequence, sequence_length):
Returns:
embedding: A batch of embeddings, sized `[batch_size, N]`.
"""
batch_size = sequence.shape[0].value
batch_size = int(sequence.shape[0])
sequence_length = lstm_utils.maybe_split_sequence_lengths(
sequence_length, np.prod(self._level_lengths[1:]),
self._total_length)
Expand Down Expand Up @@ -323,7 +323,7 @@ def reconstruction_loss(self, x_input, x_target, x_length, z=None,
metric_map: Map from metric name to tf.metrics return values for logging.
decode_results: The LstmDecodeResults.
"""
batch_size = x_input.shape[0].value
batch_size = int(x_input.shape[0])

has_z = z is not None
z = tf.zeros([batch_size, 0]) if z is None else z
Expand Down Expand Up @@ -399,10 +399,10 @@ def sample(self, n, max_length=None, z=None, c_input=None, temperature=1.0,
Raises:
ValueError: If `z` is provided and its first dimension does not equal `n`.
"""
if z is not None and z.shape[0].value != n:
if z is not None and int(z.shape[0]) != n:
raise ValueError(
'`z` must have a first dimension that equals `n` when given. '
'Got: %d vs %d' % (z.shape[0].value, n))
'Got: %d vs %d' % (z.shape[0], n))

# Use a dummy Z in unconditional case.
z = tf.zeros((n, 0), tf.float32) if z is None else z
Expand Down Expand Up @@ -547,7 +547,7 @@ def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
tf.metrics.accuracy(flat_truth, flat_predictions),
'metrics/mean_per_class_accuracy':
tf.metrics.mean_per_class_accuracy(
flat_truth, flat_predictions, flat_x_target.shape[-1].value),
flat_truth, flat_predictions, int(flat_x_target.shape[-1])),
}
return r_loss, metric_map

Expand Down Expand Up @@ -741,10 +741,10 @@ def reconstruction_loss(self, x_input, x_target, x_length, z=None,

def sample(self, n, max_length=None, z=None, c_input=None, temperature=1.0,
start_inputs=None, **core_sampler_kwargs):
if z is not None and z.shape[0].value != n:
if z is not None and int(z.shape[0]) != n:
raise ValueError(
'`z` must have a first dimension that equals `n` when given. '
'Got: %d vs %d' % (z.shape[0].value, n))
'Got: %d vs %d' % (z.shape[0], n))

if max_length is None:
# TODO(adarob): Support variable length outputs.
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def reconstruction_loss(self, x_input, x_target, x_length, z=None,
raise ValueError(
'Re-encoder mode unsupported when conditioning on controls.')

batch_size = x_input.shape[0].value
batch_size = int(x_input.shape[0])

x_length = lstm_utils.maybe_split_sequence_lengths(
x_length, np.prod(self._level_lengths[:-1]), self._total_length)
Expand Down Expand Up @@ -1135,10 +1135,10 @@ def sample(self, n, max_length=None, z=None, c_input=None,
ValueError: If `z` is provided and its first dimension does not equal `n`,
or if `c_input` is provided in re-encoder mode.
"""
if z is not None and z.shape[0].value != n:
if z is not None and int(z.shape[0]) != n:
raise ValueError(
'`z` must have a first dimension that equals `n` when given. '
'Got: %d vs %d' % (z.shape[0].value, n))
'Got: %d vs %d' % (z.shape[0], n))
z = tf.zeros([n, 0]) if z is None else z

if self._hierarchical_encoder and c_input is not None:
Expand Down
2 changes: 1 addition & 1 deletion magenta/models/onsets_frames_transcription/model.py
Expand Up @@ -64,7 +64,7 @@ def conv_net(inputs, hparams):
# Flatten while preserving batch and time dimensions.
dims = tf.shape(net)
net = tf.reshape(
net, (dims[0], dims[1], net.shape[2].value * net.shape[3].value),
net, (dims[0], dims[1], net.shape[2] * net.shape[3]),
'flatten_end')

net = slim.fully_connected(net, hparams.fc_size, scope='fc_end')
Expand Down
2 changes: 1 addition & 1 deletion magenta/models/onsets_frames_transcription/model_tpu.py
Expand Up @@ -57,7 +57,7 @@ def conv_net(inputs, hparams):
# Flatten while preserving batch and time dimensions.
dims = tf.shape(net)
net = tf.reshape(
net, (dims[0], dims[1], net.shape[2].value * net.shape[3].value),
net, (dims[0], dims[1], net.shape[2] * net.shape[3]),
'flatten_end')

net = slim.fully_connected(net, hparams.fc_size, scope='fc_end')
Expand Down
4 changes: 2 additions & 2 deletions magenta/models/pianoroll_rnn_nade/pianoroll_rnn_nade_graph.py
Expand Up @@ -102,7 +102,7 @@ def _get_state(self,
Returns:
final_state: An RnnNadeStateTuple, the final state of the RNN-NADE.
"""
batch_size = inputs.shape[0].value
batch_size = int(inputs.shape[0])

if lengths is None:
lengths = tf.tile(tf.shape(inputs)[1:2], [batch_size])
Expand Down Expand Up @@ -150,7 +150,7 @@ def log_prob(self, sequences, lengths=None):
cond_prob: The conditional probabilities at each non-padded value for
every batch, sized `[sum(lengths), num_dims]`.
"""
assert self._num_dims == sequences.shape[2].value
assert self._num_dims == int(sequences.shape[2])

# Remove last value from input sequences.
inputs = sequences[:, 0:-1, :]
Expand Down
2 changes: 1 addition & 1 deletion magenta/models/shared/events_rnn_model.py
Expand Up @@ -81,7 +81,7 @@ def _build_graph_for_generation(self):

def _batch_size(self):
"""Extracts the batch size from the graph."""
return self._session.graph.get_collection('inputs')[0].shape[0].value
return int(self._session.graph.get_collection('inputs')[0].shape[0])

def _generate_step_for_batch(self, event_sequences, inputs, initial_state,
temperature):
Expand Down
4 changes: 2 additions & 2 deletions magenta/music/melspec_input.py
Expand Up @@ -77,7 +77,7 @@ def _naive_rdft(signal_tensor, fft_length):
imag_dft_tensor = tf.constant(
np.imag(complex_dft_matrix_kept_values).astype(np.float32),
name='imaginary_dft_matrix')
signal_frame_length = signal_tensor.shape[-1].value
signal_frame_length = int(signal_tensor.shape[-1])
half_pad = (fft_length - signal_frame_length) // 2
pad_values = tf.concat([
tf.zeros([tf.rank(signal_tensor) - 1, 2], tf.int32),
Expand Down Expand Up @@ -263,7 +263,7 @@ def build_mel_calculation_graph(waveform_input,
waveform_input, window_length_samples, hop_length_samples, fft_length)

# Warp the linear-scale, magnitude spectrograms into the mel-scale.
num_spectrogram_bins = magnitude_spectrogram.shape[-1].value
num_spectrogram_bins = int(magnitude_spectrogram.shape[-1])
if tflite_compatible:
linear_to_mel_weight_matrix = tf.constant(
mfcc_mel.SpectrogramToMelMatrix(num_mel_bins, num_spectrogram_bins,
Expand Down
2 changes: 1 addition & 1 deletion magenta/version.py
Expand Up @@ -18,4 +18,4 @@
pulling in all the dependencies in __init__.py.
"""

__version__ = '2.0.0'
__version__ = '2.0.1'

0 comments on commit caf088f

Please sign in to comment.