/
RL_and_SL_demo.py
317 lines (253 loc) · 11.8 KB
/
RL_and_SL_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
""" Demo where we add the same regularization loss from the other examples, but
this time as an `AuxiliaryTask` on top of the BaseMethod.
This makes it easy to create CL methods that apply to both RL and SL Settings!
"""
import copy
import random
import sys
from argparse import Namespace
from dataclasses import dataclass
from typing import ClassVar, List
import torch
from simple_parsing import ArgumentParser, field
from torch import Tensor
# This "hack" is required so we can run `python examples/custom_baseline_demo.py`
sys.path.extend([".", ".."])
from sequoia.common.config import Config
from sequoia.common.loss import Loss
from sequoia.methods import BaseMethod
from sequoia.methods.aux_tasks import AuxiliaryTask
from sequoia.methods.models import BaseModel, ForwardPass
from sequoia.methods.trainer import TrainerConfig
from sequoia.settings import Environment, RLSetting, Setting
from sequoia.utils.utils import camel_case, dict_intersection
from sequoia.utils.logging_utils import get_logger
logger = get_logger(__name__)
class SimpleRegularizationAuxTask(AuxiliaryTask):
"""Same regularization loss as in the previous examples, this time
implemented as an `AuxiliaryTask`, which gets added to the BaseModel,
making it applicable to both RL and SL.
This adds a CL regularizaiton loss to the BaseModel.
The most important methods of `AuxiliaryTask` is `get_loss`, which should
return a `Loss` for the given forward pass and resulting rewards/labels.
Take a look at the `AuxiliaryTask` class for more info.
"""
name: ClassVar[str] = "simple_regularization"
@dataclass
class Options(AuxiliaryTask.Options):
"""Hyper-parameters / configuration options of this auxiliary task."""
# Coefficient used to scale this regularization loss before it gets
# added to the 'base' loss of the model.
coefficient: float = 0.01
# Wether to use the absolute difference of the weights or the difference
# in the `regularize` method below.
use_abs_diff: bool = False
# The norm term for the 'distance' between the current and old weights.
distance_norm: int = 2
def __init__(
self,
*args,
name: str = None,
options: "SimpleRegularizationAuxTask.Options" = None,
**kwargs,
):
super().__init__(*args, options=options, name=name, **kwargs)
self.options: SimpleRegularizationAuxTask.Options
self.previous_task: int = None
# TODO: Figure out a clean way to persist this dict into the state_dict.
self.previous_model_weights: Dict[str, Tensor] = {}
self.n_switches: int = 0
def get_loss(self, forward_pass: ForwardPass, y: Tensor = None) -> Loss:
"""Get a `Loss` for the given forward pass and resulting rewards/labels.
Take a look at the `AuxiliaryTask` class for more info,
NOTE: This is the same simplified version of EWC used throughout the
other examples: the loss is the P-norm between the current weights and
the weights as they were on the begining of the task.
Also note, this particular example doesn't actually use the provided
arguments.
"""
if self.previous_task is None:
# We're in the first task: do nothing.
return Loss(name=self.name)
old_weights: Dict[str, Tensor] = self.previous_model_weights
new_weights: Dict[str, Tensor] = dict(self.model.named_parameters())
loss = 0.0
for weight_name, (new_w, old_w) in dict_intersection(new_weights, old_weights):
loss += torch.dist(new_w, old_w.type_as(new_w), p=self.options.distance_norm)
ewc_loss = Loss(name=self.name, loss=loss)
return ewc_loss
def on_task_switch(self, task_id: int) -> None:
"""Executed when the task switches (to either a new or known task)."""
if not self.enabled:
return
if self.previous_task is None and self.n_switches == 0:
logger.debug(f"Starting the first task, no update.")
pass
elif task_id is None or task_id != self.previous_task:
logger.debug(
f"Switching tasks: {self.previous_task} -> {task_id}: "
f"Updating the 'anchor' weights."
)
self.previous_task = task_id
self.previous_model_weights.clear()
self.previous_model_weights.update(
copy.deepcopy({k: v.detach() for k, v in self.model.named_parameters()})
)
self.n_switches += 1
class CustomizedBaselineModel(BaseModel):
@dataclass
class HParams(BaseModel.HParams):
"""Hyper-parameters of our customized baseline model."""
# Hyper-parameters of our simple new auxiliary task.
simple_reg: SimpleRegularizationAuxTask.Options = field(
default_factory=SimpleRegularizationAuxTask.Options
)
def __init__(
self,
setting: Setting,
hparams: "CustomizedBaselineModel.HParams",
config: Config,
):
super().__init__(setting=setting, hparams=hparams, config=config)
self.hp: CustomizedBaselineModel.HParams
# Here we add our new auxiliary task:
self.add_auxiliary_task(SimpleRegularizationAuxTask(options=self.hp.simple_reg))
# Or, add replay buffers of some sort:
self.replay_buffer: List = []
# (...)
@dataclass
class CustomMethod(BaseMethod, target_setting=Setting):
"""Example methods which adds regularization to the baseline in RL and SL.
This extends the `BaseMethod` by adding the simple regularization
auxiliary task defined above to the `BaseModel`.
NOTE: Since this class inherits from `BaseMethod`, which targets the
`Setting` setting, i.e. the "root" node, it is applicable to all settings,
both in RL and SL. However, you could customize the `target_setting`
argument above to limit this to any particular subtree (only SL, only RL,
only when task labels are present, etc).
"""
# Hyper-parameters of the customized Baseline Model used by this method.
hparams: CustomizedBaselineModel.HParams = field(
default_factory=CustomizedBaselineModel.HParams
)
def __init__(
self,
hparams: CustomizedBaselineModel.HParams = None,
config: Config = None,
trainer_options: TrainerConfig = None,
**kwargs,
):
super().__init__(
hparams=hparams,
config=config,
trainer_options=trainer_options,
**kwargs,
)
def create_model(self, setting: Setting) -> CustomizedBaselineModel:
"""Creates the Model to be used for the given `Setting`."""
return CustomizedBaselineModel(setting=setting, hparams=self.hparams, config=self.config)
def configure(self, setting: Setting):
"""Configure this Method before being trained / tested on this Setting."""
super().configure(setting)
# For example, change the value of the coefficient of our
# regularization loss when in RL vs SL:
if isinstance(setting, RLSetting):
self.hparams.simple_reg.coefficient = 0.01
else:
self.hparams.simple_reg.coefficient = 1.0
def fit(self, train_env: Environment, valid_env: Environment):
"""Called by the Setting to let the Method train on a given task.
You can do whatever you want with the train and valid
environments. As it is currently, in most `Settings`, the valid
environment will contain data from only the current task. (See issue at
https://github.com/lebrice/Sequoia/issues/46 for more context).
"""
return super().fit(train_env=train_env, valid_env=valid_env)
@classmethod
def add_argparse_args(cls, parser: ArgumentParser):
"""Adds command-line arguments for this Method to an argument parser.
NOTE: This doesn't do anything differently than the base implementation,
but it's included here just for illustration purposes.
"""
# 'dest' is where the arguments will be stored on the namespace.
dest = camel_case(cls.__qualname__)
# Add all command-line arguments. This adds arguments for all fields of
# this dataclass.
parser.add_arguments(cls, dest=dest)
# You could add arguments here if you wanted to:
# parser.add_argument("--foo", default=1.23, help="example argument")
@classmethod
def from_argparse_args(cls, args: Namespace):
"""Create an instance of this class from the parsed arguments."""
# Retrieve the parsed arguments:
dest = camel_case(cls.__qualname__)
method: CustomMethod = getattr(args, dest)
# You could retrieve other arguments like so:
# foo: int = args.foo
return method
def demo_manual():
"""Apply the custom method to a Setting, creating both manually in code."""
# Create any Setting from the tree:
from sequoia.settings import TaskIncrementalRLSetting, TaskIncrementalSLSetting
# setting = TaskIncrementalSLSetting(dataset="mnist", nb_tasks=5) # SL
setting = TaskIncrementalRLSetting( # RL
dataset="cartpole",
train_task_schedule={
0: {"gravity": 10, "length": 0.5},
5000: {"gravity": 10, "length": 1.0},
},
train_max_steps=10_000,
)
## Create the BaseMethod:
config = Config(debug=True)
trainer_options = TrainerConfig(max_epochs=1)
hparams = BaseModel.HParams()
base_method = BaseMethod(hparams=hparams, config=config, trainer_options=trainer_options)
## Get the results of the baseline method:
base_results = setting.apply(base_method, config=config)
## Create the CustomMethod:
config = Config(debug=True)
trainer_options = TrainerConfig(max_epochs=1)
hparams = CustomizedBaselineModel.HParams()
new_method = CustomMethod(hparams=hparams, config=config, trainer_options=trainer_options)
## Get the results for the 'improved' method:
new_results = setting.apply(new_method, config=config)
print(f"\n\nComparison: BaseMethod vs CustomMethod")
print("\n BaseMethod results: ")
print(base_results.summary())
print("\n CustomMethod results: ")
print(new_results.summary())
def demo_command_line():
"""Run the same demo as above, but customizing the Setting and Method from
the command-line.
NOTE: Remember to uncomment the function call below to use this instead of
demo_simple!
"""
## Create the `Setting` and the `Config` from the command-line, like in
## the other examples.
parser = ArgumentParser(description=__doc__)
## Add command-line arguments for any Setting in the tree:
from sequoia.settings import TaskIncrementalRLSetting, TaskIncrementalSLSetting
# parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
parser.add_arguments(TaskIncrementalRLSetting, dest="setting")
parser.add_arguments(Config, dest="config")
# Add the command-line arguments for our CustomMethod (including the
# arguments for our simple regularization aux task).
CustomMethod.add_argparse_args(parser, dest="method")
args = parser.parse_args()
setting: ClassIncrementalSetting = args.setting
config: Config = args.config
# Create the BaseMethod:
base_method = BaseMethod.from_argparse_args(args, dest="method")
# Get the results of the BaseMethod:
base_results = setting.apply(base_method, config=config)
## Create the CustomMethod:
new_method = CustomMethod.from_argparse_args(args, dest="method")
# Get the results for the CustomMethod:
new_results = setting.apply(new_method, config=config)
print(f"\n\nComparison: BaseMethod vs CustomMethod:")
print(base_results.summary())
print(new_results.summary())
if __name__ == "__main__":
demo_manual()
# demo_command_line()