diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2bd19dc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: +- repo: https://github.com/pre-commit/mirrors-isort + rev: 'v4.3.21' + hooks: + - id: isort + additional_dependencies: ['isort[pyproject]'] +- repo: https://github.com/ambv/black + rev: stable + hooks: + - id: black diff --git a/.travis.yml b/.travis.yml index 1dbd25b..95d5147 100644 --- a/.travis.yml +++ b/.travis.yml @@ -37,6 +37,7 @@ install: - pip install -r requirements-dev.txt script: + - black --check epysurv - PYTHONPATH=. pytest --mypy -m mypy tests/ - PYTHONPATH=. pytest --cov=epysurv tests/ diff --git a/epysurv/__init__.py b/epysurv/__init__.py index 3dbb551..64d580f 100644 --- a/epysurv/__init__.py +++ b/epysurv/__init__.py @@ -5,5 +5,6 @@ """ from ._version import get_versions -__version__ = get_versions()['version'] + +__version__ = get_versions()["version"] del get_versions diff --git a/epysurv/_version.py b/epysurv/_version.py index 8b82301..b4eb6be 100644 --- a/epysurv/_version.py +++ b/epysurv/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -58,17 +57,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -330,8 +355,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -445,11 +469,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -469,9 +495,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -485,8 +515,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -495,13 +524,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for i in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -515,6 +547,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/epysurv/data/__init__.py b/epysurv/data/__init__.py index 4e5cd1a..1eed00b 100644 --- a/epysurv/data/__init__.py +++ b/epysurv/data/__init__.py @@ -1,3 +1,3 @@ """Module for handling data transformation and example data.""" +from .disease_loader import load_diseases from .salmonella_data import * -from .disease_loader import load_diseases \ No newline at end of file diff --git a/epysurv/data/disease_loader.py b/epysurv/data/disease_loader.py index 87ee8e7..6f11398 100644 --- a/epysurv/data/disease_loader.py +++ b/epysurv/data/disease_loader.py @@ -1,11 +1,15 @@ import os import pickle + from .filter_combination import FilterCombination # pylint: disable=unused-import + # Needs to be imported because of the pickled object requires it. def load_diseases(path): - disease_pickles = [file for file in os.listdir(path) if os.path.splitext(file)[-1] == '.pickle'] + disease_pickles = [ + file for file in os.listdir(path) if os.path.splitext(file)[-1] == ".pickle" + ] disease_pickles = sorted(disease_pickles) for disease_pickle in disease_pickles: - yield pickle.load(open(os.path.join(path, disease_pickle), 'rb')) + yield pickle.load(open(os.path.join(path, disease_pickle), "rb")) diff --git a/epysurv/data/filter_combination.py b/epysurv/data/filter_combination.py index 5a7a744..272d03d 100644 --- a/epysurv/data/filter_combination.py +++ b/epysurv/data/filter_combination.py @@ -1,7 +1,8 @@ -import pandas as pd -from typing import * -from dataclasses import dataclass, field from collections import namedtuple +from dataclasses import dataclass, field +from typing import * + +import pandas as pd from .utils import timedelta_weeks @@ -11,9 +12,10 @@ class SplitYears: start to middle is the training data. middle to end is the test data. """ + def __init__(self, start: pd.Timestamp, middle: pd.Timestamp, end: pd.Timestamp): if not (start < middle < end): - raise ValueError('start, middle and end must be consecutive.') + raise ValueError("start, middle and end must be consecutive.") self.start = start self.middle = middle self.end = end @@ -27,10 +29,12 @@ def from_ts_input(cls, start, middle, end): return cls(start, middle, end) -TimeseriesClassificationData = namedtuple('TimeseriesClassificationData', - ['train_final', 'test_final', 'train_gen', 'test_gen']) +TimeseriesClassificationData = namedtuple( + "TimeseriesClassificationData", + ["train_final", "test_final", "train_gen", "test_gen"], +) -FREQ = 'W-MON' +FREQ = "W-MON" @dataclass @@ -48,12 +52,15 @@ class FilterCombination: data The case records. """ + disease: str county: str pathogen: str data: pd.DataFrame = field(repr=False) - def expanding_windows(self, min_len_in_weeks: int, split_years: SplitYears) -> TimeseriesClassificationData: + def expanding_windows( + self, min_len_in_weeks: int, split_years: SplitYears + ) -> TimeseriesClassificationData: """ Transform case records into expanding time series. @@ -68,32 +75,56 @@ def expanding_windows(self, min_len_in_weeks: int, split_years: SplitYears) -> T """ self._validate_input(min_len_in_weeks, split_years) - train_data = self.data.query('@split_years.start <= ReportingDate < @split_years.middle') - test_data = self.data.query('@split_years.start <= ReportingDate < @split_years.end') + train_data = self.data.query( + "@split_years.start <= ReportingDate < @split_years.middle" + ) + test_data = self.data.query( + "@split_years.start <= ReportingDate < @split_years.end" + ) offset = split_years.start + timedelta_weeks(min_len_in_weeks) - true_train = (pd.DataFrame(index=pd.date_range(offset, split_years.middle, freq=FREQ, closed='left')) - .join(_to_recent_timeseries(train_data)) - .fillna(0) - .assign(outbreak=lambda df: df.n_outbreak_cases > 0)) - true_test = (pd.DataFrame(index=pd.date_range(split_years.middle, split_years.end, freq=FREQ, closed='left')) - .join(_to_recent_timeseries( - self.data.query('@split_years.middle <= ReportingDate < @split_years.end'))) - .fillna(0) - .assign(outbreak=lambda df: df.n_outbreak_cases > 0)) - - train_gen = self._expanding_frame(train_data, - true_train, - offset=offset, - start=split_years.start, - end=split_years.middle) - test_gen = self._expanding_frame(test_data, - true_test, - offset=split_years.middle, - start=split_years.start, - end=split_years.end) - + true_train = ( + pd.DataFrame( + index=pd.date_range( + offset, split_years.middle, freq=FREQ, closed="left" + ) + ) + .join(_to_recent_timeseries(train_data)) + .fillna(0) + .assign(outbreak=lambda df: df.n_outbreak_cases > 0) + ) + true_test = ( + pd.DataFrame( + index=pd.date_range( + split_years.middle, split_years.end, freq=FREQ, closed="left" + ) + ) + .join( + _to_recent_timeseries( + self.data.query( + "@split_years.middle <= ReportingDate < @split_years.end" + ) + ) + ) + .fillna(0) + .assign(outbreak=lambda df: df.n_outbreak_cases > 0) + ) + + train_gen = self._expanding_frame( + train_data, + true_train, + offset=offset, + start=split_years.start, + end=split_years.middle, + ) + test_gen = self._expanding_frame( + test_data, + true_test, + offset=split_years.middle, + start=split_years.start, + end=split_years.end, + ) return TimeseriesClassificationData(true_train, true_test, train_gen, test_gen) @@ -103,32 +134,57 @@ def _validate_input(self, min_len_in_weeks, split_years): # if self.data.ReportingDate.max() < split_years.end: # raise ValueError(f'The end date must be before the last case, but is {split_years.end}') if split_years.start < self.data.ReportingDate.min(): - raise ValueError(f'The start date must be after the first case, but is {split_years.start}') + raise ValueError( + f"The start date must be after the first case, but is {split_years.start}" + ) if split_years.middle < split_years.start + timedelta_weeks(min_len_in_weeks): - raise ValueError(f'The start date plus the offset must be before the middle date, ' - f'but is {split_years.start + timedelta_weeks(min_len_in_weeks)}') - - def _expanding_frame(self, data: pd.DataFrame, final_data: pd.DataFrame, offset: pd.Timestamp, start: pd.Timestamp, end: pd.Timestamp): - for date in pd.date_range(offset, end, freq=FREQ, closed='left'): - ts = (data - # .copy() - .query('ValidFrom <= @date & (ValidUntil > @date | @pd.isna(ValidUntil))') - .set_index('ReportingDate') - .groupby(pd.Grouper(freq=FREQ)) - .agg({'IdRecord': 'count', 'IdRecordAusbruchOut': 'count'}) - .rename(columns={'IdRecord': 'n_cases', 'IdRecordAusbruchOut': 'n_outbreak_cases'}) - ) - ts = pd.DataFrame(index=pd.date_range(start, date, freq=FREQ)).join(ts).fillna(0) + raise ValueError( + f"The start date plus the offset must be before the middle date, " + f"but is {split_years.start + timedelta_weeks(min_len_in_weeks)}" + ) + + def _expanding_frame( + self, + data: pd.DataFrame, + final_data: pd.DataFrame, + offset: pd.Timestamp, + start: pd.Timestamp, + end: pd.Timestamp, + ): + for date in pd.date_range(offset, end, freq=FREQ, closed="left"): + ts = ( + data + # .copy() + .query( + "ValidFrom <= @date & (ValidUntil > @date | @pd.isna(ValidUntil))" + ) + .set_index("ReportingDate") + .groupby(pd.Grouper(freq=FREQ)) + .agg({"IdRecord": "count", "IdRecordAusbruchOut": "count"}) + .rename( + columns={ + "IdRecord": "n_cases", + "IdRecordAusbruchOut": "n_outbreak_cases", + } + ) + ) + ts = ( + pd.DataFrame(index=pd.date_range(start, date, freq=FREQ)) + .join(ts) + .fillna(0) + ) outbreak = final_data.loc[date].outbreak yield ts, outbreak def _to_recent_timeseries(data: pd.DataFrame) -> pd.DataFrame: """Get a time series from case data, that represents the most recent state.""" - return (data - .query('IsCurrent') - .set_index('ReportingDate') - .groupby(pd.Grouper(freq=FREQ)) - .agg({'IdRecord': 'count', 'IdRecordAusbruchOut': 'count'}) - .rename(columns={'IdRecord': 'n_cases', 'IdRecordAusbruchOut': 'n_outbreak_cases'}) - ) + return ( + data.query("IsCurrent") + .set_index("ReportingDate") + .groupby(pd.Grouper(freq=FREQ)) + .agg({"IdRecord": "count", "IdRecordAusbruchOut": "count"}) + .rename( + columns={"IdRecord": "n_cases", "IdRecordAusbruchOut": "n_outbreak_cases"} + ) + ) diff --git a/epysurv/data/salmonella_data.py b/epysurv/data/salmonella_data.py index 5a580f5..ce8ec4e 100644 --- a/epysurv/data/salmonella_data.py +++ b/epysurv/data/salmonella_data.py @@ -1,29 +1,36 @@ -from typing import * import os -import pandas as pd from collections import namedtuple +from typing import * + +import pandas as pd from .utils import timedelta_weeks -TimeseriesClassificationData = namedtuple('TimeseriesClassificationData', ['train', 'test', 'train_gen', 'test_gen']) +TimeseriesClassificationData = namedtuple( + "TimeseriesClassificationData", ["train", "test", "train_gen", "test_gen"] +) def salmonella(): """Count data from Salmonella newport in Germany.""" - train = _load_data('salmonella_train.csv') - test = _load_data('salmonella_test.csv') + train = _load_data("salmonella_train.csv") + test = _load_data("salmonella_test.csv") return train, test -def timeseries_classifcation(train: pd.DataFrame, test: pd.DataFrame, - offset_in_weeks: int) -> TimeseriesClassificationData: +def timeseries_classifcation( + train: pd.DataFrame, test: pd.DataFrame, offset_in_weeks: int +) -> TimeseriesClassificationData: """Convert standard timeseries for usage in time series classification.""" - train_gen, test_gen = timeseries_classifaction_generator(train, test, offset_in_weeks) + train_gen, test_gen = timeseries_classifaction_generator( + train, test, offset_in_weeks + ) return TimeseriesClassificationData(train, test, train_gen, test_gen) def timeseries_classifaction_generator( - train: pd.DataFrame, test: pd.DataFrame, offset_in_weeks: int) -> Tuple[Generator, Generator]: + train: pd.DataFrame, test: pd.DataFrame, offset_in_weeks: int +) -> Tuple[Generator, Generator]: """Turn a time point classification problem into a time series classification problem.""" offset = train.index[0] + timedelta_weeks(offset_in_weeks) train_generator = _growing_frame(train, offset=offset) @@ -33,15 +40,19 @@ def timeseries_classifaction_generator( def _load_data(filename: str): - data = pd.read_csv(os.path.join(os.path.dirname(__file__), filename), index_col=0, parse_dates=True, - infer_datetime_format=True) + data = pd.read_csv( + os.path.join(os.path.dirname(__file__), filename), + index_col=0, + parse_dates=True, + infer_datetime_format=True, + ) data.index.freq = pd.infer_freq(data.index) return data def _growing_frame(data: pd.DataFrame, offset: pd.Timestamp): - before_begin_data = data.query('index <= @offset') - after_begin_data = data.query('index > @offset') + before_begin_data = data.query("index <= @offset") + after_begin_data = data.query("index > @offset") new_frame = before_begin_data.copy() for idx, row in after_begin_data.iterrows(): new_frame.loc[idx] = row diff --git a/epysurv/data/utils.py b/epysurv/data/utils.py index b3b2291..5eeaca7 100644 --- a/epysurv/data/utils.py +++ b/epysurv/data/utils.py @@ -2,4 +2,4 @@ def timedelta_weeks(weeks: int): - return pd.Timedelta(7 * weeks, unit='D') + return pd.Timedelta(7 * weeks, unit="D") diff --git a/epysurv/metrics/__init__.py b/epysurv/metrics/__init__.py index 7994a0d..50ba5f1 100644 --- a/epysurv/metrics/__init__.py +++ b/epysurv/metrics/__init__.py @@ -1 +1 @@ -from .outbreak_detection import * \ No newline at end of file +from .outbreak_detection import * diff --git a/epysurv/metrics/outbreak_detection.py b/epysurv/metrics/outbreak_detection.py index 40cc8ea..0ff05a7 100644 --- a/epysurv/metrics/outbreak_detection.py +++ b/epysurv/metrics/outbreak_detection.py @@ -1,5 +1,5 @@ -import pandas as pd import numpy as np +import pandas as pd def ghozzi_score(prediction_result: pd.DataFrame) -> float: @@ -19,18 +19,26 @@ def ghozzi_score(prediction_result: pd.DataFrame) -> float: A maximum score of 1. """ # Outbreaks that were correctly predicted. - weighted_true_positives = np.sum(prediction_result.alarm - * prediction_result.outbreak - * prediction_result.n_outbreak_cases) + weighted_true_positives = np.sum( + prediction_result.alarm + * prediction_result.outbreak + * prediction_result.n_outbreak_cases + ) # Outbreaks that were missed. - weighted_false_negatives = np.sum((1 - prediction_result.alarm) - * prediction_result.outbreak - * prediction_result.n_outbreak_cases) + weighted_false_negatives = np.sum( + (1 - prediction_result.alarm) + * prediction_result.outbreak + * prediction_result.n_outbreak_cases + ) # Alarms that were falsely raised. - weighted_false_positives = np.sum(prediction_result.alarm - * (prediction_result.outbreak != prediction_result.alarm) - * np.mean(prediction_result.query('outbreak').n_outbreak_cases)) - absolute_score = weighted_true_positives - weighted_false_negatives - weighted_false_positives + weighted_false_positives = np.sum( + prediction_result.alarm + * (prediction_result.outbreak != prediction_result.alarm) + * np.mean(prediction_result.query("outbreak").n_outbreak_cases) + ) + absolute_score = ( + weighted_true_positives - weighted_false_negatives - weighted_false_positives + ) normalized_score = absolute_score / prediction_result.n_outbreak_cases.sum() return normalized_score @@ -53,17 +61,25 @@ def ghozzi_case_score(prediction_result: pd.DataFrame) -> float: A maximum score of 1. """ # Outbreaks that were correctly predicted. - weighted_true_positives = np.sum(prediction_result.alarm - * prediction_result.outbreak - * prediction_result.n_outbreak_cases) + weighted_true_positives = np.sum( + prediction_result.alarm + * prediction_result.outbreak + * prediction_result.n_outbreak_cases + ) # Outbreaks that were missed. - weighted_false_negatives = np.sum((1 - prediction_result.alarm) - * prediction_result.outbreak - * prediction_result.n_outbreak_cases) + weighted_false_negatives = np.sum( + (1 - prediction_result.alarm) + * prediction_result.outbreak + * prediction_result.n_outbreak_cases + ) # Alarms that were falsely raised. - weighted_false_positives = np.sum(prediction_result.alarm - * (prediction_result.outbreak != prediction_result.alarm) - * prediction_result.n_cases) - absolute_score = weighted_true_positives - weighted_false_negatives - weighted_false_positives + weighted_false_positives = np.sum( + prediction_result.alarm + * (prediction_result.outbreak != prediction_result.alarm) + * prediction_result.n_cases + ) + absolute_score = ( + weighted_true_positives - weighted_false_negatives - weighted_false_positives + ) normalized_score = absolute_score / prediction_result.n_outbreak_cases.sum() return normalized_score diff --git a/epysurv/models/timepoint/__init__.py b/epysurv/models/timepoint/__init__.py index 0be30ee..6e2a362 100644 --- a/epysurv/models/timepoint/__init__.py +++ b/epysurv/models/timepoint/__init__.py @@ -3,7 +3,7 @@ from .cdc import CDC from .cusum import Cusum from .ears import EarsC1, EarsC2, EarsC3 -from .farrington import FarringtonFlexible, Farrington +from .farrington import Farrington, FarringtonFlexible from .glr import GLRNegativeBinomial, GLRPoisson from .hmm import HMM from .outbreak_p import OutbreakP diff --git a/epysurv/models/timepoint/_base.py b/epysurv/models/timepoint/_base.py index 8e83584..2d24e3d 100644 --- a/epysurv/models/timepoint/_base.py +++ b/epysurv/models/timepoint/_base.py @@ -4,9 +4,9 @@ import numpy as np import pandas as pd import rpy2.robjects as robjects -from rpy2.robjects.packages import importr -from rpy2.robjects import r, numpy2ri, pandas2ri from pandas.tseries import offsets +from rpy2.robjects import numpy2ri, pandas2ri, r +from rpy2.robjects.packages import importr from epysurv.metrics.outbreak_detection import ghozzi_score @@ -16,21 +16,22 @@ def silence_R_output(): This is useful, because some algorithm otherwise print every time they are invoked. """ - if platform.system() == 'Linux': - r.sink('/dev/null') - elif platform.system() == 'Windows': - r.sink('NUL') + if platform.system() == "Linux": + r.sink("/dev/null") + elif platform.system() == "Windows": + r.sink("NUL") silence_R_output() numpy2ri.activate() pandas2ri.activate() -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass class TimepointSurveillanceAlgorithm: """Algorithms that predict outbreaks for every timepoint.""" + _training_data: pd.DataFrame = field(init=False, repr=False) def fit(self, data: pd.DataFrame): @@ -53,17 +54,17 @@ def _validate_data(self, data: pd.DataFrame): self._contains_counts(data) def _contains_dates(self, data: pd.DataFrame): - has_dates = 'ds' in data.columns or isinstance(data.index, pd.DatetimeIndex) + has_dates = "ds" in data.columns or isinstance(data.index, pd.DatetimeIndex) if not has_dates: - raise ValueError('No dates') + raise ValueError("No dates") def _contains_counts(self, data: pd.DataFrame): - if not {'n_cases', 'n_outbreak_cases'} < set(data.columns): + if not {"n_cases", "n_outbreak_cases"} < set(data.columns): raise ValueError('No column named "n_cases"') def _data_in_the_future(self, data: pd.DataFrame): if data.index.min() <= self._training_data.index.max(): - raise ValueError('The prediction data overlaps with the training data.') + raise ValueError("The prediction data overlaps with the training data.") offset_to_freq = { @@ -74,10 +75,10 @@ def _data_in_the_future(self, data: pd.DataFrame): } offset_to_attr = { - offsets.Day: 'day', - offsets.Week: 'week', - offsets.MonthBegin: 'month', - offsets.MonthEnd: 'month', + offsets.Day: "day", + offsets.Week: "week", + offsets.MonthBegin: "month", + offsets.MonthEnd: "month", } @@ -108,16 +109,20 @@ def predict(self, data: pd.DataFrame) -> pd.DataFrame: super().predict(data) prediction_len = len(data) # Concat training and prediction data. make index array for range param. - data = (pd.concat((self._training_data, data), keys=['train', 'test']) - .reset_index(level=0) - .rename(columns={'level_0': 'provenance'})) + data = ( + pd.concat((self._training_data, data), keys=["train", "test"]) + .reset_index(level=0) + .rename(columns={"level_0": "provenance"}) + ) r_instance = self._prepare_r_instance(data) # R indexes are 0-based. Therefore we add 1. - detection_range = robjects.IntVector(np.where(data.provenance == 'test')[0] + 1) + detection_range = robjects.IntVector(np.where(data.provenance == "test")[0] + 1) surveillance_result = self._call_surveillance_algo(r_instance, detection_range) alarm = self._extract_alarms(surveillance_result) predictions = data.copy() - predictions['alarm'] = np.append((len(predictions) - len(alarm)) * [np.nan], alarm) + predictions["alarm"] = np.append( + (len(predictions) - len(alarm)) * [np.nan], alarm + ) predictions = predictions.iloc[-prediction_len:] return predictions @@ -143,19 +148,27 @@ def _prepare_r_instance(self, data: pd.DataFrame): if data.index.freq is None: freq = pd.infer_freq(data.index) if freq is None: - raise ValueError(f'The time series index has no valid frequency. Index={data.index}') + raise ValueError( + f"The time series index has no valid frequency. Index={data.index}" + ) data.index.freq = freq - sts = surveillance.sts(start=r.c(data.index[0].year, _get_start_epoch(data)), - epoch=robjects.IntVector( - [r['as.numeric'](r['as.Date'](d.isoformat()))[0] for d in data.index.date]), - # epoch=data.index, - freq=_get_freq(data), - observed=data["n_cases"].values, - epochAsDate=True) + sts = surveillance.sts( + start=r.c(data.index[0].year, _get_start_epoch(data)), + epoch=robjects.IntVector( + [ + r["as.numeric"](r["as.Date"](d.isoformat()))[0] + for d in data.index.date + ] + ), + # epoch=data.index, + freq=_get_freq(data), + observed=data["n_cases"].values, + epochAsDate=True, + ) return sts def _extract_alarms(self, surveillance_result): - return np.asarray(surveillance_result.slots['alarm']) + return np.asarray(surveillance_result.slots["alarm"]) class DisProgBasedAlgorithm(STSBasedAlgorithm): @@ -166,4 +179,6 @@ def _prepare_r_instance(self, data: pd.DataFrame): return surveillance.sts2disProg(sts) def _extract_alarms(self, surveillance_result): - return np.asarray(dict(zip(surveillance_result.names, list(surveillance_result)))['alarm']) + return np.asarray( + dict(zip(surveillance_result.names, list(surveillance_result)))["alarm"] + ) diff --git a/epysurv/models/timepoint/bayes.py b/epysurv/models/timepoint/bayes.py index f603e04..b11dda3 100644 --- a/epysurv/models/timepoint/bayes.py +++ b/epysurv/models/timepoint/bayes.py @@ -5,7 +5,7 @@ from ._base import STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -30,17 +30,20 @@ class Bayes(STSBasedAlgorithm): Surveillance Daten, Bachelor’s thesis [2] Höhle, M., & Riebler, A. (2005). Höhle, Riebler: The R-Package “surveillance.” Sonderforschungsbereich (Vol. 386). Retrieved from https://epub.ub.uni-muenchen.de/1791/1/paper_422.pdf """ + years_back: int = 0 window_half_width: int = 6 include_recent_year: bool = True alpha: float = 0.05 def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - b=self.years_back, - w=self.window_half_width, - actY=self.include_recent_year, - alpha=self.alpha) + control = r.list( + range=detection_range, + b=self.years_back, + w=self.window_half_width, + actY=self.include_recent_year, + alpha=self.alpha, + ) surv = surveillance.bayes(sts, control=control) return surv diff --git a/epysurv/models/timepoint/boda.py b/epysurv/models/timepoint/boda.py index ab59e75..0fab0f2 100644 --- a/epysurv/models/timepoint/boda.py +++ b/epysurv/models/timepoint/boda.py @@ -1,14 +1,14 @@ -from dataclasses import dataclass import warnings +from dataclasses import dataclass from rpy2 import robjects +from rpy2.rinterface import RRuntimeError from rpy2.robjects import r from rpy2.robjects.packages import importr -from rpy2.rinterface import RRuntimeError -from ._base import STSBasedAlgorithm +from ._base import STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -40,33 +40,38 @@ class Boda(STSBasedAlgorithm): of the parameters and then compute the quantile of the mixture distribution using bisectioning, which is faster. """ + trend: bool = False season: bool = False - prior: str = 'iid' + prior: str = "iid" alpha: float = 0.05 mc_munu: int = 100 mc_y: int = 10 - sampling_method = 'joint' - quantile_method: str = 'MM' + sampling_method = "joint" + quantile_method: str = "MM" def _call_surveillance_algo(self, sts, detection_range): try: - importr('INLA') + importr("INLA") except RRuntimeError: raise ImportError( - 'For the Boda algortihm to run you need the INLA package (http://www.r-inla.org/). ' + "For the Boda algortihm to run you need the INLA package (http://www.r-inla.org/). " 'Install it by running install.packages("INLA", repos = c(getOption("repos"), INLA = "https://inla.r-inla-download.org/R/stable"), dep = TRUE) ' - 'in the R console.' + "in the R console." ) - control = r.list(**{'range': detection_range, - 'X': robjects.NULL, - 'trend': self.trend, - 'season': self.season, - 'prior': self.prior, - 'alpha': self.alpha, - 'mc.munu': self.mc_munu, - 'mc.y': self.mc_y, - 'samplingMethod': self.sampling_method, - 'quantileMethod': self.quantile_method}) + control = r.list( + **{ + "range": detection_range, + "X": robjects.NULL, + "trend": self.trend, + "season": self.season, + "prior": self.prior, + "alpha": self.alpha, + "mc.munu": self.mc_munu, + "mc.y": self.mc_y, + "samplingMethod": self.sampling_method, + "quantileMethod": self.quantile_method, + } + ) surv = surveillance.boda(sts, control=control) return surv diff --git a/epysurv/models/timepoint/cdc.py b/epysurv/models/timepoint/cdc.py index b426166..5ace2ed 100644 --- a/epysurv/models/timepoint/cdc.py +++ b/epysurv/models/timepoint/cdc.py @@ -1,10 +1,11 @@ from dataclasses import dataclass + from rpy2.robjects import r from rpy2.robjects.packages import importr from ._base import DisProgBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -25,14 +26,17 @@ class CDC(DisProgBasedAlgorithm): [2] Farrington, C. and N. Andrews (2003). Monitoring the Health of Populations, Chapter Outbreak Detection: Application to Infectious Disease Surveillance, pp. 203-231. Oxford University Press. """ + years_back: int = 5 window_half_width: int = 1 alpha: float = 0.001 def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - b=self.years_back, - m=self.window_half_width, - alpha=self.alpha) + control = r.list( + range=detection_range, + b=self.years_back, + m=self.window_half_width, + alpha=self.alpha, + ) surv = surveillance.algo_cdc(sts, control=control) return surv diff --git a/epysurv/models/timepoint/cusum.py b/epysurv/models/timepoint/cusum.py index 97c249e..746131f 100644 --- a/epysurv/models/timepoint/cusum.py +++ b/epysurv/models/timepoint/cusum.py @@ -1,12 +1,13 @@ -from typing import * from dataclasses import dataclass +from typing import * + from rpy2 import robjects from rpy2.robjects import r from rpy2.robjects.packages import importr from ._base import STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -40,18 +41,23 @@ class Cusum(STSBasedAlgorithm): [2] D. A. Pierce and D. W. Schafer (1986), Residuals in Generalized Linear Models, Journal of the American Statistical Association, 81, 977–986 """ + reference_value: float = 1.04 decision_boundary: float = 2.26 expected_numbers_method: str = "mean" - transform: str = 'standard' + transform: str = "standard" negbin_alpha: float = 0.1 def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - k=self.reference_value, - h=self.decision_boundary, - m=robjects.NULL if self.expected_numbers_method == "mean" else self.expected_numbers_method, - trans=self.transform, - alpha=self.negbin_alpha) + control = r.list( + range=detection_range, + k=self.reference_value, + h=self.decision_boundary, + m=robjects.NULL + if self.expected_numbers_method == "mean" + else self.expected_numbers_method, + trans=self.transform, + alpha=self.negbin_alpha, + ) surv = surveillance.cusum(sts, control=control) return surv diff --git a/epysurv/models/timepoint/ears.py b/epysurv/models/timepoint/ears.py index baf51f5..1474699 100644 --- a/epysurv/models/timepoint/ears.py +++ b/epysurv/models/timepoint/ears.py @@ -6,7 +6,7 @@ from ._base import STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -14,6 +14,7 @@ class _EarsBase(STSBasedAlgorithm): """ Base class for the Ears models. """ + alpha: float = 0.001 baseline: int = 7 min_sigma: float = 0 @@ -21,11 +22,13 @@ class _EarsBase(STSBasedAlgorithm): method: ClassVar = None def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - method=self.method, - baseline=self.baseline, - minSigma=self.min_sigma, - alpha=self.alpha) + control = r.list( + range=detection_range, + method=self.method, + baseline=self.baseline, + minSigma=self.min_sigma, + alpha=self.alpha, + ) surv = surveillance.earsC(sts, control=control) return surv @@ -53,7 +56,8 @@ class EarsC1(_EarsBase): [2] Salmon, M., Schumacher, D. and Höhle, M. (2016): Monitoring count time series in R: Aberration detection in public health surveillance. Journal of Statistical Software, 70 (10), 1-35. doi: 10.18637/jss.v070.i10 """ - method = 'C1' + + method = "C1" class EarsC2(_EarsBase): @@ -79,7 +83,8 @@ class EarsC2(_EarsBase): [2] Salmon, M., Schumacher, D. and Höhle, M. (2016): Monitoring count time series in R: Aberration detection in public health surveillance. Journal of Statistical Software, 70 (10), 1-35. doi: 10.18637/jss.v070.i10 """ - method = 'C2' + + method = "C2" @dataclass @@ -105,14 +110,17 @@ class EarsC3(_EarsBase): [2] Salmon, M., Schumacher, D. and Höhle, M. (2016): Monitoring count time series in R: Aberration detection in public health surveillance. Journal of Statistical Software, 70 (10), 1-35. doi: 10.18637/jss.v070.i10 """ + alpha: float = 0.001 baseline: int = 7 def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - method="C3", - baseline=self.baseline, - minSigma=self.min_sigma, - alpha=self.alpha) + control = r.list( + range=detection_range, + method="C3", + baseline=self.baseline, + minSigma=self.min_sigma, + alpha=self.alpha, + ) surv = surveillance.earsC(sts, control=control) return surv diff --git a/epysurv/models/timepoint/farrington.py b/epysurv/models/timepoint/farrington.py index 5cb7c99..44ed429 100644 --- a/epysurv/models/timepoint/farrington.py +++ b/epysurv/models/timepoint/farrington.py @@ -3,9 +3,9 @@ from rpy2.robjects import r from rpy2.robjects.packages import importr -from ._base import STSBasedAlgorithm, DisProgBasedAlgorithm +from ._base import DisProgBasedAlgorithm, STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -47,6 +47,7 @@ class Farrington(DisProgBasedAlgorithm): [1] Farrington, C.P., Andrews, N.J, Beale A.D. and Catchpole, M.A. (1996): A statistical algorithm for the early detection of outbreaks of infectious disease. J. R. Statist. Soc. A, 159, 547-563. """ + years_back: int = 3 window_half_width: int = 3 reweight: bool = True @@ -54,17 +55,19 @@ class Farrington(DisProgBasedAlgorithm): trend: bool = True past_period_cutoff: int = 4 min_cases_in_past_periods: int = 5 - power_transform: str = '2/3' + power_transform: str = "2/3" def _call_surveillance_algo(self, disprog_obj, detection_range): - control = r.list(range=detection_range, - b=self.years_back, - w=self.window_half_width, - reweight=self.reweight, - alpha=self.alpha, - trend=self.trend, - limit54=r.c(self.min_cases_in_past_periods, self.past_period_cutoff), - powertrans=self.power_transform) + control = r.list( + range=detection_range, + b=self.years_back, + w=self.window_half_width, + reweight=self.reweight, + alpha=self.alpha, + trend=self.trend, + limit54=r.c(self.min_cases_in_past_periods, self.past_period_cutoff), + powertrans=self.power_transform, + ) surv = surveillance.algo_farrington(disprog_obj, control=control) return surv @@ -140,23 +143,25 @@ class FarringtonFlexible(STSBasedAlgorithm): trend_threshold: float = 0.05 past_period_cutoff: int = 4 min_cases_in_past_periods: int = 5 - power_transform: str = '2/3' + power_transform: str = "2/3" past_weeks_not_included: int = 26 - threshold_method: str = 'delta' + threshold_method: str = "delta" def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - b=self.years_back, - w=self.window_half_width, - reweight=self.reweight, - weightsThreshold=self.weights_threshold, - alpha=self.alpha, - trend=self.trend, - trend_threshold=self.trend_threshold, - limit54=r.c(self.min_cases_in_past_periods, self.past_period_cutoff), - powertrans=self.power_transform, - pastWeeksNotIncluded=self.past_weeks_not_included, - thresholdMethod=self.threshold_method) + control = r.list( + range=detection_range, + b=self.years_back, + w=self.window_half_width, + reweight=self.reweight, + weightsThreshold=self.weights_threshold, + alpha=self.alpha, + trend=self.trend, + trend_threshold=self.trend_threshold, + limit54=r.c(self.min_cases_in_past_periods, self.past_period_cutoff), + powertrans=self.power_transform, + pastWeeksNotIncluded=self.past_weeks_not_included, + thresholdMethod=self.threshold_method, + ) surv = surveillance.farringtonFlexible(sts, control=control) return surv diff --git a/epysurv/models/timepoint/glr.py b/epysurv/models/timepoint/glr.py index 2e5e948..83e2574 100644 --- a/epysurv/models/timepoint/glr.py +++ b/epysurv/models/timepoint/glr.py @@ -1,15 +1,15 @@ """Count data regression charts for the monitoring of surveillance time series as proposed by Höhle and Paul (2008). The implementation is described in Salmon et al. (2016).""" from dataclasses import dataclass, field -from rpy2 import robjects +from typing import Tuple, Union +from rpy2 import robjects from rpy2.robjects import r from rpy2.robjects.packages import importr -from typing import Tuple, Union from ._base import STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -58,28 +58,32 @@ class GLRNegativeBinomial(STSBasedAlgorithm): detection in public health surveillance. Journal of Statistical Software, 70 (10), 1-35. doi: 10.18637/jss.v070.i10 """ + alpha: float = 0 glr_test_threshold: int = 5 m: int = -1 - change: str = 'intercept' - direction: Union[Tuple[str, str], Tuple[str]] = ('inc', 'dec') - upperbound_statistic: str = 'cases' + change: str = "intercept" + direction: Union[Tuple[str, str], Tuple[str]] = ("inc", "dec") + upperbound_statistic: str = "cases" x_max: float = 1e4 def _call_surveillance_algo(self, sts, detection_range): - control = r.list(**{'range': detection_range, - 'c.ARL': self.glr_test_threshold, - 'm0': robjects.NULL, - 'alpha': self.alpha, - # Mtilde is set to 1, since that is the only valid value for "epi" and "intercept" - 'Mtilde': 1, - 'M': self.m, - 'change': self.change, - 'theta': robjects.NULL, - 'dir': r.c(*self.direction), - 'ret': self.upperbound_statistic, - 'xMax': self.x_max - }) + control = r.list( + **{ + "range": detection_range, + "c.ARL": self.glr_test_threshold, + "m0": robjects.NULL, + "alpha": self.alpha, + # Mtilde is set to 1, since that is the only valid value for "epi" and "intercept" + "Mtilde": 1, + "M": self.m, + "change": self.change, + "theta": robjects.NULL, + "dir": r.c(*self.direction), + "ret": self.upperbound_statistic, + "xMax": self.x_max, + } + ) surv = surveillance.glrnb(sts, control=control) return surv @@ -116,30 +120,34 @@ class GLRPoisson(STSBasedAlgorithm): detection in public health surveillance. Journal of Statistical Software, 70 (10), 1-35. doi: 10.18637/jss.v070.i10 """ + glr_test_threshold: int = 5 """threshold in the GLR test, i.e. cγ.""" m: int = -1 """number of time instances back in time in the window-limited approach, i.e. the last value considered is max 1, n − M. To always look back until the first observation use M=-1.""" - change: str = 'intercept' + change: str = "intercept" """a string specifying the type of the alternative. Currently the two choices are intercept and epi. See the SFB Discussion Paper 500 for details""" - direction: Union[Tuple[str, str], Tuple[str]] = ('inc', 'dec') + direction: Union[Tuple[str, str], Tuple[str]] = ("inc", "dec") """Specifying the direction of testing in GLR scheme. With "inc" only increases in x are considered in the GLR-statistic, with "dec" decreases are regarded.""" - upperbound_statistic: str = 'cases' + upperbound_statistic: str = "cases" """a string specifying the type of upperbound-statistic that is returned. With "cases" the number of cases that would have been necessary to produce an alarm or with "value" the GLR-statistic is computed (see below)""" def _call_surveillance_algo(self, sts, detection_range): - control = r.list(**{'range': detection_range, - 'c.ARL': self.glr_test_threshold, - 'm0': robjects.NULL, - # Mtilde is set to 1, since that is the only valid value for "epi" and "intercept" - 'Mtilde': 1, - 'M': self.m, - 'change': self.change, - # Role of theta: If NULL then the GLR scheme is used. If not NULL the prespecified value for κ or λ is used in a recursive LR scheme, which is faster.""" - 'theta': robjects.NULL, - 'dir': r.c(*self.direction), - 'ret': self.upperbound_statistic, - }) + control = r.list( + **{ + "range": detection_range, + "c.ARL": self.glr_test_threshold, + "m0": robjects.NULL, + # Mtilde is set to 1, since that is the only valid value for "epi" and "intercept" + "Mtilde": 1, + "M": self.m, + "change": self.change, + # Role of theta: If NULL then the GLR scheme is used. If not NULL the prespecified value for κ or λ is used in a recursive LR scheme, which is faster.""" + "theta": robjects.NULL, + "dir": r.c(*self.direction), + "ret": self.upperbound_statistic, + } + ) surv = surveillance.glrpois(sts, control=control) return surv diff --git a/epysurv/models/timepoint/hmm.py b/epysurv/models/timepoint/hmm.py index 8b26d76..832ec29 100644 --- a/epysurv/models/timepoint/hmm.py +++ b/epysurv/models/timepoint/hmm.py @@ -1,10 +1,11 @@ from dataclasses import dataclass + from rpy2.robjects import r from rpy2.robjects.packages import importr from ._base import DisProgBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -36,6 +37,7 @@ class HMM(DisProgBasedAlgorithm): [2] I.L. MacDonald and W. Zucchini, Hidden Markov and Other Models for Discrete-valued Time Series, (1997), Chapman & Hall, Monographs on Statistics and applied Probability 70 """ + n_observations: int = -1 n_hidden_states: int = 2 trend: bool = True @@ -43,11 +45,13 @@ class HMM(DisProgBasedAlgorithm): equal_covariate_effects: bool = False def _call_surveillance_algo(self, disprog_obj, detection_range): - control = r.list(range=detection_range, - Mtilde=self.n_observations, - noStates=self.n_hidden_states, - trend=self.trend, - noHarmonics=self.n_harmonics, - covEffectEqual=self.equal_covariate_effects) + control = r.list( + range=detection_range, + Mtilde=self.n_observations, + noStates=self.n_hidden_states, + trend=self.trend, + noHarmonics=self.n_harmonics, + covEffectEqual=self.equal_covariate_effects, + ) surv = surveillance.algo_hmm(disprog_obj, control=control) return surv diff --git a/epysurv/models/timepoint/outbreak_p.py b/epysurv/models/timepoint/outbreak_p.py index fafd985..d492e0d 100644 --- a/epysurv/models/timepoint/outbreak_p.py +++ b/epysurv/models/timepoint/outbreak_p.py @@ -1,10 +1,11 @@ from dataclasses import dataclass + from rpy2.robjects import r from rpy2.robjects.packages import importr from ._base import STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -27,14 +28,17 @@ class OutbreakP(STSBasedAlgorithm): [1] Frisén, M., Andersson and Schiöler, L., (2009), Robust outbreak surveillance of epidemics in Sweden, Statistics in Medicine, 28(3):476-493. [2] Frisén, M. and Andersson, E., (2009) Semiparametric Surveillance of Monotonic Changes, Sequential Analysis 28(4):434-454. """ + threshold: int = 100 - upperbound_statistic: str = 'cases' + upperbound_statistic: str = "cases" max_upperbound_cases: int = 100_000 def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - k=self.threshold, - ret=self.upperbound_statistic, - maxUpperboundCases=self.max_upperbound_cases) + control = r.list( + range=detection_range, + k=self.threshold, + ret=self.upperbound_statistic, + maxUpperboundCases=self.max_upperbound_cases, + ) surv = surveillance.outbreakP(sts, control=control) return surv diff --git a/epysurv/models/timepoint/rki.py b/epysurv/models/timepoint/rki.py index 3c6f1c2..d776764 100644 --- a/epysurv/models/timepoint/rki.py +++ b/epysurv/models/timepoint/rki.py @@ -5,7 +5,7 @@ from ._base import STSBasedAlgorithm -surveillance = importr('surveillance') +surveillance = importr("surveillance") @dataclass @@ -21,15 +21,18 @@ class RKI(STSBasedAlgorithm): include_recent_year Is a boolean to decide if the year of timePoint also contributes w reference values. """ + years_back: int = 0 window_half_width: int = 6 include_recent_year: bool = True def _call_surveillance_algo(self, sts, detection_range): - control = r.list(range=detection_range, - b=self.years_back, - w=self.window_half_width, - actY=self.include_recent_year) + control = r.list( + range=detection_range, + b=self.years_back, + w=self.window_half_width, + actY=self.include_recent_year, + ) surv = surveillance.rki(sts, control=control) return surv diff --git a/epysurv/models/timeseries/__init__.py b/epysurv/models/timeseries/__init__.py index 6361b18..52104a1 100644 --- a/epysurv/models/timeseries/__init__.py +++ b/epysurv/models/timeseries/__init__.py @@ -1,16 +1,18 @@ from .convert_interface import * -__all__ = ['Bayes', - 'Boda', - 'CDC', - 'Cusum', - 'EarsC1', - 'EarsC2', - 'EarsC3', - 'FarringtonFlexible', - 'Farrington', - 'GLRNegativeBinomial', - 'GLRPoisson', - 'HMM', - 'OutbreakP', - 'RKI'] +__all__ = [ + "Bayes", + "Boda", + "CDC", + "Cusum", + "EarsC1", + "EarsC2", + "EarsC3", + "FarringtonFlexible", + "Farrington", + "GLRNegativeBinomial", + "GLRPoisson", + "HMM", + "OutbreakP", + "RKI", +] diff --git a/epysurv/models/timeseries/_base.py b/epysurv/models/timeseries/_base.py index 997f572..baf36ba 100644 --- a/epysurv/models/timeseries/_base.py +++ b/epysurv/models/timeseries/_base.py @@ -3,7 +3,6 @@ class NonLearningTimeseriesClassificationMixin: - def fit(self, data_generator): """These types of algorithms do not learn from previous time series.""" pass @@ -15,10 +14,11 @@ def predict(self, data_generator) -> pd.DataFrame: # Fit on all data, except the last point, that is to be predicted. super().fit(X.iloc[:-1]) prediction = super().predict( - X.iloc[[-1]]) # Use inner brackets to get dytpe preserving frame and not series. + X.iloc[[-1]] + ) # Use inner brackets to get dytpe preserving frame and not series. # As only a single value should be returned, we can access this single item. [alarm] = prediction.alarm [time] = prediction.index alarms.append(alarm) times.append(time) - return pd.DataFrame({'alarm': alarms}, index=pd.DatetimeIndex(times)) + return pd.DataFrame({"alarm": alarms}, index=pd.DatetimeIndex(times)) diff --git a/epysurv/models/timeseries/convert_interface.py b/epysurv/models/timeseries/convert_interface.py index 001e056..6c43cc2 100644 --- a/epysurv/models/timeseries/convert_interface.py +++ b/epysurv/models/timeseries/convert_interface.py @@ -7,10 +7,11 @@ for name, obj in vars(timepoint).items(): try: if issubclass(obj, timepoint._base.TimepointSurveillanceAlgorithm): - globals()[name] = type(name, (NonLearningTimeseriesClassificationMixin, obj), {}) + globals()[name] = type( + name, (NonLearningTimeseriesClassificationMixin, obj), {} + ) __all__.append(name) except TypeError: continue # print(__all__) - diff --git a/epysurv/simulation/naive_poisson.py b/epysurv/simulation/naive_poisson.py index 4949a63..4e2790f 100644 --- a/epysurv/simulation/naive_poisson.py +++ b/epysurv/simulation/naive_poisson.py @@ -1,7 +1,8 @@ -from typing import Set import random -import pandas as pd +from typing import Set + import numpy as np +import pandas as pd from scipy import stats @@ -20,8 +21,13 @@ def get_outbreak_begins(n: int, outbreak_length: int, n_outbreaks: int) -> Set[i return outbreaks_starts -def simulate_outbreaks(n: int = 104, outbreak_length: int = 5, n_outbreaks: int = 3, mu: float = 1, - outbreak_mu: float = 10) -> pd.DataFrame: +def simulate_outbreaks( + n: int = 104, + outbreak_length: int = 5, + n_outbreaks: int = 3, + mu: float = 1, + outbreak_mu: float = 10, +) -> pd.DataFrame: """Simulate outbreaks based on Poisson distribution. Parameters @@ -49,15 +55,21 @@ def simulate_outbreaks(n: int = 104, outbreak_length: int = 5, n_outbreaks: int for start in outbreaks_starts: outbreak_cases = stats.poisson.rvs(mu=outbreak_mu, size=outbreak_length) - outbreak_cases += (outbreak_cases == 0) # Ensure that there is a least on case during the outbreak. - n_cases[start: start + outbreak_length] += outbreak_cases - n_outbreak_cases[start: start + outbreak_length] = outbreak_cases + outbreak_cases += ( + outbreak_cases == 0 + ) # Ensure that there is a least on case during the outbreak. + n_cases[start : start + outbreak_length] += outbreak_cases + n_outbreak_cases[start : start + outbreak_length] = outbreak_cases - data = pd.DataFrame({'n_cases': n_cases, - 'n_outbreak_cases': n_outbreak_cases, - 'outbreak': n_outbreak_cases > 0, - 'baseline': baseline, - }, index=pd.date_range(start='2020', periods=baseline.size, freq='W-MON')) + data = pd.DataFrame( + { + "n_cases": n_cases, + "n_outbreak_cases": n_outbreak_cases, + "outbreak": n_outbreak_cases > 0, + "baseline": baseline, + }, + index=pd.date_range(start="2020", periods=baseline.size, freq="W-MON"), + ) - data.index.name = 'date' + data.index.name = "date" return data diff --git a/epysurv/visualization/model_diagnostics.py b/epysurv/visualization/model_diagnostics.py index a9a293d..2fad19b 100644 --- a/epysurv/visualization/model_diagnostics.py +++ b/epysurv/visualization/model_diagnostics.py @@ -1,13 +1,14 @@ -import pandas as pd -import seaborn as sns -import numpy as np import matplotlib import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import plotnine as gg +import seaborn as sns -def plot_confusion_matrix(confusion_matrix: np.ndarray, class_names: list, - ax: matplotlib.axes.Axes = None) -> matplotlib.axes.Axes: +def plot_confusion_matrix( + confusion_matrix: np.ndarray, class_names: list, ax: matplotlib.axes.Axes = None +) -> matplotlib.axes.Axes: """Plots a confusion matrix, as returned by sklearn.metrics.confusion_matrix, as a heatmap. Based on https://gist.github.com/shaypal5/94c53d765083101efc0240d776a23823 @@ -28,32 +29,43 @@ def plot_confusion_matrix(confusion_matrix: np.ndarray, class_names: list, The resulting confusion matrix figure """ - df_cm = pd.DataFrame( - confusion_matrix, index=class_names, columns=class_names, - ) + df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names) if ax is None: fig, ax = plt.subplots() heatmap = sns.heatmap(df_cm, annot=True, cmap="Blues", ax=ax) - heatmap.set(ylabel='True label', xlabel='Predicted label') + heatmap.set(ylabel="True label", xlabel="Predicted label") return ax -def plot_prediction(train_data, test_data, prediction, ax: matplotlib.axes.Axes = None) -> matplotlib.axes.Axes: +def plot_prediction( + train_data, test_data, prediction, ax: matplotlib.axes.Axes = None +) -> matplotlib.axes.Axes: """Plots case counts as step line, with outbreaks and alarms indicated by triangles.""" whole_data = pd.concat((train_data, test_data), sort=False) fontsize = 20 if ax is None: fig, ax = plt.subplots(figsize=(12, 8)) - ax.step(x=whole_data.index, y=whole_data.n_cases, where='mid', - color='blue', label='_nolegend_') - alarms = prediction.query('alarm == 1') - ax.plot(alarms.index, [0] * len(alarms), 'g^', label='alarm', markersize=12) - outbreaks = test_data.query('outbreak') - ax.plot(outbreaks.index, outbreaks.n_outbreak_cases, 'rv', label='outbreak', markersize=12) - ax.set_xlabel('time', fontsize=fontsize) - ax.set_ylabel('cases', fontsize=fontsize) - ax.legend(fontsize='xx-large') + ax.step( + x=whole_data.index, + y=whole_data.n_cases, + where="mid", + color="blue", + label="_nolegend_", + ) + alarms = prediction.query("alarm == 1") + ax.plot(alarms.index, [0] * len(alarms), "g^", label="alarm", markersize=12) + outbreaks = test_data.query("outbreak") + ax.plot( + outbreaks.index, + outbreaks.n_outbreak_cases, + "rv", + label="outbreak", + markersize=12, + ) + ax.set_xlabel("time", fontsize=fontsize) + ax.set_ylabel("cases", fontsize=fontsize) + ax.legend(fontsize="xx-large") return ax @@ -69,41 +81,70 @@ def ghozzi_score_plot(prediction_result: pd.DataFrame, filename: str): File name to write the plot to. """ # Outbreaks that were recognized. - prediction_result['weighted_true_positives'] = (prediction_result.alarm - * prediction_result.outbreak - * prediction_result.n_outbreak_cases) + prediction_result["weighted_true_positives"] = ( + prediction_result.alarm + * prediction_result.outbreak + * prediction_result.n_outbreak_cases + ) # Outbreaks that were missed. - prediction_result['weighted_false_negatives'] = ((1 - prediction_result.alarm) - * prediction_result.outbreak - * prediction_result.n_outbreak_cases) + prediction_result["weighted_false_negatives"] = ( + (1 - prediction_result.alarm) + * prediction_result.outbreak + * prediction_result.n_outbreak_cases + ) # Alarms that were falsely raised. - prediction_result['weighted_false_positives'] = (prediction_result.alarm - * (prediction_result.outbreak != prediction_result.alarm) - * np.mean(prediction_result.query('outbreak').n_outbreak_cases)) + prediction_result["weighted_false_positives"] = ( + prediction_result.alarm + * (prediction_result.outbreak != prediction_result.alarm) + * np.mean(prediction_result.query("outbreak").n_outbreak_cases) + ) melted_prediction_result = ( - prediction_result - .reset_index() - .rename(columns={'index': 'date'}) - .melt(id_vars=['date', 'county', 'pathogen', 'n_cases', 'n_outbreak_cases', 'outbreak', 'alarm'], - var_name='prediction', value_name='weighting') + prediction_result.reset_index() + .rename(columns={"index": "date"}) + .melt( + id_vars=[ + "date", + "county", + "pathogen", + "n_cases", + "n_outbreak_cases", + "outbreak", + "alarm", + ], + var_name="prediction", + value_name="weighting", + ) ) - case_color = 'grey' + case_color = "grey" n_cols = 4 - n_filter_combinations = len(prediction_result[['county', 'pathogen']].drop_duplicates()) - - chart = (gg.ggplot(melted_prediction_result, gg.aes(x='date')) - + gg.geom_bar(prediction_result, gg.aes(x='prediction_result.index', y='n_cases'), fill=case_color, - stat='identity') - + gg.geom_line(gg.aes(y=0), color=case_color) - + gg.geom_bar(gg.aes(y='weighting', fill='prediction'), stat='identity') - + gg.facet_wrap(['county', 'pathogen'], ncol=n_cols) - + gg.scale_x_date(date_breaks='4 month', date_labels='%Y-%m') - + gg.ylab('# cases') - + gg.scale_fill_manual(name='weighting', values=['red', 'orange', 'green']) - + gg.theme(panel_grid_minor=gg.element_blank()) - + gg.theme_light() - ) - chart.save(filename, width=5 * n_cols, height=4 * n_filter_combinations / n_cols, unit='cm', limitsize=False) + n_filter_combinations = len( + prediction_result[["county", "pathogen"]].drop_duplicates() + ) + + chart = ( + gg.ggplot(melted_prediction_result, gg.aes(x="date")) + + gg.geom_bar( + prediction_result, + gg.aes(x="prediction_result.index", y="n_cases"), + fill=case_color, + stat="identity", + ) + + gg.geom_line(gg.aes(y=0), color=case_color) + + gg.geom_bar(gg.aes(y="weighting", fill="prediction"), stat="identity") + + gg.facet_wrap(["county", "pathogen"], ncol=n_cols) + + gg.scale_x_date(date_breaks="4 month", date_labels="%Y-%m") + + gg.ylab("# cases") + + gg.scale_fill_manual(name="weighting", values=["red", "orange", "green"]) + + gg.theme(panel_grid_minor=gg.element_blank()) + + gg.theme_light() + ) + chart.save( + filename, + width=5 * n_cols, + height=4 * n_filter_combinations / n_cols, + unit="cm", + limitsize=False, + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..882578b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[tool.black] +line-length = 88 +target-version = ['py37'] +include = '\.pyi?$' +exclude = ''' +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ +) +''' + +[tool.isort] +known_first_party = 'epysurv' +line_length = 88 +multi_line_output = 3 +include_trailing_comma = "true" \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 06312af..52e0efe 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,3 +4,5 @@ pytest-datadir pytest-mypy pytest-cov codecov +black +pre-commit diff --git a/run.py b/run.py index ddf922d..d6e85cd 100644 --- a/run.py +++ b/run.py @@ -1,10 +1,13 @@ -import pandas as pd import matplotlib.pyplot as plt +import pandas as pd + from epysurv.models import timepoint def read_surv_data(filename): - data = pd.read_csv(filename, index_col=0, parse_dates=True, infer_datetime_format=True) + data = pd.read_csv( + filename, index_col=0, parse_dates=True, infer_datetime_format=True + ) data.index.freq = pd.infer_freq(data.index) return data @@ -13,17 +16,22 @@ def plot_prediction(train_data, prediction): prediction = pd.concat((train_data, prediction), sort=False) fontsize = 20 fig, ax = plt.subplots(figsize=(12, 8)) - ax.step(x=prediction.index, y=prediction.n_cases, where='mid', - color='blue', label='_nolegend_') - outbreaks = prediction.query('alarm == 1') - ax.plot(outbreaks.index, outbreaks.n_cases, 'r^', label='alarm', markersize=12) - ax.set_xlabel('time', fontsize=fontsize) - ax.set_ylabel('cases', fontsize=fontsize) - ax.legend(fontsize='xx-large') - - -data_train = read_surv_data('tests/data/salmonella_train.csv') -data_test = read_surv_data('tests/data/salmonella_test.csv') + ax.step( + x=prediction.index, + y=prediction.n_cases, + where="mid", + color="blue", + label="_nolegend_", + ) + outbreaks = prediction.query("alarm == 1") + ax.plot(outbreaks.index, outbreaks.n_cases, "r^", label="alarm", markersize=12) + ax.set_xlabel("time", fontsize=fontsize) + ax.set_ylabel("cases", fontsize=fontsize) + ax.legend(fontsize="xx-large") + + +data_train = read_surv_data("tests/data/salmonella_train.csv") +data_test = read_surv_data("tests/data/salmonella_test.csv") algos = [ # timepoint.EarsC1, @@ -46,7 +54,7 @@ def plot_prediction(train_data, prediction): model = Algo() model.fit(data_train) pred = model.predict(data_test) - pred.to_csv(f'tests/data/{model.__class__.__name__}_pred.csv') + pred.to_csv(f"tests/data/{model.__class__.__name__}_pred.csv") plot_prediction(data_train, pred) plt.title(model.__class__.__name__) plt.show() diff --git a/setup.py b/setup.py index 1b2ce08..7d3d16a 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -from setuptools import setup, find_packages import os +from setuptools import find_packages, setup + import versioneer package = "epysurv" diff --git a/tests/conftest.py b/tests/conftest.py index a9c0716..597815a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,55 +1,73 @@ -import pytest import pickle from collections import namedtuple + import pandas as pd -from epysurv import data +import pytest -from epysurv.data.filter_combination import FilterCombination, SplitYears, TimeseriesClassificationData +from epysurv import data +from epysurv.data.filter_combination import ( + FilterCombination, + SplitYears, + TimeseriesClassificationData, +) -TSCGenerator = namedtuple('TSCGenerator', 'train_gen test_gen') +TSCGenerator = namedtuple("TSCGenerator", "train_gen test_gen") @pytest.fixture def train_data(shared_datadir): - data = pd.read_csv(shared_datadir / 'salmonella_train.csv', index_col=0, parse_dates=True, - infer_datetime_format=True) + data = pd.read_csv( + shared_datadir / "salmonella_train.csv", + index_col=0, + parse_dates=True, + infer_datetime_format=True, + ) data.index.freq = pd.infer_freq(data.index) return data @pytest.fixture def test_data(shared_datadir): - data = pd.read_csv(shared_datadir / 'salmonella_test.csv', index_col=0, parse_dates=True, - infer_datetime_format=True) + data = pd.read_csv( + shared_datadir / "salmonella_test.csv", + index_col=0, + parse_dates=True, + infer_datetime_format=True, + ) data.index.freq = pd.infer_freq(data.index) return data @pytest.fixture def tsc_generator(train_data, test_data): - train_generator, test_generator = data.timeseries_classifaction_generator(train_data, test_data, - offset_in_weeks=5 * 52) + train_generator, test_generator = data.timeseries_classifaction_generator( + train_data, test_data, offset_in_weeks=5 * 52 + ) return TSCGenerator(train_generator, test_generator) @pytest.fixture def filter_combination(shared_datadir): - with open(shared_datadir / 'cases.pickle', 'rb') as handle: + with open(shared_datadir / "cases.pickle", "rb") as handle: cases = pickle.load(handle) cases_in_berlin = cases.query('county == "Berlin"') - return FilterCombination(disease='SAL', county='Berlin', pathogen='SAL', data=cases_in_berlin) + return FilterCombination( + disease="SAL", county="Berlin", pathogen="SAL", data=cases_in_berlin + ) @pytest.fixture def expanding_windows(filter_combination): - tsc_data = filter_combination.expanding_windows(min_len_in_weeks=104, - split_years=SplitYears.from_ts_input('2005', '2009', '2011')) + tsc_data = filter_combination.expanding_windows( + min_len_in_weeks=104, + split_years=SplitYears.from_ts_input("2005", "2009", "2011"), + ) return tsc_data -@pytest.fixture(params=['timeseries', 'cases']) +@pytest.fixture(params=["timeseries", "cases"]) def tsc_data(request, tsc_generator, expanding_windows): - if request.param == 'timeseries': + if request.param == "timeseries": return tsc_generator - elif request.param == 'cases': + elif request.param == "cases": return expanding_windows diff --git a/tests/test_base.py b/tests/test_base.py index 4655e19..7ddf557 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,34 +1,33 @@ -import pytest import numpy as np import pandas as pd +import pytest + from epysurv.models.timepoint import _base def random_cases_for_dates(dates): - return {'n_cases': np.random.randint(low=5, high=10, size=len(dates)), - 'n_outbreak_cases': np.random.randint(low=0, high=5, size=len(dates))} + return { + "n_cases": np.random.randint(low=5, high=10, size=len(dates)), + "n_outbreak_cases": np.random.randint(low=0, high=5, size=len(dates)), + } def test_data_in_the_future(): model = _base.TimepointSurveillanceAlgorithm() - dates = pd.date_range(start='2011', end='2011-12-31', freq='W-Mon') - future_dates = pd.date_range(start='2012-01-01', end='2013') - train_data = pd.DataFrame(random_cases_for_dates(dates), - index=dates) - test_data = pd.DataFrame(random_cases_for_dates(future_dates), - index=future_dates) + dates = pd.date_range(start="2011", end="2011-12-31", freq="W-Mon") + future_dates = pd.date_range(start="2012-01-01", end="2013") + train_data = pd.DataFrame(random_cases_for_dates(dates), index=dates) + test_data = pd.DataFrame(random_cases_for_dates(future_dates), index=future_dates) model.fit(train_data) model.predict(test_data) def test_data_in_the_future_should_raise(): model = _base.TimepointSurveillanceAlgorithm() - dates = pd.date_range(start='2011', end='2011-12-31', freq='W-Mon') - future_dates = pd.date_range(start='2011-12-01', end='2013') - train_data = pd.DataFrame(random_cases_for_dates(dates), - index=dates) - test_data = pd.DataFrame(random_cases_for_dates(future_dates), - index=future_dates) + dates = pd.date_range(start="2011", end="2011-12-31", freq="W-Mon") + future_dates = pd.date_range(start="2011-12-01", end="2013") + train_data = pd.DataFrame(random_cases_for_dates(dates), index=dates) + test_data = pd.DataFrame(random_cases_for_dates(future_dates), index=future_dates) model.fit(train_data) with pytest.raises(ValueError): model.predict(test_data) @@ -36,10 +35,10 @@ def test_data_in_the_future_should_raise(): def test_fit(): model = _base.TimepointSurveillanceAlgorithm() - dates = pd.date_range(start='2011', freq='W-Mon', periods=2) - train_data = pd.DataFrame({'n_cases': [5, 6], - 'n_outbreak_cases': [0, 1]}, - index=dates) + dates = pd.date_range(start="2011", freq="W-Mon", periods=2) + train_data = pd.DataFrame( + {"n_cases": [5, 6], "n_outbreak_cases": [0, 1]}, index=dates + ) train_data_before_fit = train_data.copy() model.fit(train_data) pd.testing.assert_frame_equal(train_data, train_data_before_fit) diff --git a/tests/test_data.py b/tests/test_data.py index d3822d6..dcc66d1 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -30,18 +30,20 @@ def test_train_ends_before_test(tsc_data): def test_train_frequency(tsc_data): for frame, y in tsc_data.train_gen: - assert pd.infer_freq(frame.index) == 'W-MON' + assert pd.infer_freq(frame.index) == "W-MON" def test_test_frequency(tsc_data): for frame, y in tsc_data.test_gen: - assert pd.infer_freq(frame.index) == 'W-MON' + assert pd.infer_freq(frame.index) == "W-MON" def test_raises_on_early_start(filter_combination): - with raises(ValueError, match='The start date'): - filter_combination.expanding_windows(min_len_in_weeks=104, - split_years=SplitYears.from_ts_input('1999', '2009', '2011')) + with raises(ValueError, match="The start date"): + filter_combination.expanding_windows( + min_len_in_weeks=104, + split_years=SplitYears.from_ts_input("1999", "2009", "2011"), + ) # TODO: reactivate later @@ -52,11 +54,13 @@ def test_raises_on_early_start(filter_combination): def test_raises_on_high_offset(filter_combination): - with raises(ValueError, match='The start date plus the offset'): - filter_combination.expanding_windows(min_len_in_weeks=500, - split_years=SplitYears.from_ts_input('2005', '2009', '2011')) + with raises(ValueError, match="The start date plus the offset"): + filter_combination.expanding_windows( + min_len_in_weeks=500, + split_years=SplitYears.from_ts_input("2005", "2009", "2011"), + ) def test_split_year_order(): - with raises(ValueError, match='consecutive'): - SplitYears.from_ts_input('2011', '2012', '2010') + with raises(ValueError, match="consecutive"): + SplitYears.from_ts_input("2011", "2012", "2010") diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fe35161..88b6a44 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,38 +1,35 @@ import pandas as pd import pytest +from pytest import approx from epysurv.metrics import ghozzi_score -from pytest import approx @pytest.fixture def prediction_results(): - return pd.DataFrame({ - 'outbreak': [1, 1, 0, 0], - 'n_outbreak_cases': [6, 10, 0, 0] - }) + return pd.DataFrame({"outbreak": [1, 1, 0, 0], "n_outbreak_cases": [6, 10, 0, 0]}) def test_ghozzi_score_mixed(prediction_results): - prediction_results['alarm'] = [1, 0, 0, 1] + prediction_results["alarm"] = [1, 0, 0, 1] assert ghozzi_score(prediction_results) == approx((6 - 10 + 0 - 8) / 16) def test_ghozzi_score_always_outbreak(prediction_results): - prediction_results['alarm'] = [1, 1, 1, 1] + prediction_results["alarm"] = [1, 1, 1, 1] assert ghozzi_score(prediction_results) == approx((6 + 10 - 8 - 8) / 16) def test_ghozzi_score_never_outbreak(prediction_results): - prediction_results['alarm'] = [0, 0, 0, 0] + prediction_results["alarm"] = [0, 0, 0, 0] assert ghozzi_score(prediction_results) == approx(-1) def test_ghozzi_score_correct(prediction_results): - prediction_results['alarm'] = prediction_results['outbreak'] + prediction_results["alarm"] = prediction_results["outbreak"] assert ghozzi_score(prediction_results) == approx(1) def test_ghozzi_score_always_incorrect(prediction_results): - prediction_results['alarm'] = (1 - prediction_results['outbreak']) + prediction_results["alarm"] = 1 - prediction_results["outbreak"] assert ghozzi_score(prediction_results) == approx((-6 + -10 - 8 - 8) / 16) diff --git a/tests/test_timepoint_models.py b/tests/test_timepoint_models.py index cb10e7d..7a11baf 100644 --- a/tests/test_timepoint_models.py +++ b/tests/test_timepoint_models.py @@ -1,13 +1,29 @@ -import pytest import pandas as pd +import pytest from pandas.testing import assert_frame_equal -from epysurv.models.timepoint import (EarsC1, EarsC2, EarsC3, Farrington, FarringtonFlexible, Cusum, Bayes, RKI, - GLRNegativeBinomial, GLRPoisson, OutbreakP, CDC, HMM, Boda) +from epysurv.models.timepoint import ( + CDC, + HMM, + RKI, + Bayes, + Boda, + Cusum, + EarsC1, + EarsC2, + EarsC3, + Farrington, + FarringtonFlexible, + GLRNegativeBinomial, + GLRPoisson, + OutbreakP, +) def load_predictions(filepath): - predictions = pd.read_csv(filepath, index_col=0, parse_dates=True, infer_datetime_format=True) + predictions = pd.read_csv( + filepath, index_col=0, parse_dates=True, infer_datetime_format=True + ) return predictions @@ -23,33 +39,33 @@ def load_predictions(filepath): GLRNegativeBinomial, GLRPoisson, OutbreakP, - CDC + CDC, ] -@pytest.mark.parametrize('Algo', algos_to_test) +@pytest.mark.parametrize("Algo", algos_to_test) def test_prediction(train_data, test_data, shared_datadir, Algo): """Regression tests against a change in the prediction behavior.""" model = Algo() model.fit(train_data) pred = model.predict(test_data) - saved_predictions = load_predictions(shared_datadir / f'{Algo.__name__}_pred.csv') + saved_predictions = load_predictions(shared_datadir / f"{Algo.__name__}_pred.csv") assert_frame_equal(pred, saved_predictions) # These algorithms take to long to be tested every time. long_algos_to_test = [ HMM, - Boda # TODO: Boda throws strange error when run in the test. -] + Boda, +] # TODO: Boda throws strange error when run in the test. @pytest.mark.skip -@pytest.mark.parametrize('Algo', long_algos_to_test) +@pytest.mark.parametrize("Algo", long_algos_to_test) def test_long_prediction(train_data, test_data, shared_datadir, Algo): """Regression tests against a change in the prediction behavior.""" model = Algo() model.fit(train_data) pred = model.predict(test_data) - saved_predictions = load_predictions(shared_datadir / f'{Algo.__name__}_pred.csv') + saved_predictions = load_predictions(shared_datadir / f"{Algo.__name__}_pred.csv") assert_frame_equal(pred, saved_predictions) diff --git a/tests/test_timeseries_models.py b/tests/test_timeseries_models.py index 160205c..42cd15d 100644 --- a/tests/test_timeseries_models.py +++ b/tests/test_timeseries_models.py @@ -1,6 +1,7 @@ -import pytest -import pandas as pd import numpy as np +import pandas as pd +import pytest + from epysurv.models.timeseries import Farrington, GLRPoisson # type: ignore @@ -8,9 +9,12 @@ def test_farrington_timeseries_prediciton(tsc_generator, shared_datadir): model = Farrington(alpha=0.1) model.fit(tsc_generator.train_gen) pred = model.predict(tsc_generator.test_gen) - saved_predictions = pd.read_csv(shared_datadir / 'farrington_timeseries_predictions.csv', index_col=0, - parse_dates=True, - infer_datetime_format=True) + saved_predictions = pd.read_csv( + shared_datadir / "farrington_timeseries_predictions.csv", + index_col=0, + parse_dates=True, + infer_datetime_format=True, + ) pred.alarm.plot() saved_predictions.alarm.plot() @@ -19,9 +23,10 @@ def test_farrington_timeseries_prediciton(tsc_generator, shared_datadir): def test_outbreak_case_subtraction(): def test_gen(): - df = pd.DataFrame({'n_cases': np.ones(100) * 5, - 'n_outbreak_cases': np.ones(100) * 3}, - index=pd.date_range('2020', freq='W-MON', periods=100)) + df = pd.DataFrame( + {"n_cases": np.ones(100) * 5, "n_outbreak_cases": np.ones(100) * 3}, + index=pd.date_range("2020", freq="W-MON", periods=100), + ) yield df, True model = Farrington() diff --git a/versioneer.py b/versioneer.py index 64fea1c..8aed693 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,3 @@ - # Version: 0.18 """The Versioneer - like a rocketeer, but for versions. @@ -277,10 +276,7 @@ """ from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser + import errno import json import os @@ -288,6 +284,11 @@ import subprocess import sys +try: + import configparser +except ImportError: + import ConfigParser as configparser + class VersioneerConfig: """Container for Versioneer configuration parameters.""" @@ -308,11 +309,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -325,8 +328,10 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(me)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py) + ) except NameError: pass return root @@ -348,6 +353,7 @@ def get(parser, name): if parser.has_option("versioneer", name): return parser.get("versioneer", name) return None + cfg = VersioneerConfig() cfg.VCS = VCS cfg.style = get(parser, "style") or "" @@ -372,17 +378,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -390,10 +397,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -418,7 +428,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, p.returncode -LONG_VERSION_PY['git'] = ''' +LONG_VERSION_PY[ + "git" +] = ''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -993,7 +1005,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1002,7 +1014,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1010,19 +1022,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -1037,8 +1056,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1046,10 +1064,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1072,17 +1099,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1091,10 +1117,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1105,13 +1133,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -1167,16 +1195,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1205,11 +1239,13 @@ def versions_from_file(filename): contents = f.read() except EnvironmentError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1218,8 +1254,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1251,8 +1286,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1366,11 +1400,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1390,9 +1426,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1415,8 +1455,9 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1470,9 +1511,13 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version(): @@ -1521,6 +1566,7 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in both distutils and setuptools @@ -1553,14 +1599,15 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1581,17 +1628,21 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? try: from py2exe.distutils_buildexe import py2exe as _py2exe # py3 except ImportError: @@ -1610,13 +1661,17 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # we override different "sdist" commands for both environments @@ -1643,8 +1698,10 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + cmds["sdist"] = cmd_sdist return cmds @@ -1699,11 +1756,13 @@ def do_setup(): root = get_root() try: cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except ( + EnvironmentError, + configparser.NoSectionError, + configparser.NoOptionError, + ) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1712,15 +1771,18 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: @@ -1762,8 +1824,10 @@ def do_setup(): else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) + print( + " appending versionfile_source ('%s') to MANIFEST.in" + % cfg.versionfile_source + ) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: