Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Reinforcement Learning Agent Creation #3648

Closed
tareknaser opened this issue Mar 2, 2024 · 5 comments
Closed

Simplify Reinforcement Learning Agent Creation #3648

tareknaser opened this issue Mar 2, 2024 · 5 comments

Comments

@tareknaser
Copy link
Member

What is the desired addition or change?

Simplify the creation of reinforcement learning agents in mlpack by having default values for common parameters, including network architectures and learning rates.

Defining a TD3 agent

Current Approach

// Set up the replay method.
RandomReplay<Pendulum> replayMethod(32, 10000);
  
// Set up the TD3 training configuration.
TrainingConfig config;
config.StepSize() = 0.01;
config.TargetNetworkSyncInterval() = 1;
config.UpdateInterval() = 3;
config.Rho() = 0.001;

// Set up Actor network.
FFN<EmptyLoss, GaussianInitialization>
  policyNetwork(EmptyLoss(), GaussianInitialization(0, 0.1));
policyNetwork.Add(new Linear(128));
policyNetwork.Add(new ReLU());
policyNetwork.Add(new Linear(1));
policyNetwork.Add(new TanH());

// Set up Critic network.
FFN<EmptyLoss, GaussianInitialization>
  qNetwork(EmptyLoss(), GaussianInitialization(0, 0.1));
qNetwork.Add(new Linear(128));
qNetwork.Add(new ReLU());
qNetwork.Add(new Linear(1));

// Set up Twin Delayed Deep Deterministic policy gradient agent.
TD3<Pendulum, decltype(qNetwork), decltype(policyNetwork), AdamUpdate>
  agent(config, qNetwork, policyNetwork, replayMethod);

Proposed Approach:

TD3<Pendulum> agent(replayMethod);

What is the motivation for this feature?

The current approach of manually configuring all agent parameters requires extra steps from users who want to quickly set up a basic reinforcement learning agent. Default constructors would simplify agent creation.

If applicable, describe how this feature would be implemented.

  • Implement constructors that accept essential parameters (environment, replay method) and use pre-defined defaults for others.
  • Keep the current methods to allow overriding defaults through arguments for users who require more control.
@shapy051002
Copy link

(This is my first Open Source Attempt to help, I apologise if I have broken any protocol, and am open to suggestions to improve)

I have tried to implement the template as-

template <
typename EnvironmentType,
typename QNetworkType= decltype(std::declval().GetDefaultQNetwork(),
typename PolicyNetworkType= decltype(std::declval().GetDefaultPolicyNetwork()),
typename UpdaterType = AdamUpdate,
typename ReplayType = RandomReplay

and overloaded the TD3(ReplayType& replayMethod); constructor method in td3.hpp

in td3_impl.hpp, i have expanded upon the constructor
template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename UpdaterType,
typename ReplayType

TD3<
EnvironmentType,
QNetworkType,
PolicyNetworkType,
UpdaterType,
ReplayType

::TD3(ReplayType& replayMethod):
config(GetDefaultConfig()),
learningQ1Network(GetDefaultQNetwork()),
policyNetwork(GetDefaultPolicy()),
replayMethod(replayMethod),
//rest of the method remains the same

in training_config.hpp, i have implemented the GetDefaultConfig() method-
TrainingConfig GetDefaultConfig() const{
TrainingConfig config;
config.StepSize() = 0.01;
config.TargetNetworkSyncInterval() = 1;
config.UpdateInterval() = 3;
config.Rho() = 0.001;

}
and similarly for the other methods as well
Qlearning GetDefaultQNetwork() const{
FFN<EmptyLoss, GaussianInitialization>
qNetwork(EmptyLoss(), GaussianInitialization(0, 0.1));
QNetwork.Add(new Linear(128));
QNetwork.Add(new ReLU());
QNetwork.Add(new Linear(1));

}
GreedyPolicy GetDefaultPolicy() const{
FFN<EmptyLoss, GaussianInitialization>
policyNetwork(EmptyLoss(), GaussianInitialization(0, 0.1));
policyNetwork.Add(new Linear(128));
policyNetwork.Add(new ReLU());
policyNetwork.Add(new Linear(1));
policyNetwork.Add(new TanH());
}

I have currently made these changes in my vscode, however I am not sure of the current syntax errors, since I dont know how to compile the entire mlpack library with my modifications.

is there a way I can give you the specific .hpp files where I have made the changes? or should I put a pull request so that my differences wrt the original code are visible?

@tareknaser
Copy link
Member Author

I have currently made these changes in my vscode, however I am not sure of the current syntax errors, since I dont know how to compile the entire mlpack library with my modifications.

I'd suggest you build mlpack main branch first. If you face any problems, feel free to open an issue. After you build mlpack, you should be able to follow RL tutorials
Then, add your modifications in a separate branch and build mlpack again to make sure the changes are working as expected.

is there a way I can give you the specific .hpp files where I have made the changes? or should I put a pull request so that my differences wrt the original code are visible?

You should open a PR.

@5advaith
Copy link

This can also be done by using Wrapper Function or Class: A wrapper function or class that creates a TD3 agent with default parameters. Would be better right . It is safer than modifying the TD3 class itself. Will be more simple and elegant

@tareknaser
Copy link
Member Author

Hey @5advaith
I am not familiar with this approach but could you open a PR where you implement it for just one reinforcement learning agent of your choice as a PoC?

cc: @zoq

Copy link

mlpack-bot bot commented May 1, 2024

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label May 1, 2024
@mlpack-bot mlpack-bot bot closed this as completed May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants