Skip to content

anagabandi/nn_dynamics

Repository files navigation

Neural Network Dynamics for Model-Based Deep Reinforcement Learning with Model-Free Fine-Tuning

Arxiv Link

Abstract: Model-free deep reinforcement learning algorithms have been shown to be capable of learning a wide range of robotic skills, but typically require a very large number of samples to achieve good performance. Model-based algorithms, in principle, can provide for much more efficient learning, but have proven difficult to extend to expressive, high-capacity models such as deep neural networks. In this work, we demonstrate that medium-sized neural network models can in fact be combined with model predictive control (MPC) to achieve excellent sample complexity in a model-based reinforcement learning algorithm, producing stable and plausible gaits to accomplish various complex locomotion tasks. We also propose using deep neural network dynamics models to initialize a model-free learner, in order to combine the sample efficiency of model-based approaches with the high task-specific performance of model-free methods. We empirically demonstrate on MuJoCo locomotion tasks that our pure model-based approach trained on just minutes of random action data can follow arbitrary trajectories, and that our hybrid algorithm can accelerate model-free learning on high-speed benchmark tasks, achieving sample efficiency gains of 3-5x on swimmer, cheetah, hopper, and ant agents.

  • For installation guide, go to installation.md
  • For notes on how to use your own environment, how to edit envs, etc. go to notes.md

How to run everything

cd scripts
./swimmer_mbmf.sh
./cheetah_mbmf.sh
./hopper_mbmf.sh
./ant_mbmf.sh

Each of those scripts does something similar to the following (but for multiple seeds):

python main.py --seed=0 --run_num=1 --yaml_file='swimmer_forward'
python mbmf.py --run_num=1 --which_agent=2
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=2 --num_workers_trpo=2 --std_on_mlp_policy=0.5
python plot_mbmf.py --trpo_dir=[trpo_dir] --std_on_mlp_policy=0.5 --which_agent=2 --run_nums 1 --seeds 0

Note that [trpo_dir] above corresponds to where the TRPO runs are saved. Probably somewhere in ~/rllab/data/...
Each of these steps are further explained in the following sections.


How to run MB

Need to specify:

    --yaml_file Specify the corresponding yaml file
    --seed Set random seed to set for numpy and tensorflow
    --run_num Specify what directory to save files under
    --use_existing_training_data To use the data that already exists in the directory run_num instead of recollecting
    --desired_traj_type What type of trajectory to follow (if you want to follow a trajectory)
    --num_rollouts_save_for_mf Number of on-policy rollouts to save after last aggregation iteration, to be used later
    --might_render If you might want to visualize anything during the run
    --visualize_MPC_rollout To set a breakpoint and visualize the on-policy rollouts after each agg iteration
    --perform_forwardsim_for_vis To visualize an open-loop prediction made by the learned dynamics model
    --print_minimal To not print messages regarding progress/notes/etc.

Examples:
python main.py --seed=0 --run_num=0 --yaml_file='cheetah_forward'
python main.py --seed=0 --run_num=1 --yaml_file='swimmer_forward'
python main.py --seed=0 --run_num=2 --yaml_file='ant_forward'
python main.py --seed=0 --run_num=3 --yaml_file='hopper_forward'
python main.py --seed=0 --run_num=4 --yaml_file='cheetah_trajfollow' --desired_traj_type='straight' --visualize_MPC_rollout
python main.py --seed=0 --run_num=4 --yaml_file='cheetah_trajfollow' --desired_traj_type='backward' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=4 --yaml_file='cheetah_trajfollow' --desired_traj_type='forwardbackward' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=5 --yaml_file='swimmer_trajfollow' --desired_traj_type='straight' --visualize_MPC_rollout
python main.py --seed=0 --run_num=5 --yaml_file='swimmer_trajfollow' --desired_traj_type='left_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=5 --yaml_file='swimmer_trajfollow' --desired_traj_type='right_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='straight' --visualize_MPC_rollout
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='left_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='right_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='u_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model

How to run MBMF

Need to specify:

    --save_trpo_run_num number Number used as part of directory name for saving mbmf TRPO run (you can use 1,2,3,etc to differentiate your different seeds)
    --run_num Specify what directory to get relevant MB data from & to save new MBMF files in
    --which_agent Specify which agent (1 ant, 2 swimmer, 4 cheetah, 6 hopper)
    --std_on_mlp_policy Initial std you want to set on your pre-initialization policy for TRPO to use
    --num_workers_trpo How many worker threads (cpu) for TRPO to use
    --might_render If you might want to visualize anything during the run
    --visualize_mlp_policy To visualize the rollout performed by trained policy (that will serve as pre-initialization for TRPO)
    --visualize_on_policy_rollouts To set a breakpoint and visualize the on-policy rollouts after each agg iteration of dagger
    --print_minimal To not print messages regarding progress/notes/etc.
    --use_existing_pretrained_policy To run only the TRPO part (if you've already done the imitation learning part in the same directory)

Note that the finished TRPO run saves to ~/rllab/data/local/experiments/

Examples:
python mbmf.py --run_num=1 --which_agent=2 --std_on_mlp_policy=1.0
python mbmf.py --run_num=0 --which_agent=4 --std_on_mlp_policy=0.5
python mbmf.py --run_num=3 --which_agent=6 --std_on_mlp_policy=1.0 
python mbmf.py --run_num=2 --which_agent=1 --std_on_mlp_policy=0.5

How to run MF

Run pure TRPO, for comparisons.

Need to specify command line args as desired
    --seed Set random seed to set for numpy and tensorflow
    --steps_per_rollout Length of each rollout that TRPO should collect
    --save_trpo_run_num Number used as part of directory name for saving TRPO run (you can use 1,2,3,etc to differentiate your different seeds)
    --which_agent Specify which agent (1 ant, 2 swimmer, 4 cheetah, 6 hopper)
    --num_workers_trpo How many worker threads (cpu) for TRPO to use
    --num_trpo_iters Total number of TRPO iterations to run before stopping

Note that the finished TRPO run saves to ~/rllab/data/local/experiments/

Examples:
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=4 --num_workers_trpo=4
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=2 --num_workers_trpo=4
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=1 --num_workers_trpo=4
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=6 --num_workers_trpo=4

python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=4 --num_workers_trpo=4
python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=2 --num_workers_trpo=4
python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=1 --num_workers_trpo=4
python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=6 --num_workers_trpo=4

How to plot

  1. MBMF
        -Need to specify the commandline arguments as desired (in plot_mbmf.py)
        -Examples of running the plotting script:
cd plotting
python plot_mbmf.py --trpo_dir=[trpo_dir] --std_on_mlp_policy=1.0 --which_agent=2 --run_nums 1 --seeds 0
python plot_mbmf.py --trpo_dir=[trpo_dir] --std_on_mlp_policy=1.0 --which_agent=2 --run_nums 1 2 3 --seeds 0 70 100

Note that [trpo_dir] above corresponds to where the TRPO runs are saved. Probably somewhere in ~/rllab/data/...

  1. Dynamics model training and validation losses per aggregation iteration
    IPython notebook: plotting/plot_loss.ipynb
    Example plots: docs/sample_plots/...

  2. Visualize a forward simulation (an open-loop multi-step prediction of the elements of the state space)
    IPython notebook: plotting/plot_forwardsim.ipynb
    Example plots: docs/sample_plots/...

  3. Visualize the trajectories (on policy rollouts) per aggregation iteration
    IPython notebook: plotting/plot_trajfollow.ipynb
    Example plots: docs/sample_plots/...

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published