Releases: google/flax
Releases · google/flax
v0.8.3
What's Changed
- Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
- removed getattr/setattr unboxing magic from
nnx.Pytree
by @chiamp in #3743 - added Einsum layer to NNX by @chiamp in #3741
- Make
TrainState
'sstep
possibly jax.Array. This makesreplicate
valid for type checking. by @copybara-service in #3763 - v0.8.3 by @cgarciae in #3758
- [nnx] fix demo notebook by @cgarciae in #3744
- added nnx api reference by @chiamp in #3762
- updated rng docstring for init, apply and make_rng by @chiamp in #3765
- use note box in make_rng docstring by @cgarciae in #3767
- [nnx] improved graph update mechanism by @cgarciae in #3759
- use note box in docstrings by @chiamp in #3769
- Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
- Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
- Minor doc improvements by @canyon289 in #3588
- added MGU
reset_gate
test by @chiamp in #3773 - [nnx] Pytrees are Trees by @cgarciae in #3768
- Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
- fix tabulate on norm wrappers by @chiamp in #3772
- Add
kw_only
struct.dataclass test by @chiamp in #3651 - extended
PyTreeNode
to take dataclass kwargs by @chiamp in #3785 - [nnx] Arrays are state by @cgarciae in #3791
- [nnx] add GraphNode base class by @cgarciae in #3790
- [nnx] jit accepts many Modules by @cgarciae in #3783
- Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
- Expose
nnx.GraphNode
by @chiamp in #3796 - [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
- [nnx] TrainState uses struct by @cgarciae in #3788
- [nnx] split returns graphdef first by @cgarciae in #3794
- Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
- Add
nnx.training
by @chiamp in #3782 - [nnx] non-str State keys by @cgarciae in #3802
- [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
- [nnx] simplify readme by @cgarciae in #3805
- [nnx] Fix nnx basics by @cgarciae in #3812
- [nnx] grad accepts argnums by @cgarciae in #3798
- [nnx] improve toy examples by @cgarciae in #3813
- [nnx] expose Sequential by @cgarciae in #3814
- [nnx] Rng Variable tags by @cgarciae in #3807
- [nnx] remove copy in graph unflatten by @cgarciae in #3804
- fixed optax guide links and docstring typos by @chiamp in #3789
- added dropout broadcast test by @chiamp in #3776
- relaxed
grads
kwarg forOptimizer.update
by @chiamp in #3818 - added
tree_map
deprecation warning filter by @chiamp in #3828 - updated
tree_map
by @chiamp in #3823 - added NNX vs JAX transformations guide by @chiamp in #3819
- Updated NNX MNIST tutorial by @chiamp in #3810
- [nnx] add Dropout.rngs by @cgarciae in #3815
- removed autosummary from linen docs by @chiamp in #3792
- Fix cloudpickle sentinel cloning by @cgarciae in #3825
- [nnx] remove pytreelib by @cgarciae in #3816
- [nnx] fix nnx_basics by @cgarciae in #3839
- [linen] fix DenseGeneral init by @cgarciae in #3834
- [nnx] jit constrain object state by @cgarciae in #3817
- Copybara import of the project: by @copybara-service in #3857
- Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
- RNNCellBase refactor FLIP by @cgarciae in #3099
- [nnx] Some small documentation suggestions. by @gnecula in #3861
- updated nnx dropout by @chiamp in #3841
- Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
- Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
- added nnx api reference link by @chiamp in #3871
- option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
- allow custom dot_general for einsum. by @copybara-service in #3884
- [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
- updated
robots.txt
by @chiamp in #3886 - fixed autosummary links by @chiamp in #3887
- Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
- [nnx] v0.1 by @cgarciae in #3876
Full Changelog: v0.8.2...v0.8.3
v0.8.2
What's Changed
- Add +1 to version after 0.8.1 release by @IvyZX in #3684
- fixed rng guide outputs by @chiamp in #3685
- enforce mask kwarg in norm layers by @chiamp in #3663
- added kwargs to self.param and self.variable by @chiamp in #3675
- added nnx normalization tests by @chiamp in #3689
- added NNX init_cache docstring example by @chiamp in #3688
- added nnx attention equivalence test by @chiamp in #3687
- Fix bug that assumed frozen-dict keys were strings. by @copybara-service in #3692
- added nnx rmsnorm by @chiamp in #3691
- updated nnx compute_stats by @chiamp in #3693
- fixed intercept_methods docstring by @chiamp in #3694
- [nnx] Add Sphinx Docs by @cgarciae in #3678
- Fix pointless docstring example of nn.checkpoint / nn.remat. by @levskaya in #3703
- added default params rng to .apply by @chiamp in #3698
- [nnx] add partial_init by @cgarciae in #3674
- make make_rng default to 'params' by @chiamp in #3699
- Add SimpleCell. by @carlosgmartin in #3697
- fix Module.module_paths docstring by @cgarciae in #3709
- Guarantee the latest JAX version on CI by @cgarciae in #3705
- Replace deprecated API
jax.tree_map
by @copybara-service in #3715 - Use
jax.tree_util.tree_map
instead of deprecatedjax.tree_map
. by @copybara-service in #3714 - [nnx] simplify readme by @cgarciae in #3707
- [nnx] add demo.ipynb by @cgarciae in #3680
- Fix Tabulate's compute_flops by @cgarciae in #3721
- [nnx] simplify TraceState by @cgarciae in #3724
- Add broadcast of
strides
andkernel_dilation
tonn.ConvTranspose
by @IvyZX in #3731 - [nnx] Fix State.sub by @cgarciae in #3704
- [nnx] always fold_in on fork + new ForkedKeys return type by @cgarciae in #3722
- [nnx] explicit Variables by @cgarciae in #3720
- Improves fingerprint definition for Modules in nn.jit. by @copybara-service in #3736
- Flax: avoid key reuse in tests by @copybara-service in #3740
- added Einsum layer by @chiamp in #3710
- nn.jit: automatic fingerprint definition for dataclass attributes by @cgarciae in #3737
- [NVIDIA] Use custom grad accumulation for FP8 params by @kaixih in #3623
- removed nnx dataclass by @chiamp in #3742
- [nnx] cleanup graph_utils by @cgarciae in #3728
- Fix doctest and unbreak head by @IvyZX in #3753
- [nnx] add pytree support by @cgarciae in #3732
- fixed intercept_methods docstring by @chiamp in #3752
- Add ConvLSTMCell to docs. by @carlosgmartin in #3712
- [nnx] remove flagslib by @cgarciae in #3733
- Fix tests after applying JAX key-reuse checker. See: by @copybara-service in #3748
Full Changelog: v0.8.1...v0.8.2
Version 0.8.1
What's Changed
- bump version number to 0.8.1 by @chiamp in #3649
- Bump pillow from 10.0.1 to 10.2.0 in /examples/vae by @dependabot in #3641
- fixed docstring by @chiamp in #3643
- Add explicit control over frozen/slots setting in flax.struct.dataclass by @copybara-service in #3645
- make Sequential.call compact by @copybara-service in #3647
- add Module.module_paths by @cgarciae in #3654
- added rng_guide by @chiamp in #3497
- Replacing jax.tree_util.tree_map with mapping over leafs. by @copybara-service in #3658
- Copybara import of the project: by @copybara-service in #3659
- added InstanceNorm by @chiamp in #3652
- add Module.module_paths by @copybara-service in #3660
- added norm equivalence tests by @chiamp in #3662
- updated nowrap docstring by @chiamp in #3661
- Add module_paths method to docs by @cgarciae in #3657
- add default make_rng by @chiamp in #3669
- renamed channel_axes to feature_axes in InstanceNorm by @chiamp in #3667
- added flax.typing by @chiamp in #3624
- changed kwargs to actual key-word args by @chiamp in #3562
- updated docs and docstrings by @chiamp in #3670
- re-added linen_intro by @chiamp in #3672
- add compact_name_scope v3 by @cgarciae in #3646
- Release 0.8.1 by @IvyZX in #3682
Full Changelog: v0.8.0...v0.8.1
v0.8.0
What's Changed
- bump version number by @levskaya in #3446
- Add merge / finalize step when using OCDBT driver. Files will be first written to per-process subdirectories, which are later copied by reference to the main directory before the checkpoint is finalized. by @copybara-service in #3426
- fixed quickstart by @chiamp in #3451
- [NVIDIA] Update the algorithm to compute fp8 scales by @kaixih in #3441
- added pre-commit hook that sort imports and formats by @chiamp in #3455
- restructured doc folders by @chiamp in #3434
- Forked a subset of JAX configuration APIs by @superbobry in #3448
- Fix Module.clone in deepclone mode for internal usage. by @levskaya in #3459
- Add user-friendly module copy method. by @levskaya in #3461
- Add simple argument-only lifted nn.grad function. by @levskaya in #3463
- exempt a jax.config deprecation warning by @levskaya in #3465
- Clean up pyproject.toml. by @levskaya in #3468
- Allow for fast accumulation selection for FP8 GEMM by @wenscarl in #3416
- re-added quickstart guide by @chiamp in #3471
- fixed tabulate docstring by @chiamp in #3452
- Add NNX by @cgarciae in #3218
- Bump pillow from 9.5.0 to 10.0.1 in /examples/vae by @dependabot in #3390
- updated attention_test by @chiamp in #3454
- [nnx] Improve docs by @cgarciae in #3478
- added example docstrings by @chiamp in #3453
- fix nn.value_and_grad by implementing directly in core by @levskaya in #3479
- Add dataset loading guide (Issue #2116) by @VictorPrins in #3450
- [nnx] Add support for python container types by @cgarciae in #3486
- remove SelfAttention test and warning filter by @chiamp in #3470
- disabled ruff formatter by @chiamp in #3482
- adding doctest to .rst files by @chiamp in #3481
- changed pip installs to use quotes by @chiamp in #3477
- added enum support for tabulate by @chiamp in #3485
- fix bug in optimizer-api.md by @zhaoyang-0204 in #3462
- removed selfattention from doctest by @chiamp in #3489
- [nnx] Add missing import on why.ipynb by @cgarciae in #3503
- [nnx] switch to nested State representation by @cgarciae in #3502
- Improved Rigor of
PReLU
Test by @Micky774 in #3498 - added geglu activation and tests by @HMUNACHI in #3512
- [nnx] Add LinearGeneral and MultiHeadAttention by @cgarciae in #3487
- Add NNX/Linen consistency test for
Embed
layer by @Micky774 in #3513 - Add NNX/Linen API consistency test for
Conv
layer by @Micky774 in #3511 - Prevent crash in dataclasses with no-init params by @NeilGirdhar in #3514
- [nnx] Variable referece sharing by @cgarciae in #3516
- Added NNX/Linen API consistency test for
Linear/Dense
layer by @Micky774 in #3509 - Add missing mask argument to LayerNorm, RMSNorm, and GroupNorm. by @carlosgmartin in #3510
- [nnx] Fix graph_utils bug by @cgarciae in #3518
- remove deprecated normalize function by @chiamp in #3531
- Reduced number of parameterizations for
Conv
NNX/Linen consistency test by @Micky774 in #3526 - Ensure that
_hashable_filter
does not convert strings to a tuple of letters by @copybara-service in #3533 - added sow attention weights by @chiamp in #3529
- Fix scan out_axes by @cgarciae in #3540
- updated embed docstring by @chiamp in #3539
- add test_scan_negative_axes by @cgarciae in #3542
- add module methods to api docs by @chiamp in #3544
- fixed double backquote code font by @chiamp in #3545
- add nnx conv support for int kernel size by @chiamp in #3537
- added sow attention weights to NNX by @chiamp in #3548
- changed
return_weights
tosow_weights
for attention layer by @chiamp in #3550 - format linen_linear_test.py by @chiamp in #3553
- re-factored features arg by @chiamp in #3554
- updated NNX readme by @chiamp in #3556
- Disable ruff sort imports by @cgarciae in #3560
- Add StateVariablesMapping by @cgarciae in #3523
- add kwargs support for nn.jit by @copybara-service in #3559
- [nnx] Fix readme install instruction by @cgarciae in #3565
- implement Rng.getattr by @cgarciae in #3547
- [nnx] add qkv_features back to MHA by @cgarciae in #3566
- updated readme by @chiamp in #3563
- fixed typo by @chiamp in #3561
- Raise an error for a bad key type by @NeilGirdhar in #3527
- re-factored nnx initializers by @chiamp in #3555
- [nnx] Add complex test with scan + batchnorm + dropout by @cgarciae in #3567
- [nnx] Add interacting with JAX section to README by @cgarciae in #3573
- expose ones and zeros initializers by @chiamp in #3574
- Fix promotion bug in MultiHeadDotProductAttention: by @giovannic in #3571
- fixed error doc formatting by @chiamp in #3587
- [nnx] Improve spmd by @cgarciae in #3580
- [nnx] improve graph_utils._set_key_tuple by @cgarciae in #3592
- [nnx] Fix variable unflatten by @cgarciae in #3578
- [nnx] add open in colab button to why nnx by @cgarciae in #3596
- [nnx] Export missing symbols by @cgarciae in #3583
- [nnx] flaglib add get overloads by @cgarciae in #3582
- Fix type in NNX readme by @shoyer in #3591
- [nnx] add submodule iterator by @cgarciae in #3581
- [nnx] delete flaglib duplicated copyright comment by @cgarciae in #3600
- fixed NNX decode and dynamic slicing by @chiamp in #3576
- [nnx] cleanup CallableProxy by @cgarciae in #3608
- [nnx] improve runtime flags by @cgarciae in #3607
- fixed broken links on quick start guide by @chiamp in #3610
- added multiheadattention alias by @chiamp in #3572
- Rollback of Copybara import of the project: by @copybara-service in #3612
- add missing docs for module functions by @cgarciae in #3619
- fix lm1b data sharding by @cgarciae in #3620
- improve embed by @jianyizh in #3590
- disable ruff linter by @chiamp in #3625
- Add compact_name_scope decorator by @cgarciae in #3621
- Copybara import of the project: by @copybara-service in #3638
- added BatchApply by @chiamp in #3634
- add compact_name_scope v2 by @copybara-service in #3640
- add compact_name_scope v2 by @copybara-service in #3642
- release 0.8.0 by @chiamp in #3644
New Contributors
- @superbobry made their first contribution in #3448
- @VictorPrins made their first contribution in #3450
- @zhaoyang-0204 made their first contribution in #3462
- @Micky774 made their first contribution in #3498
- @HMUNACHI made their first contribution in #3512
- @carlosgmartin made their first contribution in #3510
- @giovannic made their first contribution in https://gith...
v0.7.5
What's Changed
- Add method-to-model section to Haiku migration guide by @IvyZX in #3277
- updated haiku guide with new JAX RNG api by @chiamp in #3343
- changed resnet v1 to v1.5 by @chiamp in #3344
- updated flax_basics by @chiamp in #3342
- Add
find
methods and magic methods for Cursor API by @chiamp in #3306 - fix DeprecationWarnings by @chiamp in #3352
- Make checkpoint guide path absolute by @IvyZX in #3358
- use jax.Array type for rng keys by @chiamp in #3354
- Disable pyink by @cgarciae in #3356
- removed DeprecationWarning filter by @chiamp in #3359
- Don't propagate default args in Tabulate by @cgarciae in #3357
- added spectral norm by @chiamp in #3335
- fixed bind-unbind bug by @chiamp in #3365
- Add MaxText open source LLM to index.rst by @8bitmp3 in #3368
- fixed typo by @chiamp in #3367
- Add fp8 custom op and unit test by @wenscarl in #3322
- fixed scope typing by @chiamp in #3371
- Trailing whitespace fixes. by @levskaya in #3373
- Make fp8 ops use explicit broadcasting. by @levskaya in #3374
- Conditonal contraints for numpy and clu by @cgarciae in #3394
- add truncated_normal to initializers by @levskaya in #3401
- fix HEAD by @chiamp in #3404
- updated imagenet readme by @chiamp in #3383
- Add Flax FAQ - how to search, how to take the derivative w.r.t. a hidden layer, remat_scan vs scan(remat), recommended training libraries/metrics by @8bitmp3 in #3267
- added rmsnorm to api docs by @chiamp in #3406
- updated docs by @chiamp in #3370
- fix HEAD by @chiamp in #3408
- added weightnorm layer by @chiamp in #3405
- added attention refactor to changelog by @chiamp in #3412
- added dropout arg to
MultiHeadDotProductAttention
by @chiamp in #3384 - fix spectralnorm layer by @chiamp in #3403
- remove pdf target by @cgarciae in #3415
- added import precommit hook by @chiamp in #3410
- fixed GRU docstring by @chiamp in #3419
- Replaces pjit with jit in spmd.py by @copybara-service in #3421
- Ignore transient chex deprecationwarning. by @levskaya in #3427
- Remove transformers dependency from docs by @cgarciae in #3431
- updated pytorch upgrade guide by @chiamp in #3432
- Simplify abstract rng creation in param shape check. by @levskaya in #3429
- [NVIDIA] Update the FP8 support by @kaixih in #3435
- fixed mypy errors at HEAD by @chiamp in #3440
- added
has_improved
field to EarlyStopping by @chiamp in #3385 - fully deprecated old RNN api by @chiamp in #3425
- updated lm1b example with jit by @chiamp in #3302
- added MGU class by @chiamp in #3418
- Make Flax Basics visible by @8bitmp3 in #3443
- update for release v0.7.5 by @levskaya in #3444
New Contributors
Full Changelog: v0.7.4...v0.7.5
v0.7.4
Version 0.7.3
What's Changed
- Fix Documentation Typo by @peterdavidfagan in #3265
- bump version number by @chiamp in #3273
- added cursor api by @chiamp in #3246
- Improve Typing support by @cgarciae in #3242
- feat: add configs for vae example by @SauravMaheshkar in #3254
- fix stackoverflow when loading pickled module by @cgarciae in #3286
- Remove string from checkpointing example and unbreak doctest in head by @IvyZX in #3288
- Improve kw_only_dataclass by @cgarciae in #3293
- updated haiku upgrade guide by @chiamp in #3292
- added logical partitioning to pjit guide by @chiamp in #3290
- Rolling back ba9e24a by @copybara-service in #3295
- Add RNN FLIP by @cgarciae in #2585
- [flax] Add module path to nn.module. by @copybara-service in #3300
- Minor fix to jax_utils.prefect_to_device by @anuragarnab in #3308
- add tf-text to doc requirements by @cgarciae in #3317
- Allow apply's method argument to accept submodules by @cgarciae in #3281
- Update handling of typed PRNG keys by @jakevdp in #3314
- 0.7.3 release by @IvyZX in #3326
New Contributors
- @peterdavidfagan made their first contribution in #3265
- @SauravMaheshkar made their first contribution in #3254
Full Changelog: v0.7.2...v0.7.3
Version 0.7.2
What's Changed
- make flax.core.copy
add_or_replace
optional by @PhilipVinc in #3241 - bump version to 0.7.2 by @chiamp in #3258
- Release 0.7.2 by @chiamp in #3269
Full Changelog: v0.7.1...v0.7.2
Version 0.7.1
What's Changed
- added dict migration guide to index by @chiamp in #3188
- [linen] Add alternative, more numerically stable, variance calculation to
LayerNorm
. by @copybara-service in #3194 - [linen] Minor cleanup to normalization code. by @copybara-service in #3200
- Fix warnings from atari gym. by @levskaya in #3207
- Fix carry slice logic by @cgarciae in #3213
- removed FrozenDict section by @chiamp in #3210
- make flax_basics guide use utility fns by @chiamp in #3214
- Fix typo in struct.py documentation by @marcoselvi in #3209
- updated flax guides to use utility fns by @chiamp in #3215
- [linen] More minor cleanup in normalization
compute_stats
. by @copybara-service in #3205 - Expose options to customize rich.Table by @cgarciae in #3197
- Use pyink by @cgarciae in #3216
- Fix remaining pyink issues by @cgarciae in #3219
- add scan over layers section by @cgarciae in #3195
- Fix checkpointing guide error at head by @IvyZX in #3223
- Improve scan docs by @cgarciae in #3231
- 0.7.1 release by @chiamp in #3234
New Contributors
- @marcoselvi made their first contribution in #3209
Full Changelog: v0.7.0...v0.7.1
Version 0.7.0
What's Changed
- delete long-unused dotgetter utility by @copybara-service in #3156
- Add sections to flax.linen toctree by @cgarciae in #3073
- update doctest requirements to use recentish jax minver by @levskaya in #3160
- Update gym to gymnasium by @cgarciae in #3133
- Add lifted transforms section to Haiku migration guide by @cgarciae in #3158
- relaxed tolerance by @chiamp in #3162
- fixed broken links by @chiamp in #3161
- added dict migration guide by @chiamp in #3109
- Update requirements, restructure files and fix formatting for VAE example by @canyon289 in #3046
- Update python version support by @cgarciae in #3168
- Improve Haiku migration guide by @cgarciae in #3169
- Set default types in Flax for Orbax restoration and add
restore_with_serialized_types
in preparation for an upcoming change. by @copybara-service in #3165 - fix precommit issues by @chiamp in #3170
- fixed incorrect reference link by @chiamp in #3167
- added absltest to linen_recurrent_test by @chiamp in #3172
- remove cell_size docs from RNN by @cgarciae in #3186
- 0.7.0 by @cgarciae in #3187
Full Changelog: v0.6.11...v0.7.0