Skip to content

A dead simple Python package for creating custom JAX pytree objects

License

Notifications You must be signed in to change notification settings

cgarciae/simple-pytree

Repository files navigation

codecov

Simple Pytree

A dead simple Python package for creating custom JAX pytree objects.

  • Strives to be minimal, the implementation is just ~100 lines of code
  • Has no dependencies other than JAX
  • Its compatible with both dataclasses and regular classes
  • It has no intention of supporting Neural Network use cases (e.g. partitioning)

Installation

pip install simple-pytree

Usage

import jax
from simple_pytree import Pytree

class Foo(Pytree):
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)

assert foo.x == -1 and foo.y == -2

Static fields

You can mark fields as static by assigning static_field() to a class attribute with the same name as the instance attribute:

import jax
from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified

assert foo.x == -1 and foo.y == 2

Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.

Dataclasses

simple_pytree provides a dataclass decorator you can use with classes that contain static_fields:

import jax
from simple_pytree import Pytree, dataclass, static_field

@dataclass
class Foo(Pytree):
    x: int
    y: int = static_field(default=2)
    
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified

assert foo.x == -1 and foo.y == 2

simple_pytree.dataclass is just a wrapper around dataclasses.dataclass but when used static analysis tools and IDEs will understand that static_field is a field specifier just like dataclasses.field.

Mutability

Pytree objects are immutable by default after __init__:

from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo.x = 3 # AttributeError

If you want to make them mutable, you can use the mutable argument in class definition:

from simple_pytree import Pytree, static_field

class Foo(Pytree, mutable=True):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo.x = 3 # OK

Replacing fields

If you want to make a copy of a Pytree object with some fields modified, you can use the .replace() method:

from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = foo.replace(x=10)

assert foo.x == 10 and foo.y == 2

replace works for both mutable and immutable Pytree objects. If the class is a dataclass, replace internally use dataclasses.replace.

About

A dead simple Python package for creating custom JAX pytree objects

Topics

Resources

License

Stars

Watchers

Forks

Sponsor this project

 

Packages

No packages published

Languages