Skip to content

astanziola/siren-flax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SIREN in Flax

Unofficial implementation of SIREN neural networks in Flax, using the Linen Module system.

This repo also includes Modulated Periodic Activations for Generalizable Local Functional Representations.

Examples

An image fitting problem is provided in the Example notebook

reults

Defining a single SIREN layer

Returns a fully connected layer with sinusoidal activation function, initialized according to the original SIREN paper.

layer = SirenLayer(
    features = 32
    w0 = 1.0
    c = 6.0
    is_first = False
    use_bias = True
    act = jnp.sin
    precision = None
    dtype = jnp.float32
)

How to use a SIREN neural network

SirenNN = Siren(hidden_dim=512, output_dim=1, final_activation=sigmoid)
params = SirenNN.init(random_key, sample_input)["params"]
output = SirenNN.apply({"params": params}, sample_input)

Approximate image on a grid

This can be easily done using the built-in broadcasting features of jax.numpy functions. This repository provides an useful initializer grid_init to generate a coordinate grid that can be used as input.

SirenDef = Siren(num_layers=5)

grid = grid_init(grid_dimension, jnp.float32)()
params = SirenDef.init(key, grid)["params"]

image = SirenDef.apply({"params": params}, grid)

Use Modulated SIREN

SirenDef = ModulatedSiren(num_layers=5)

grid = grid_init(grid_dimension, jnp.float32)()
params = SirenDef.init(key, grid)["params"]

image = SirenDef.apply({"params": params}, grid)

References

  1. Implicit Neural Representations with Periodic Activation Functions
  2. Modulated Periodic Activations for Generalizable Local Functional Representations

Related works