Skip to content

Commit

Permalink
Merge pull request #21 from Rishit-dagli/Rishit-dagli-patch-1
Browse files Browse the repository at this point in the history
Fix error with tf.function
  • Loading branch information
Rishit-dagli committed Apr 26, 2021
2 parents 4d3b9b0 + c97c3b4 commit ca44960
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion perceiver/perceiver.py
Expand Up @@ -76,6 +76,8 @@ def __init__(
# self.existing_layers = get_latent_attn()(self.existing_layers)
# self.existing_layers = get_latent_ff()(self.existing_layers)

self.existing_layers = tf.keras.Sequential(self.existing_layers)

self.to_logits = tf.keras.Sequential(
[
tf.keras.layers.LayerNormalization(axis=-1),
Expand Down Expand Up @@ -103,7 +105,7 @@ def call(self, data, mask=None):

x = repeat(self.latents, "n d -> b n d", b=b)

x = tf.keras.Sequential(self.existing_layers)(x)
x = self.existing_layers(x)

x = tf.math.reduce_mean(x, axis=-2)
return self.to_logits(x)
2 changes: 1 addition & 1 deletion perceiver/version.py
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.1.2"
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -7,7 +7,7 @@

setup(
name="perceiver",
version="0.1.1",
version="0.1.2",
description="Implement of Perceiver, General Perception with Iterative Attention in TensorFlow",
packages=["perceiver"],
long_description=long_description,
Expand Down

0 comments on commit ca44960

Please sign in to comment.