diff --git a/idtxl/network_analysis.py b/idtxl/network_analysis.py index 77c332fe..d7118a67 100644 --- a/idtxl/network_analysis.py +++ b/idtxl/network_analysis.py @@ -257,6 +257,11 @@ def _separate_realisations(self, idx_full, idx_single): def _define_candidates(self, processes, samples): """Build a list of candidate indices. + Build a list of candidate indices. Note that variables that were + manually added to the conditioning set via the 'add_conditionals' + setting are removed from the candidate set if both sets are not + disjoint. + Args: processes : list of int process indices @@ -268,9 +273,37 @@ def _define_candidates(self, processes, samples): candidate and has the form (process index, sample index), indices are absolute values with respect to some data array. """ - candidate_set = [] + candidate_set = self._build_variable_list(processes, samples) + # Remove candidates that were already manullay added to the + # conditioning set via the 'add_conditionals' setting. Otherwise the + # candidates get tested in the inclusion step. + candidate_set = self._remove_forced_conditionals(candidate_set) + return candidate_set + + def _build_variable_list(self, processes, samples): + """Build a list of variable tuples with (process index, sample index). + + Args: + processes : list of int + process indices + samples: list of int + sample indices + + Returns: + a list of variable tuples + """ + var_list = [] for idx in it.product(processes, samples): - candidate_set.append(idx) + var_list.append(idx) + return var_list + + def _remove_forced_conditionals(self, candidate_set): + """Remove enforced conditioning variables from candidate set.""" + if self.settings['add_conditionals'] is not None: + cond = self.settings['add_conditionals'] + if type(cond) is tuple: # easily add single variable + cond = [cond] + candidate_set = list(set(candidate_set).difference(set(cond))) return candidate_set def _append_selected_vars_idx(self, idx): diff --git a/idtxl/network_inference.py b/idtxl/network_inference.py index 7da2f11a..43abba84 100755 --- a/idtxl/network_inference.py +++ b/idtxl/network_inference.py @@ -151,8 +151,8 @@ def _force_conditionals(self, cond, data): # that _define_candidates returns tuples with absolute indices and # not lags. if cond == 'faes': - cond = self._define_candidates(self.source_set, - [self.current_value[1]]) + cond = self._build_variable_list(self.source_set, + [self.current_value[1]]) self._append_selected_vars( cond, data.get_realisations(self.current_value, cond)[0]) diff --git a/test/test_active_information_storage.py b/test/test_active_information_storage.py index e8309ee5..5aeaddf7 100644 --- a/test/test_active_information_storage.py +++ b/test/test_active_information_storage.py @@ -242,7 +242,43 @@ def test_discrete_input(): nw.analyse_single_process(settings=settings, data=data, process=0) +@jpype_missing +def test_define_candidates(): + """Test candidate definition from a list of procs and a list of samples.""" + target = 1 + tau_target = 3 + max_lag_target = 10 + current_val = (target, 10) + procs = [target] + samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_target, + -tau_target) + # Test if candidates that are added manually to the conditioning set are + # removed from the candidate set. + nw = ActiveInformationStorage() + settings = [ + {'add_conditionals': None}, + {'add_conditionals': (2, 3)}, + {'add_conditionals': [(2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' + assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + + settings = [ + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' + assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' + assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + + if __name__ == '__main__': + test_define_candidates() test_return_local_values() test_discrete_input() test_analyse_network() diff --git a/test/test_bivariate_mi.py b/test/test_bivariate_mi.py index cf2774d2..2c2907aa 100644 --- a/test/test_bivariate_mi.py +++ b/test/test_bivariate_mi.py @@ -368,13 +368,30 @@ def test_define_candidates(): procs = [target] samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources, -tau_sources) + # Test if candidates that are added manually to the conditioning set are + # removed from the candidate set. nw = BivariateMI() - candidates = nw._define_candidates(procs, samples) - assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + settings = [ + {'add_conditionals': None}, + {'add_conditionals': (2, 3)}, + {'add_conditionals': [(2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' + assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + + settings = [ + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' - @jpype_missing def test_analyse_network(): """Test method for full network analysis.""" @@ -495,6 +512,7 @@ def test_indices_to_lags(): if __name__ == '__main__': + test_define_candidates() test_zero_lag() test_gauss_data() test_return_local_values() @@ -506,4 +524,3 @@ def test_indices_to_lags(): test_faes_method() test_add_conditional_manually() test_check_source_set() - test_define_candidates() diff --git a/test/test_bivariate_te.py b/test/test_bivariate_te.py index 961949ee..8cc22f9c 100644 --- a/test/test_bivariate_te.py +++ b/test/test_bivariate_te.py @@ -363,15 +363,33 @@ def test_check_source_set(): def test_define_candidates(): """Test candidate definition from a list of procs and a list of samples.""" target = 1 - tau_target = 3 - max_lag_target = 10 + tau_sources = 3 + max_lag_sources = 10 current_val = (target, 10) procs = [target] - samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_target, - -tau_target) + samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources, + -tau_sources) + # Test if candidates that are added manually to the conditioning set are + # removed from the candidate set. nw = BivariateTE() - candidates = nw._define_candidates(procs, samples) - assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + settings = [ + {'add_conditionals': None}, + {'add_conditionals': (2, 3)}, + {'add_conditionals': [(2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' + assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + + settings = [ + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' @@ -476,7 +494,7 @@ def test_discrete_input(): @jpype_missing def test_mute_data(): """Test estimation from MuTE data.""" - max_lag = 3 + max_lag = 5 data = Data() data.generate_mute_data(200, 5) settings = { @@ -487,6 +505,7 @@ def test_mute_data(): 'n_perm_omnibus': 21, 'max_lag_sources': max_lag, 'min_lag_sources': 1, + 'add_conditionals': [(1, 3), (1, 2)], 'max_lag_target': max_lag} target = 2 te = BivariateTE() diff --git a/test/test_multivariate_mi.py b/test/test_multivariate_mi.py index 13b631b0..328d66f5 100644 --- a/test/test_multivariate_mi.py +++ b/test/test_multivariate_mi.py @@ -359,9 +359,27 @@ def test_define_candidates(): procs = [target] samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources, -tau_sources) + # Test if candidates that are added manually to the conditioning set are + # removed from the candidate set. nw = MultivariateMI() - candidates = nw._define_candidates(procs, samples) - assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + settings = [ + {'add_conditionals': None}, + {'add_conditionals': (2, 3)}, + {'add_conditionals': [(2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' + assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + + settings = [ + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' diff --git a/test/test_multivariate_te.py b/test/test_multivariate_te.py index 416d39ff..731a8a45 100644 --- a/test/test_multivariate_te.py +++ b/test/test_multivariate_te.py @@ -365,9 +365,27 @@ def test_define_candidates(): procs = [target] samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_target, -tau_target) + # Test if candidates that are added manually to the conditioning set are + # removed from the candidate set. nw = MultivariateTE() - candidates = nw._define_candidates(procs, samples) - assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + settings = [ + {'add_conditionals': None}, + {'add_conditionals': (2, 3)}, + {'add_conditionals': [(2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' + assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' + assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + + settings = [ + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] + for s in settings: + nw.settings = s + candidates = nw._define_candidates(procs, samples) + assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'