New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicit fixed and variable parameters for batchrunner #374
Changes from 11 commits
d231ac9
a39ed63
11daf74
9f16ecc
a896511
998ff41
547d676
8fa2951
f653738
82ad522
52c4331
f00b08e
b1e9ec0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,24 @@ | |
A single class to manage a batch run or parameter sweep of a given model. | ||
|
||
""" | ||
from itertools import product | ||
import collections | ||
import copy | ||
from itertools import product, count | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
|
||
class VariableParameterError(TypeError): | ||
MESSAGE = ('variable_parameters must map a name to a sequence of values. ' | ||
'These parameters were given with non-sequence values: {}') | ||
|
||
def __init__(self, bad_names): | ||
self.bad_names = bad_names | ||
|
||
def __str__(self): | ||
return self.MESSAGE.format(self.bad_names) | ||
|
||
|
||
class BatchRunner: | ||
""" This class is instantiated with a model class, and model parameters | ||
associated with one or more values. It is also instantiated with model and | ||
|
@@ -23,19 +36,26 @@ class BatchRunner: | |
entire DataCollector object. | ||
|
||
""" | ||
def __init__(self, model_cls, parameter_values, iterations=1, | ||
max_steps=1000, model_reporters=None, agent_reporters=None, | ||
display_progress=True): | ||
def __init__(self, model_cls, variable_parameters=None, | ||
fixed_parameters=None, iterations=1, max_steps=1000, | ||
model_reporters=None, agent_reporters=None, display_progress=True): | ||
""" Create a new BatchRunner for a given model with the given | ||
parameters. | ||
|
||
Args: | ||
model_cls: The class of model to batch-run. | ||
parameter_values: Dictionary of parameters to their values or | ||
ranges of values. For example: | ||
variable_parameters: Dictionary of parameters to lists of values. | ||
The model will be run with every combination of these paramters. | ||
For example, given variable_parameters of | ||
{"param_1": range(5), | ||
"param_2": [1, 5, 10], | ||
"const_param": 100} | ||
"param_2": [1, 5, 10]} | ||
models will be run with {param_1=1, param_2=1}, | ||
{param_1=2, param_2=1}, ..., {param_1=4, param_2=10}. | ||
fixed_parameters: Dictionary of parameters that stay same through | ||
all batch runs. For example, given fixed_parameters of | ||
{"constant_parameter": 3}, | ||
every instantiated model will be passed constant_parameter=3 | ||
as a kwarg. | ||
iterations: The total number of times to run the model for each | ||
combination of parameters. | ||
max_steps: The upper limit of steps above which each run will be halted | ||
|
@@ -51,8 +71,8 @@ def __init__(self, model_cls, parameter_values, iterations=1, | |
|
||
""" | ||
self.model_cls = model_cls | ||
self.parameter_values = {param: self.make_iterable(vals) | ||
for param, vals in parameter_values.items()} | ||
self.variable_parameters = self._process_parameters(variable_parameters) | ||
self.fixed_parameters = fixed_parameters or {} | ||
self.iterations = iterations | ||
self.max_steps = max_steps | ||
|
||
|
@@ -67,36 +87,42 @@ def __init__(self, model_cls, parameter_values, iterations=1, | |
|
||
self.display_progress = display_progress | ||
|
||
def _process_parameters(self, params): | ||
params = copy.deepcopy(params) | ||
bad_names = [] | ||
for name, values in params.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This typecheck can be enforced with PEP484, will be much simpler. The version requirement will have to be bumped to be at least 3.5. |
||
if (isinstance(values, str) or | ||
not isinstance(values, collections.Sequence)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails with numpy sequences (e.g. in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found an unresolved issue, but with a patch here numpy/numpy#2776 (comment). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's also resolved with |
||
bad_names.append(name) | ||
if bad_names: | ||
raise VariableParameterError(bad_names) | ||
return params | ||
|
||
def run_all(self): | ||
""" Run the model at all parameter combinations and store results. """ | ||
params = self.parameter_values.keys() | ||
param_ranges = self.parameter_values.values() | ||
run_count = 0 | ||
|
||
if self.display_progress: | ||
pbar = tqdm(total=len(list(product(*param_ranges))) * self.iterations) | ||
|
||
for param_values in list(product(*param_ranges)): | ||
kwargs = dict(zip(params, param_values)) | ||
for _ in range(self.iterations): | ||
param_names, param_ranges = zip(*self.variable_parameters.items()) | ||
run_count = count() | ||
total_iterations = self.iterations | ||
for param_range in param_ranges: | ||
total_iterations *= len(param_range) | ||
with tqdm(total_iterations, disable=not self.display_progress) as pbar: | ||
for param_values in product(*param_ranges): | ||
kwargs = dict(zip(param_names, param_values)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is basically There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, Cartesian |
||
kwargs.update(self.fixed_parameters) | ||
model = self.model_cls(**kwargs) | ||
self.run_model(model) | ||
# Collect and store results: | ||
if self.model_reporters: | ||
key = tuple(list(param_values) + [run_count]) | ||
self.model_vars[key] = self.collect_model_vars(model) | ||
if self.agent_reporters: | ||
agent_vars = self.collect_agent_vars(model) | ||
for agent_id, reports in agent_vars.items(): | ||
key = tuple(list(param_values) + [run_count, agent_id]) | ||
self.agent_vars[key] = reports | ||
if self.display_progress: | ||
pbar.update() | ||
|
||
run_count += 1 | ||
|
||
if self.display_progress: | ||
pbar.close() | ||
for _ in range(self.iterations): | ||
self.run_model(model) | ||
# Collect and store results: | ||
model_key = param_values + (next(run_count),) | ||
if self.model_reporters: | ||
self.model_vars[model_key] = self.collect_model_vars(model) | ||
if self.agent_reporters: | ||
agent_vars = self.collect_agent_vars(model) | ||
for agent_id, reports in agent_vars.items(): | ||
agent_key = model_key + (agent_id,) | ||
self.agent_vars[agent_key] = reports | ||
pbar.update() | ||
|
||
def run_model(self, model): | ||
""" Run a model object to completion, or until reaching max steps. | ||
|
@@ -126,38 +152,36 @@ def collect_agent_vars(self, model): | |
return agent_vars | ||
|
||
def get_model_vars_dataframe(self): | ||
""" Generate a pandas DataFrame from the model-level variables collected. | ||
""" Generate a pandas DataFrame from the model-level variables | ||
collected. | ||
|
||
""" | ||
index_col_names = list(self.parameter_values.keys()) | ||
index_col_names.append("Run") | ||
records = [] | ||
for key, val in self.model_vars.items(): | ||
record = dict(zip(index_col_names, key)) | ||
for k, v in val.items(): | ||
record[k] = v | ||
records.append(record) | ||
return pd.DataFrame(records) | ||
return self._prepare_report_table(self.model_vars) | ||
|
||
def get_agent_vars_dataframe(self): | ||
""" Generate a pandas DataFrame from the agent-level variables | ||
collected. | ||
|
||
""" | ||
index_col_names = list(self.parameter_values.keys()) | ||
index_col_names += ["Run", "AgentID"] | ||
return self._prepare_report_table(self.agent_vars, | ||
extra_cols=['AgentId']) | ||
|
||
def _prepare_report_table(self, vars_dict, extra_cols=None): | ||
""" | ||
Creates a dataframe from collected records and sorts it using 'Run' | ||
column as a key. | ||
""" | ||
extra_cols = ['Run'] + (extra_cols or []) | ||
index_cols = list(self.variable_parameters.keys()) + extra_cols | ||
|
||
records = [] | ||
for key, val in self.agent_vars.items(): | ||
record = dict(zip(index_col_names, key)) | ||
for k, v in val.items(): | ||
record[k] = v | ||
for param_key, values in vars_dict.items(): | ||
record = dict(zip(index_cols, param_key)) | ||
record.update(values) | ||
records.append(record) | ||
return pd.DataFrame(records) | ||
|
||
@staticmethod | ||
def make_iterable(val): | ||
""" Helper method to ensure a value is a non-string iterable. """ | ||
if hasattr(val, "__iter__") and not isinstance(val, str): | ||
return val | ||
else: | ||
return [val] | ||
|
||
df = pd.DataFrame(records) | ||
rest_cols = set(df.columns) - set(index_cols) | ||
ordered = df[index_cols + list(sorted(rest_cols))] | ||
ordered.sort_values(by='Run', inplace=True) | ||
return ordered |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reads like
fixed_parameters || {}
;) --{}
should have been a default value for fixed_parameters, instead ofNone
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this absolutely a no-no in Python. Never use mutable objects as default parameters - they're shared between every invocation of the function, so if they're modified, terrible bugs can creep in. The way this works is quite standard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pylint w0102[1].
[1]http://pylint-messages.wikidot.com/messages:w0102