Skip to content

Commit

Permalink
Update gnome.py
Browse files Browse the repository at this point in the history
fix pytype
  • Loading branch information
ekindogus committed Apr 8, 2024
1 parent a868630 commit e703937
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions jax_md/_nn/gnome.py
Expand Up @@ -51,6 +51,8 @@

PyTree = util.PyTree

Array = util.Array


def model_from_config(cfg: ConfigDict) -> nn.Module:
model_family = cfg.get('model_family', 'nequip')
Expand All @@ -69,11 +71,11 @@ def minimum_batch_size(cfg: ConfigDict) -> int:


class ScaleLROnPlateau(NamedTuple):
step_size: jnp.ndarray
minimum_loss: jnp.ndarray
steps_without_reduction: jnp.ndarray
max_steps_without_reduction: jnp.ndarray
reduction_factor: jnp.ndarray
step_size: Array
minimum_loss: Array
steps_without_reduction: Array
max_steps_without_reduction: Array
reduction_factor: Array


def scale_lr_on_plateau(initial_step_size: float,
Expand Down

0 comments on commit e703937

Please sign in to comment.