Skip to content

charlesjsun/rjax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RL framework in JAX (Flax)

Offline RL framework in JAX that supports

  • Distributed multi-device training (GPU/TPU)
  • Multiprocessed dataloaders (similar to PyTorch)
  • Modular experiment management and logging

Currently used by myself for research so features may change often.

Install dependencies

conda env create -f environment.yml

conda activate rjax

# Installs dependencies for GPU, don't run if you want CPU install only
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

pip install -e .

To install on TPU

./gcp/setup_tpu.sh

There are also two Dockerfiles provided for CPU and GPU.

Documentation

No official documentation page yet but the code is extensively documented and configs are explained.

Examples

IQL: implicit_q_learning.

python examples/train_iql.py --prefix iql --env antmaze-medium-play-v2 --num_workers 2

About

RL Framework in JAX (Flax)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published