diff --git a/mesa/batchrunner.py b/mesa/batchrunner.py index 2c55e901795..1f83de42f3c 100644 --- a/mesa/batchrunner.py +++ b/mesa/batchrunner.py @@ -6,13 +6,24 @@ A single class to manage a batch run or parameter sweep of a given model. """ -from itertools import product +import collections import copy - +from itertools import product, count import pandas as pd from tqdm import tqdm +class VariableParameterError(TypeError): + MESSAGE = ('variable_parameters must map a name to a sequence of values. ' + 'These parameters were given with non-sequence values: {}') + + def __init__(self, bad_names): + self.bad_names = bad_names + + def __str__(self): + return self.MESSAGE.format(self.bad_names) + + class BatchRunner: """ This class is instantiated with a model class, and model parameters associated with one or more values. It is also instantiated with model and @@ -25,21 +36,26 @@ class BatchRunner: entire DataCollector object. """ - def __init__(self, model_cls, variable_parameters, fixed_parameters=None, - iterations=1, max_steps=1000, model_reporters=None, - agent_reporters=None, display_progress=True): + def __init__(self, model_cls, variable_parameters=None, + fixed_parameters=None, iterations=1, max_steps=1000, + model_reporters=None, agent_reporters=None, display_progress=True): """ Create a new BatchRunner for a given model with the given parameters. Args: model_cls: The class of model to batch-run. - variable_parameters: Dictionary of parameters to their values or - ranges of values. For example: + variable_parameters: Dictionary of parameters to lists of values. + The model will be run with every combination of these paramters. + For example, given variable_parameters of {"param_1": range(5), - "param_2": [1, 5, 10], - "const_param": 100} + "param_2": [1, 5, 10]} + models will be run with {param_1=1, param_2=1}, + {param_1=2, param_2=1}, ..., {param_1=4, param_2=10}. fixed_parameters: Dictionary of parameters that stay same through - all batch runs. + all batch runs. For example, given fixed_parameters of + {"constant_parameter": 3}, + every instantiated model will be passed constant_parameter=3 + as a kwarg. iterations: The total number of times to run the model for each combination of parameters. max_steps: The upper limit of steps above which each run will be halted @@ -55,9 +71,8 @@ def __init__(self, model_cls, variable_parameters, fixed_parameters=None, """ self.model_cls = model_cls - self.parameter_values = {param: self.make_iterable(vals) - for param, vals in variable_parameters.items()} - self.fixed_values = fixed_parameters or {} + self.variable_parameters = self._process_parameters(variable_parameters) + self.fixed_parameters = fixed_parameters or {} self.iterations = iterations self.max_steps = max_steps @@ -72,38 +87,43 @@ def __init__(self, model_cls, variable_parameters, fixed_parameters=None, self.display_progress = display_progress + def _process_parameters(self, params): + params = copy.deepcopy(params) + bad_names = [] + for name, values in params.items(): + if (isinstance(values, str) or + not isinstance(values, collections.Sequence)): + bad_names.append(name) + if bad_names: + raise VariableParameterError(bad_names) + return params + def run_all(self): """ Run the model at all parameter combinations and store results. """ - params = self.parameter_values.keys() - param_ranges = self.parameter_values.values() - run_count = 0 - if self.display_progress: - pbar = tqdm(total=len(list(product(*param_ranges))) * self.iterations) - - for param_values in list(product(*param_ranges)): - kwargs = dict(zip(params, param_values)) - model = self._try_to_init_model(kwargs) - - for _ in range(self.iterations): - self.run_model(model) - # Collect and store results: - if self.model_reporters: - key = tuple(list(param_values) + [run_count]) - self.model_vars[key] = self.collect_model_vars(model) - if self.agent_reporters: - agent_vars = self.collect_agent_vars(model) - for agent_id, reports in agent_vars.items(): - key = tuple( - list(param_values) + [run_count, agent_id]) - self.agent_vars[key] = reports - if self.display_progress: + param_names, param_ranges = zip(*self.variable_parameters.items()) + run_count = count() + total_iterations = self.iterations + for param_range in param_ranges: + total_iterations *= len(param_range) + with tqdm(total_iterations, disable=not self.display_progress) as pbar: + for param_values in product(*param_ranges): + kwargs = dict(zip(param_names, param_values)) + kwargs.update(self.fixed_parameters) + model = self.model_cls(**kwargs) + + for _ in range(self.iterations): + self.run_model(model) + # Collect and store results: + model_key = param_values + (next(run_count),) + if self.model_reporters: + self.model_vars[model_key] = self.collect_model_vars(model) + if self.agent_reporters: + agent_vars = self.collect_agent_vars(model) + for agent_id, reports in agent_vars.items(): + agent_key = model_key + (agent_id,) + self.agent_vars[agent_key] = reports pbar.update() - run_count += 1 - - if self.display_progress: - pbar.close() - def run_model(self, model): """ Run a model object to completion, or until reaching max steps. @@ -143,19 +163,21 @@ def get_agent_vars_dataframe(self): collected. """ - return self._prepare_report_table(self.agent_vars) + return self._prepare_report_table(self.agent_vars, + extra_cols=['AgentId']) - def _prepare_report_table(self, vars_dict): + def _prepare_report_table(self, vars_dict, extra_cols=None): """ Creates a dataframe from collected records and sorts it using 'Run' column as a key. """ - index_cols = list(self.parameter_values.keys()) + ['Run'] + extra_cols = ['Run'] + (extra_cols or []) + index_cols = list(self.variable_parameters.keys()) + extra_cols records = [] - for k, v in vars_dict.items(): - record = dict(zip(index_cols, k)) - record.update(v) + for param_key, values in vars_dict.items(): + record = dict(zip(index_cols, param_key)) + record.update(values) records.append(record) df = pd.DataFrame(records) @@ -163,41 +185,3 @@ def _prepare_report_table(self, vars_dict): ordered = df[index_cols + list(sorted(rest_cols))] ordered.sort_values(by='Run', inplace=True) return ordered - - @staticmethod - def make_iterable(val): - """ Helper method to ensure a value is a non-string iterable. """ - if hasattr(val, "__iter__") and not isinstance(val, str): - return val - else: - return [val] - - def _try_to_init_model(self, variable_params): - """ - Attempts to instantiate a model with specific variable parameters set - and additional fixed parameters if any. - - Args: - variable_params: A mapping of a specific set of variable parameters. - """ - if not self.fixed_values: - return self.model_cls(**variable_params) - - try: - kv = copy.deepcopy(variable_params) - kv.update(self.fixed_values) - return self.model_cls(**kv) - - except TypeError: - import inspect - sig = inspect.signature(self.model_cls.__init__) - last_arg = list(sig.parameters.values())[-1] - valid_types = (last_arg.POSITIONAL_OR_KEYWORD, - last_arg.VAR_POSITIONAL) - if last_arg.kind in valid_types: - variable_params[last_arg.name] = self.fixed_values - return self.model_cls(**variable_params) - - msg = ('Cannot configure model with variable ' - 'params {} and fixed params {}') - raise ValueError(msg.format(variable_params, self.fixed_values)) diff --git a/tests/test_batchrunner.py b/tests/test_batchrunner.py index d80e92546c9..acbf7492a67 100644 --- a/tests/test_batchrunner.py +++ b/tests/test_batchrunner.py @@ -51,10 +51,10 @@ def step(self): class MockDictionaryModel(Model): - def __init__(self, variable_param, fixed_params): + def __init__(self, variable_param, fixed_name): super().__init__() self.variable_param = variable_param - self.fixed_name = fixed_params.get('fixed_name', None) + self.fixed_name = fixed_name self.running = True self.schedule = BaseScheduler(None) self.schedule.add(MockAgent(1, self, 0)) @@ -115,7 +115,7 @@ def test_model_level_vars(self): len(self.model_reporters) + 1) # extra column with run index - assert model_vars.shape == (self.model_runs, expected_cols) + self.assertEqual(model_vars.shape, (self.model_runs, expected_cols)) def test_agent_level_vars(self): """ @@ -125,9 +125,10 @@ def test_agent_level_vars(self): agent_vars = batch.get_agent_vars_dataframe() expected_cols = (len(self.variable_params) + len(self.agent_reporters) + - 1) # extra column with run index + 2) # extra columns with run index and agentId - assert agent_vars.shape == (self.model_runs * NUM_AGENTS, expected_cols) + self.assertEqual(agent_vars.shape, + (self.model_runs * NUM_AGENTS, expected_cols)) def test_model_with_fixed_parameters_as_kwargs(self): """ @@ -139,9 +140,9 @@ def test_model_with_fixed_parameters_as_kwargs(self): model_vars = batch.get_model_vars_dataframe() agent_vars = batch.get_agent_vars_dataframe() - assert len(model_vars) == len(agent_vars) - assert len(model_vars) == self.model_runs - assert model_vars['reported_fixed_value'].unique() == ['Fixed'] + self.assertEqual(len(model_vars), len(agent_vars)) + self.assertEqual(len(model_vars), self.model_runs) + self.assertEqual(model_vars['reported_fixed_value'].unique(), ['Fixed']) def test_model_with_fixed_parameters_as_dict(self): self.mock_model = MockDictionaryModel @@ -156,6 +157,6 @@ def test_model_with_fixed_parameters_as_dict(self): len(self.model_reporters) + 1) - assert model_vars.shape == (self.model_runs, expected_cols) - assert (model_vars['reported_fixed_param'].iloc[0] == + self.assertEqual(model_vars.shape, (self.model_runs, expected_cols)) + self.assertEqual(model_vars['reported_fixed_param'].iloc[0], self.fixed_params['fixed_name'])