Skip to content

Commit

Permalink
Cleaning up variable/fixed parameters handling
Browse files Browse the repository at this point in the history
  • Loading branch information
njvrzm committed May 23, 2017
1 parent 82ad522 commit 52c4331
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 105 deletions.
154 changes: 69 additions & 85 deletions mesa/batchrunner.py
Expand Up @@ -6,13 +6,24 @@
A single class to manage a batch run or parameter sweep of a given model.
"""
from itertools import product
import collections
import copy

from itertools import product, count
import pandas as pd
from tqdm import tqdm


class VariableParameterError(TypeError):
MESSAGE = ('variable_parameters must map a name to a sequence of values. '
'These parameters were given with non-sequence values: {}')

def __init__(self, bad_names):
self.bad_names = bad_names

def __str__(self):
return self.MESSAGE.format(self.bad_names)


class BatchRunner:
""" This class is instantiated with a model class, and model parameters
associated with one or more values. It is also instantiated with model and
Expand All @@ -25,21 +36,26 @@ class BatchRunner:
entire DataCollector object.
"""
def __init__(self, model_cls, variable_parameters, fixed_parameters=None,
iterations=1, max_steps=1000, model_reporters=None,
agent_reporters=None, display_progress=True):
def __init__(self, model_cls, variable_parameters=None,
fixed_parameters=None, iterations=1, max_steps=1000,
model_reporters=None, agent_reporters=None, display_progress=True):
""" Create a new BatchRunner for a given model with the given
parameters.
Args:
model_cls: The class of model to batch-run.
variable_parameters: Dictionary of parameters to their values or
ranges of values. For example:
variable_parameters: Dictionary of parameters to lists of values.
The model will be run with every combination of these paramters.
For example, given variable_parameters of
{"param_1": range(5),
"param_2": [1, 5, 10],
"const_param": 100}
"param_2": [1, 5, 10]}
models will be run with {param_1=1, param_2=1},
{param_1=2, param_2=1}, ..., {param_1=4, param_2=10}.
fixed_parameters: Dictionary of parameters that stay same through
all batch runs.
all batch runs. For example, given fixed_parameters of
{"constant_parameter": 3},
every instantiated model will be passed constant_parameter=3
as a kwarg.
iterations: The total number of times to run the model for each
combination of parameters.
max_steps: The upper limit of steps above which each run will be halted
Expand All @@ -55,9 +71,8 @@ def __init__(self, model_cls, variable_parameters, fixed_parameters=None,
"""
self.model_cls = model_cls
self.parameter_values = {param: self.make_iterable(vals)
for param, vals in variable_parameters.items()}
self.fixed_values = fixed_parameters or {}
self.variable_parameters = self._process_parameters(variable_parameters)
self.fixed_parameters = fixed_parameters or {}
self.iterations = iterations
self.max_steps = max_steps

Expand All @@ -72,38 +87,43 @@ def __init__(self, model_cls, variable_parameters, fixed_parameters=None,

self.display_progress = display_progress

def _process_parameters(self, params):
params = copy.deepcopy(params)
bad_names = []
for name, values in params.items():
if (isinstance(values, str) or
not isinstance(values, collections.Sequence)):
bad_names.append(name)
if bad_names:
raise VariableParameterError(bad_names)
return params

def run_all(self):
""" Run the model at all parameter combinations and store results. """
params = self.parameter_values.keys()
param_ranges = self.parameter_values.values()
run_count = 0
if self.display_progress:
pbar = tqdm(total=len(list(product(*param_ranges))) * self.iterations)

for param_values in list(product(*param_ranges)):
kwargs = dict(zip(params, param_values))
model = self._try_to_init_model(kwargs)

for _ in range(self.iterations):
self.run_model(model)
# Collect and store results:
if self.model_reporters:
key = tuple(list(param_values) + [run_count])
self.model_vars[key] = self.collect_model_vars(model)
if self.agent_reporters:
agent_vars = self.collect_agent_vars(model)
for agent_id, reports in agent_vars.items():
key = tuple(
list(param_values) + [run_count, agent_id])
self.agent_vars[key] = reports
if self.display_progress:
param_names, param_ranges = zip(*self.variable_parameters.items())
run_count = count()
total_iterations = self.iterations
for param_range in param_ranges:
total_iterations *= len(param_range)
with tqdm(total_iterations, disable=not self.display_progress) as pbar:
for param_values in product(*param_ranges):
kwargs = dict(zip(param_names, param_values))
kwargs.update(self.fixed_parameters)
model = self.model_cls(**kwargs)

for _ in range(self.iterations):
self.run_model(model)
# Collect and store results:
model_key = param_values + (next(run_count),)
if self.model_reporters:
self.model_vars[model_key] = self.collect_model_vars(model)
if self.agent_reporters:
agent_vars = self.collect_agent_vars(model)
for agent_id, reports in agent_vars.items():
agent_key = model_key + (agent_id,)
self.agent_vars[agent_key] = reports
pbar.update()

run_count += 1

if self.display_progress:
pbar.close()

def run_model(self, model):
""" Run a model object to completion, or until reaching max steps.
Expand Down Expand Up @@ -143,61 +163,25 @@ def get_agent_vars_dataframe(self):
collected.
"""
return self._prepare_report_table(self.agent_vars)
return self._prepare_report_table(self.agent_vars,
extra_cols=['AgentId'])

def _prepare_report_table(self, vars_dict):
def _prepare_report_table(self, vars_dict, extra_cols=None):
"""
Creates a dataframe from collected records and sorts it using 'Run'
column as a key.
"""
index_cols = list(self.parameter_values.keys()) + ['Run']
extra_cols = ['Run'] + (extra_cols or [])
index_cols = list(self.variable_parameters.keys()) + extra_cols

records = []
for k, v in vars_dict.items():
record = dict(zip(index_cols, k))
record.update(v)
for param_key, values in vars_dict.items():
record = dict(zip(index_cols, param_key))
record.update(values)
records.append(record)

df = pd.DataFrame(records)
rest_cols = set(df.columns) - set(index_cols)
ordered = df[index_cols + list(sorted(rest_cols))]
ordered.sort_values(by='Run', inplace=True)
return ordered

@staticmethod
def make_iterable(val):
""" Helper method to ensure a value is a non-string iterable. """
if hasattr(val, "__iter__") and not isinstance(val, str):
return val
else:
return [val]

def _try_to_init_model(self, variable_params):
"""
Attempts to instantiate a model with specific variable parameters set
and additional fixed parameters if any.
Args:
variable_params: A mapping of a specific set of variable parameters.
"""
if not self.fixed_values:
return self.model_cls(**variable_params)

try:
kv = copy.deepcopy(variable_params)
kv.update(self.fixed_values)
return self.model_cls(**kv)

except TypeError:
import inspect
sig = inspect.signature(self.model_cls.__init__)
last_arg = list(sig.parameters.values())[-1]
valid_types = (last_arg.POSITIONAL_OR_KEYWORD,
last_arg.VAR_POSITIONAL)
if last_arg.kind in valid_types:
variable_params[last_arg.name] = self.fixed_values
return self.model_cls(**variable_params)

msg = ('Cannot configure model with variable '
'params {} and fixed params {}')
raise ValueError(msg.format(variable_params, self.fixed_values))
41 changes: 21 additions & 20 deletions tests/test_batchrunner.py
Expand Up @@ -49,12 +49,12 @@ def step(self):
self.schedule.step()


class MockDictionaryModel(Model):
class MockMixedModel(Model):

def __init__(self, variable_param, fixed_params):
def __init__(self, **other_params):
super().__init__()
self.variable_param = variable_param
self.fixed_name = fixed_params.get('fixed_name', None)
self.variable_name = other_params.get('variable_name', 42)
self.fixed_name = other_params.get('fixed_name')
self.running = True
self.schedule = BaseScheduler(None)
self.schedule.add(MockAgent(1, self, 0))
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_model_level_vars(self):
len(self.model_reporters) +
1) # extra column with run index

assert model_vars.shape == (self.model_runs, expected_cols)
self.assertEqual(model_vars.shape, (self.model_runs, expected_cols))

def test_agent_level_vars(self):
"""
Expand All @@ -125,9 +125,10 @@ def test_agent_level_vars(self):
agent_vars = batch.get_agent_vars_dataframe()
expected_cols = (len(self.variable_params) +
len(self.agent_reporters) +
1) # extra column with run index
2) # extra columns with run index and agentId

assert agent_vars.shape == (self.model_runs * NUM_AGENTS, expected_cols)
self.assertEqual(agent_vars.shape,
(self.model_runs * NUM_AGENTS, expected_cols))

def test_model_with_fixed_parameters_as_kwargs(self):
"""
Expand All @@ -139,23 +140,23 @@ def test_model_with_fixed_parameters_as_kwargs(self):
model_vars = batch.get_model_vars_dataframe()
agent_vars = batch.get_agent_vars_dataframe()

assert len(model_vars) == len(agent_vars)
assert len(model_vars) == self.model_runs
assert model_vars['reported_fixed_value'].unique() == ['Fixed']

def test_model_with_fixed_parameters_as_dict(self):
self.mock_model = MockDictionaryModel
self.model_reporters = {'reported_fixed_param': lambda m: m.fixed_name}
self.agent_reporters = {}
self.fixed_params = {'fixed_name': 'DictModel'}
self.variable_params = {'variable_param': [1, 2, 3]}
self.assertEqual(len(model_vars), len(agent_vars))
self.assertEqual(len(model_vars), self.model_runs)
self.assertEqual(model_vars['reported_fixed_value'].unique(), ['Fixed'])

def test_model_with_variable_and_fixed_kwargs(self):
self.mock_model = MockMixedModel
self.model_reporters = {
'reported_fixed_param': lambda m: m.fixed_name,
'reported_variable_param': lambda m: m.variable_name
}
self.fixed_params = {'fixed_name': 'Fixed'}
self.variable_params = {'variable_name': [1, 2, 3]}
batch = self.launch_batch_processing()
model_vars = batch.get_model_vars_dataframe()
expected_cols = (len(self.variable_params) +
len(self.model_reporters) +
1)

assert model_vars.shape == (self.model_runs, expected_cols)
assert (model_vars['reported_fixed_param'].iloc[0] ==
self.assertEqual(model_vars.shape, (self.model_runs, expected_cols))
self.assertEqual(model_vars['reported_fixed_param'].iloc[0],
self.fixed_params['fixed_name'])

0 comments on commit 52c4331

Please sign in to comment.