diff --git a/oximachine_featurizer/__init__.py b/oximachine_featurizer/__init__.py index f663644..e238ba7 100644 --- a/oximachine_featurizer/__init__.py +++ b/oximachine_featurizer/__init__.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- +"""Featurization tools for the oxiMachine""" from ._version import get_versions __version__ = get_versions()["version"] del get_versions -from .featurize import FeatureCollector, GetFeatures, featurize \ No newline at end of file +from .featurize import FeatureCollector, GetFeatures, featurize diff --git a/oximachine_featurizer/_version.py b/oximachine_featurizer/_version.py index 072d29e..b07dee4 100644 --- a/oximachine_featurizer/_version.py +++ b/oximachine_featurizer/_version.py @@ -24,10 +24,10 @@ def get_keywords(): # setup.py/versioneer.py will grep for the variable names, so they must # each be defined on a line of their own. _version.py will just call # get_keywords(). - git_refnames = '$Format:%d$' - git_full = '$Format:%H$' - git_date = '$Format:%ci$' - keywords = {'refnames': git_refnames, 'full': git_full, 'date': git_date} + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} return keywords @@ -40,11 +40,11 @@ def get_config(): # these strings are filled in when 'setup.py versioneer' creates # _version.py cfg = VersioneerConfig() - cfg.VCS = 'git' - cfg.style = 'pep440' - cfg.tag_prefix = '' - cfg.parentdir_prefix = '' - cfg.versionfile_source = 'mine_mof_oxstate/_version.py' + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "" + cfg.versionfile_source = "mine_mof_oxstate/_version.py" cfg.verbose = False return cfg @@ -59,17 +59,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -77,30 +78,33 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] if e.errno == errno.ENOENT: continue if verbose: - print('unable to run %s' % dispcmd) + print("unable to run %s" % dispcmd) print(e) return None, None else: if verbose: - print('unable to find command, tried %s' % (commands,)) + print("unable to find command, tried %s" % (commands,)) return None, None stdout = p.communicate()[0].strip() if sys.version_info[0] >= 3: stdout = stdout.decode() if p.returncode != 0: if verbose: - print('unable to run %s (error)' % dispcmd) - print('stdout was %s' % stdout) + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) return None, p.returncode return stdout, p.returncode @@ -117,20 +121,26 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {'version': dirname[len(parentdir_prefix):], - 'full-revisionid': None, - 'dirty': False, 'error': None, 'date': None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print('Tried directories %s but none started with prefix %s' % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") -@register_vcs_handler('git', 'get_keywords') +@register_vcs_handler("git", "get_keywords") def git_get_keywords(versionfile_abs): """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these @@ -139,32 +149,32 @@ def git_get_keywords(versionfile_abs): # _version.py. keywords = {} try: - f = open(versionfile_abs, 'r') + f = open(versionfile_abs, "r") for line in f.readlines(): - if line.strip().startswith('git_refnames ='): + if line.strip().startswith("git_refnames ="): mo = re.search(r'=\s*"(.*)"', line) if mo: - keywords['refnames'] = mo.group(1) - if line.strip().startswith('git_full ='): + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): mo = re.search(r'=\s*"(.*)"', line) if mo: - keywords['full'] = mo.group(1) - if line.strip().startswith('git_date ='): + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): mo = re.search(r'=\s*"(.*)"', line) if mo: - keywords['date'] = mo.group(1) + keywords["date"] = mo.group(1) f.close() except EnvironmentError: pass return keywords -@register_vcs_handler('git', 'keywords') +@register_vcs_handler("git", "keywords") def git_versions_from_keywords(keywords, tag_prefix, verbose): """Get version information from git keywords.""" if not keywords: - raise NotThisMethod('no keywords at all, weird') - date = keywords.get('date') + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") if date is not None: # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 @@ -172,17 +182,17 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # it's been around since git-1.5.3, and it's too difficult to # discover which version we're using, or to work around using an # older one. - date = date.strip().replace(' ', 'T', 1).replace(' ', '', 1) - refnames = keywords['refnames'].strip() - if refnames.startswith('$Format'): + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): if verbose: - print('keywords are unexpanded, not using') - raise NotThisMethod('unexpanded keywords, not a git-archive tarball') - refs = set([r.strip() for r in refnames.strip('()').split(',')]) + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = 'tag: ' - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + TAG = "tag: " + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -191,30 +201,37 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: - print("discarding '%s', no digits" % ','.join(refs - tags)) + print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: - print('likely tags: %s' % ','.join(sorted(tags))) + print("likely tags: %s" % ",".join(sorted(tags))) for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: - print('picking %s' % r) - return {'version': r, - 'full-revisionid': keywords['full'].strip(), - 'dirty': False, 'error': None, - 'date': date} + print("picking %s" % r) + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: - print('no suitable tags, using unknown + full revision id') - return {'version': '0+unknown', - 'full-revisionid': keywords['full'].strip(), - 'dirty': False, 'error': 'no suitable tags', 'date': None} + print("no suitable tags, using unknown + full revision id") + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } -@register_vcs_handler('git', 'pieces_from_vcs') +@register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): """Get version from 'git describe' in the root of the source tree. @@ -222,56 +239,63 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): expanded, and _version.py hasn't already been rewritten with a short version string, meaning we're inside a checked out source tree. """ - GITS = ['git'] - if sys.platform == 'win32': - GITS = ['git.cmd', 'git.exe'] + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ['rev-parse', '--git-dir'], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: - print('Directory %s not under git control' % root) + print("Directory %s not under git control" % root) raise NotThisMethod("'git rev-parse --git-dir' returned error") # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ['describe', '--tags', '--dirty', - '--always', '--long', - '--match', '%s*' % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ['rev-parse', 'HEAD'], cwd=root) + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() pieces = {} - pieces['long'] = full_out - pieces['short'] = full_out[:7] # maybe improved later - pieces['error'] = None + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out # look for -dirty suffix - dirty = git_describe.endswith('-dirty') - pieces['dirty'] = dirty + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex('-dirty')] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX - if '-' in git_describe: + if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces['error'] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -280,37 +304,39 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces['error'] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces['closest-tag'] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag - pieces['distance'] = int(mo.group(2)) + pieces["distance"] = int(mo.group(2)) # commit: short hex revision ID - pieces['short'] = mo.group(3) + pieces["short"] = mo.group(3) else: # HEX: no tags - pieces['closest-tag'] = None - count_out, rc = run_command(GITS, ['rev-list', 'HEAD', '--count'], - cwd=root) - pieces['distance'] = int(count_out) # total number of commits + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) + pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ['show', '-s', '--format=%ci', 'HEAD'], - cwd=root)[0].strip() - pieces['date'] = date.strip().replace(' ', 'T', 1).replace(' ', '', 1) + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces def plus_or_dot(pieces): """Return a + if we don't already have one, else return a .""" - if '+' in pieces.get('closest-tag', ''): - return '.' - return '+' + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" def render_pep440(pieces): @@ -322,19 +348,18 @@ def render_pep440(pieces): Exceptions: 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] """ - if pieces['closest-tag']: - rendered = pieces['closest-tag'] - if pieces['distance'] or pieces['dirty']: + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: rendered += plus_or_dot(pieces) - rendered += '%d.g%s' % (pieces['distance'], pieces['short']) - if pieces['dirty']: - rendered += '.dirty' + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" else: # exception #1 - rendered = '0+untagged.%d.g%s' % (pieces['distance'], - pieces['short']) - if pieces['dirty']: - rendered += '.dirty' + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" return rendered @@ -344,13 +369,13 @@ def render_pep440_pre(pieces): Exceptions: 1: no tags. 0.post.devDISTANCE """ - if pieces['closest-tag']: - rendered = pieces['closest-tag'] - if pieces['distance']: - rendered += '.post.dev%d' % pieces['distance'] + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] else: # exception #1 - rendered = '0.post.dev%d' % pieces['distance'] + rendered = "0.post.dev%d" % pieces["distance"] return rendered @@ -364,20 +389,20 @@ def render_pep440_post(pieces): Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ - if pieces['closest-tag']: - rendered = pieces['closest-tag'] - if pieces['distance'] or pieces['dirty']: - rendered += '.post%d' % pieces['distance'] - if pieces['dirty']: - rendered += '.dev0' + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" rendered += plus_or_dot(pieces) - rendered += 'g%s' % pieces['short'] + rendered += "g%s" % pieces["short"] else: # exception #1 - rendered = '0.post%d' % pieces['distance'] - if pieces['dirty']: - rendered += '.dev0' - rendered += '+g%s' % pieces['short'] + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] return rendered @@ -389,17 +414,17 @@ def render_pep440_old(pieces): Eexceptions: 1: no tags. 0.postDISTANCE[.dev0] """ - if pieces['closest-tag']: - rendered = pieces['closest-tag'] - if pieces['distance'] or pieces['dirty']: - rendered += '.post%d' % pieces['distance'] - if pieces['dirty']: - rendered += '.dev0' + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" else: # exception #1 - rendered = '0.post%d' % pieces['distance'] - if pieces['dirty']: - rendered += '.dev0' + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" return rendered @@ -411,15 +436,15 @@ def render_git_describe(pieces): Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ - if pieces['closest-tag']: - rendered = pieces['closest-tag'] - if pieces['distance']: - rendered += '-%d-g%s' % (pieces['distance'], pieces['short']) + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 - rendered = pieces['short'] - if pieces['dirty']: - rendered += '-dirty' + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" return rendered @@ -432,47 +457,53 @@ def render_git_describe_long(pieces): Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ - if pieces['closest-tag']: - rendered = pieces['closest-tag'] - rendered += '-%d-g%s' % (pieces['distance'], pieces['short']) + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 - rendered = pieces['short'] - if pieces['dirty']: - rendered += '-dirty' + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" return rendered def render(pieces, style): """Render the given version pieces into the requested style.""" - if pieces['error']: - return {'version': 'unknown', - 'full-revisionid': pieces.get('long'), - 'dirty': None, - 'error': pieces['error'], - 'date': None} - - if not style or style == 'default': - style = 'pep440' # the default - - if style == 'pep440': + if pieces["error"]: + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": rendered = render_pep440(pieces) - elif style == 'pep440-pre': + elif style == "pep440-pre": rendered = render_pep440_pre(pieces) - elif style == 'pep440-post': + elif style == "pep440-post": rendered = render_pep440_post(pieces) - elif style == 'pep440-old': + elif style == "pep440-old": rendered = render_pep440_old(pieces) - elif style == 'git-describe': + elif style == "git-describe": rendered = render_git_describe(pieces) - elif style == 'git-describe-long': + elif style == "git-describe-long": rendered = render_git_describe_long(pieces) else: raise ValueError("unknown style '%s'" % style) - return {'version': rendered, 'full-revisionid': pieces['long'], - 'dirty': pieces['dirty'], 'error': None, - 'date': pieces.get('date')} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -486,8 +517,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -496,13 +526,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for i in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {'version': '0+unknown', 'full-revisionid': None, - 'dirty': None, - 'error': 'unable to find root of source tree', - 'date': None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -516,6 +549,10 @@ def get_versions(): except NotThisMethod: pass - return {'version': '0+unknown', 'full-revisionid': None, - 'dirty': None, - 'error': 'unable to compute version', 'date': None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/oximachine_featurizer/featurize.py b/oximachine_featurizer/featurize.py index 36aaa2b..4dffa6f 100644 --- a/oximachine_featurizer/featurize.py +++ b/oximachine_featurizer/featurize.py @@ -569,7 +569,7 @@ def __init__( # pylint:disable=too-many-arguments percentage_holdout: float = 0, outdir_holdout: Union[str, Path] = None, forbidden_picklepath: Union[str, Path] = None, - exclude_dir: Union[str, Path] = "../test_structures/showcases", + exclude_dir: Union[str, Path] = None, selected_features: List[str] = CHEMISTRY_FEATURES + METAL_CENTER_FEATURES + ["crystal_nn_fingerprint"], diff --git a/oximachine_featurizer/utils.py b/oximachine_featurizer/utils.py index 22907ba..dafdf9f 100644 --- a/oximachine_featurizer/utils.py +++ b/oximachine_featurizer/utils.py @@ -20,7 +20,7 @@ def read_pickle(filepath: str): """Does what it says. Nothing more and nothing less. Takes a pickle file path and unpickles it""" - with open(filepath, 'rb') as fh: # pylint: disable=invalid-name + with open(filepath, "rb") as fh: # pylint: disable=invalid-name result = pickle.load(fh) # pylint: disable=invalid-name return result @@ -36,10 +36,9 @@ def flatten(items): def chunks(l, n): - """ Yield successive n-sized chunks from l. - """ + """Yield successive n-sized chunks from l.""" for i in range(0, len(l), n): - yield l[i:i + n] + yield l[i : i + n] def diff_to_18e(nvalence): @@ -50,9 +49,9 @@ def diff_to_18e(nvalence): def apricot_select(data, k, standardize=True, chunksize=20000): """Does 'farthest point sampling' with apricot. - For memory limitation reasons it is chunked with a hardcoded chunksize. """ + For memory limitation reasons it is chunked with a hardcoded chunksize.""" if standardize: - print('standardizing data') + print("standardizing data") data = StandardScaler().fit_transform(data) data = data.astype(np.float64) @@ -71,11 +70,11 @@ def apricot_select(data, k, standardize=True, chunksize=20000): to_select = int(k / num_chunks) - print(('Will use {} chunks of size {}'.format(num_chunks, chunksize))) + print(("Will use {} chunks of size {}".format(num_chunks, chunksize))) num_except = 0 for d_ in tqdm(chunks(data, chunksize)): - print(('Current chunk has size {}'.format(len(d_)))) + print(("Current chunk has size {}".format(len(d_)))) if len(d_) > to_select: # otherwise it makes no sense to select something try: X_subset = FacilityLocationSelection(to_select).fit_transform(d_) @@ -84,11 +83,12 @@ def apricot_select(data, k, standardize=True, chunksize=20000): num_except += 1 if num_except > 1: # pylint:disable=no-else-return warnings.warn( - 'Could not perform diverse set selection for two attempts, will perform random choice') + "Could not perform diverse set selection for two attempts, will perform random choice" + ) return np.random.choice(len(data), k, replace=False) else: - print('will use greedy select now') - X_subset = _greedy_loop(d_, to_select, 'euclidean') + print("will use greedy select now") + X_subset = _greedy_loop(d_, to_select, "euclidean") chunklist.append(X_subset) greedy_indices = [] subset = np.vstack(chunklist) @@ -124,21 +124,20 @@ def _greedy_loop(remaining, k, metric): return greedy_data -def _greedy_farthest_point_samples_non_chunked(data, - k: int, - metric: str = 'euclidean', - standardize: bool = True) -> list: +def _greedy_farthest_point_samples_non_chunked( + data, k: int, metric: str = "euclidean", standardize: bool = True +) -> list: """ - Args: - data (np.array) - k (int) - metric (string): metric to use for the distance, can be one from - https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html - defaults to euclidean - standardize (bool): flag that indicates whether features are standardized prior to sampling - Returns: - list with the sampled names - list of indices + Args: + data (np.array) + k (int) + metric (string): metric to use for the distance, can be one from + https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html + defaults to euclidean + standardize (bool): flag that indicates whether features are standardized prior to sampling + Returns: + list with the sampled names + list of indices """ data = data.astype(np.float32) @@ -164,45 +163,49 @@ def _greedy_farthest_point_samples_non_chunked(data, def greedy_farthest_point_samples( - data, - k: int, - metric: str = 'euclidean', - standardize: bool = True, - chunked: bool = False, + data, + k: int, + metric: str = "euclidean", + standardize: bool = True, + chunked: bool = False, ) -> list: """ - Args: - data (np.array) - k (int) - metric (string): metric to use for the distance, can be one from - https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html - defaults to euclidean - standardize (bool): flag that indicates whether features are standardized prior to sampling - Returns: - list with the sampled names - list of indices + Args: + data (np.array) + k (int) + metric (string): metric to use for the distance, can be one from + https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html + defaults to euclidean + standardize (bool): flag that indicates whether features are standardized prior to sampling + Returns: + list with the sampled names + list of indices """ if chunked: result = _greedy_farthest_point_samples_chunked(data, k, metric, standardize) else: - result = _greedy_farthest_point_samples_non_chunked(data, k, metric, standardize) + result = _greedy_farthest_point_samples_non_chunked( + data, k, metric, standardize + ) return result -def _greedy_farthest_point_samples_chunked(data, k: int, metric: str = 'euclidean', standardize: bool = True) -> list: +def _greedy_farthest_point_samples_chunked( + data, k: int, metric: str = "euclidean", standardize: bool = True +) -> list: """ - Args: - data (np.array) - k (int) - metric (string): metric to use for the distance, can be one from - https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html - defaults to euclidean - standardize (bool): flag that indicates whether features are standardized prior to sampling - Returns: - list with the sampled names - list of indices + Args: + data (np.array) + k (int) + metric (string): metric to use for the distance, can be one from + https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html + defaults to euclidean + standardize (bool): flag that indicates whether features are standardized prior to sampling + Returns: + list with the sampled names + list of indices """ data = data.astype(np.float32) @@ -219,7 +222,7 @@ def _greedy_farthest_point_samples_chunked(data, k: int, metric: str = 'euclidea # memory intensive i = 0 for d_ in chunks(data, chunksize): - print(('chunk {} out of {}'.format(i, num_chunks))) + print(("chunk {} out of {}".format(i, num_chunks))) d = d_ if len(d) > 2: index = np.random.randint(0, len(d) - 1) @@ -252,8 +255,10 @@ class SymbolNameDict: def __init__(self): with open( - os.path.join(Path(__file__).absolute().parent, 'data', 'periodic_table.json'), - 'r', + os.path.join( + Path(__file__).absolute().parent, "data", "periodic_table.json" + ), + "r", ) as periodic_table_file: self.pt_data = json.load(periodic_table_file) self.symbol_name_dict = {} @@ -265,8 +270,8 @@ def get_symbol_name_dict(self, only_metal=True): for key, value in self.pt_data.items(): if only_metal: if Element(key).is_metal: - self.symbol_name_dict[key] = value['Name'].lower() + self.symbol_name_dict[key] = value["Name"].lower() else: - self.symbol_name_dict[key] = value['Name'].lower() + self.symbol_name_dict[key] = value["Name"].lower() return self.symbol_name_dict diff --git a/run/_featurecollect_mp_solids.py b/run/_featurecollect_mp_solids.py index 2d50580..07ae691 100644 --- a/run/_featurecollect_mp_solids.py +++ b/run/_featurecollect_mp_solids.py @@ -18,22 +18,24 @@ def write_labels_to_stupid_format(df, outdir): # pylint:disable = invalid-name with the automatic featurecollection to keep our lives easier""" stupid_dict = {} for _, row in df.iterrows(): - stupid_dict[row['name']] = {row['metal']: [row['oxidationstate']]} + stupid_dict[row["name"]] = {row["metal"]: [row["oxidationstate"]]} - with open(os.path.join(outdir, 'materialsproject_structure_labels.pkl'), 'wb') as fh: # pylint:disable=invalid-name + with open( + os.path.join(outdir, "materialsproject_structure_labels.pkl"), "wb" + ) as fh: # pylint:disable=invalid-name pickle.dump(stupid_dict, fh) -@click.command('cli') -@click.argument('dfpath') -@click.argument('inpath') -@click.argument('outdir_labels') -@click.argument('outdir_features') -@click.argument('outdir_helper') -@click.argument('percentage_holdout') -@click.argument('outdir_holdout') -@click.argument('training_set_size') -@click.argument('features', nargs=-1) +@click.command("cli") +@click.argument("dfpath") +@click.argument("inpath") +@click.argument("outdir_labels") +@click.argument("outdir_features") +@click.argument("outdir_helper") +@click.argument("percentage_holdout") +@click.argument("outdir_holdout") +@click.argument("training_set_size") +@click.argument("features", nargs=-1) def main( # pylint:disable=too-many-arguments dfpath, inpath, @@ -50,7 +52,7 @@ def main( # pylint:disable=too-many-arguments dirpath = tempfile.mkdtemp() write_labels_to_stupid_format(df, outdir=dirpath) - print(f'rewrote labels to {dirpath}') + print(f"rewrote labels to {dirpath}") try: training_set_size = int(training_set_size) except Exception: # pylint:disable=broad-except @@ -59,7 +61,7 @@ def main( # pylint:disable=too-many-arguments fc = FeatureCollector( # pylint:disable=invalid-name inpath, - os.path.join(dirpath, 'materialsproject_structure_labels.pkl'), + os.path.join(dirpath, "materialsproject_structure_labels.pkl"), outdir_labels, outdir_features, outdir_helper, @@ -71,5 +73,5 @@ def main( # pylint:disable=too-many-arguments fc.dump_featurecollection() -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint:disable=no-value-for-parameter diff --git a/run/_featurize_mp_structures.py b/run/_featurize_mp_structures.py index be4b26b..d022f79 100644 --- a/run/_featurize_mp_structures.py +++ b/run/_featurize_mp_structures.py @@ -13,13 +13,15 @@ from oximachine_featurizer.featurize import GetFeatures -MPDIR = '/Users/kevinmaikjablonka/Dropbox (LSMO)/proj62_guess_oxidation_states/mp_structures' -ALREADY_FEATURIZED = glob(os.path.join(MPDIR, '*.pkl')) -OUTDIR = ('/Users/kevinmaikjablonka/Dropbox (LSMO)/proj62_guess_oxidation_states//mp_features') +MPDIR = "/Users/kevinmaikjablonka/Dropbox (LSMO)/proj62_guess_oxidation_states/mp_structures" +ALREADY_FEATURIZED = glob(os.path.join(MPDIR, "*.pkl")) +OUTDIR = ( + "/Users/kevinmaikjablonka/Dropbox (LSMO)/proj62_guess_oxidation_states//mp_features" +) def load_pickle(f): # pylint:disable=invalid-name - with open(f, 'rb') as fh: # pylint:disable=invalid-name + with open(f, "rb") as fh: # pylint:disable=invalid-name result = pickle.load(fh) return result @@ -31,12 +33,14 @@ def featurize_single(structure, outdir=OUTDIR): def main(): """CLI""" - all_structures = glob(os.path.join(MPDIR, '*.cif')) - print(f'found {len(all_structures)} structures') + all_structures = glob(os.path.join(MPDIR, "*.cif")) + print(f"found {len(all_structures)} structures") with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: - for _ in tqdm(executor.map(featurize_single, all_structures), total=len(all_structures)): + for _ in tqdm( + executor.map(featurize_single, all_structures), total=len(all_structures) + ): pass -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter diff --git a/run/_run_chemical_formulas.py b/run/_run_chemical_formulas.py index 8b36b8a..0652a48 100644 --- a/run/_run_chemical_formulas.py +++ b/run/_run_chemical_formulas.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- - +"""Get the chemical formulas for the structures +(used to plot the frequency of the elements) +""" import pickle import re import time @@ -8,8 +10,8 @@ def load_pickle(f): - with open(f, 'rb') as fh: - result = pickle.load(fh) + with open(f, "rb") as handle: + result = pickle.load(handle) return result @@ -18,8 +20,7 @@ def get_chemical_formula(csd_reader, database_id): formula = entry_object.crystal.formula formula_dict = {} for elem in formula.split(): - count = int(re.search(r'\d+', elem).group()) - symbol = elem.strip(str(count)) + count = int(re.search(r"\d+", elem).group()) formula_dict[elem] = count return formula_dict @@ -28,22 +29,22 @@ def get_chemical_formula(csd_reader, database_id): def main(): # oxidation_parse_dict = load_pickle( # "/home/kevin/Dropbox (LSMO)/proj62_guess_oxidation_states/oxidation_state_book/data/20190820-173457-csd_ox_parse_output.pkl" - #) + # ) oxidation_reference_dict = load_pickle( - '/home/kevin/Dropbox (LSMO)/proj62_guess_oxidation_states/mine_csd/20190921-142007-csd_ox_parse_output_reference.pkl' + "/home/kevin/Dropbox (LSMO)/proj62_guess_oxidation_states/mine_csd/20190921-142007-csd_ox_parse_output_reference.pkl" ) database_ids = list(oxidation_reference_dict.keys()) - csd_reader = io.EntryReader('CSD') + csd_reader = io.EntryReader("CSD") formula_dicts = {} for database_id in database_ids: formula_dicts[database_id] = get_chemical_formula(csd_reader, database_id) - timestr = time.strftime('%Y%m%d-%H%M%S') - output_name = '-'.join([timestr, 'get_chemical_formulas']) - with open(output_name + '.pkl', 'wb') as filehandle: + timestr = time.strftime("%Y%m%d-%H%M%S") + output_name = "-".join([timestr, "get_chemical_formulas"]) + with open(output_name + ".pkl", "wb") as filehandle: pickle.dump(formula_dicts, filehandle) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/run/_run_featurecollection.py b/run/_run_featurecollection.py index 08ae035..cca348f 100644 --- a/run/_run_featurecollection.py +++ b/run/_run_featurecollection.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- # pylint:disable=relative-beyond-top-level, line-too-long, too-many-arguments -""" -Status: Dev -Run the featurization on one structure -""" +"""Run the featurization on one structure""" import os @@ -12,19 +9,19 @@ from oximachine_featurizer.featurize import FeatureCollector -@click.command('cli') -@click.argument('inpath', type=click.Path(exists=True)) -@click.argument('labelpath', type=click.Path(exists=True)) -@click.argument('outdir_labels', type=click.Path(exists=True)) -@click.argument('outdir_features', type=click.Path(exists=True)) -@click.argument('outdir_helper', type=click.Path(exists=True)) -@click.argument('percentage_holdout') -@click.argument('outdir_holdout', type=click.Path(exists=True)) -@click.argument('training_set_size') -@click.argument('racsfile') -@click.argument('features', nargs=-1) -@click.option('--only_racs', is_flag=True) -@click.option('--do_not_drop_duplicates', is_flag=True) +@click.command("cli") +@click.argument("inpath", type=click.Path(exists=True)) +@click.argument("labelpath", type=click.Path(exists=True)) +@click.argument("outdir_labels", type=click.Path(exists=True)) +@click.argument("outdir_features", type=click.Path(exists=True)) +@click.argument("outdir_helper", type=click.Path(exists=True)) +@click.argument("percentage_holdout") +@click.argument("outdir_holdout", type=click.Path(exists=True)) +@click.argument("training_set_size") +@click.argument("racsfile") +@click.argument("features", nargs=-1) +@click.option("--only_racs", is_flag=True) +@click.option("--do_not_drop_duplicates", is_flag=True) def main( inpath, labelpath, @@ -39,9 +36,7 @@ def main( only_racs, do_not_drop_duplicates, ): - """ - CLI function - """ + """CLI function""" do_not_drop_duplicates = not do_not_drop_duplicates @@ -73,5 +68,5 @@ def main( fc.dump_featurecollection() -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter diff --git a/run/_run_featurization_cod.py b/run/_run_featurization_cod.py index 7b8ff55..10d6465 100644 --- a/run/_run_featurization_cod.py +++ b/run/_run_featurization_cod.py @@ -14,25 +14,28 @@ from oximachine_featurizer.featurize import GetFeatures -OUTDIR = '/scratch/kjablonk/oximachine_all/features_cod' -INDIR = '/work/lsmo/jablonka/cod_to_featurize' -ALREADY_FEATURIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, '*.pkl'))] +OUTDIR = "/scratch/kjablonk/oximachine_all/features_cod" +INDIR = "/work/lsmo/jablonka/cod_to_featurize" +ALREADY_FEATURIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, "*.pkl"))] def read_already_featurized(): - if os.path.exists('already_featurized.txt'): - with open('already_featurized.txt', 'r') as fh: + """Reads a file with list of already featurized files""" + if os.path.exists("already_featurized.txt"): + with open("already_featurized.txt", "r") as fh: already_featurized = fh.readlines() ALREADY_FEATURIZED.extend(already_featurized) def load_pickle(f): # pylint:disable=invalid-name - with open(f, 'rb') as fh: # pylint:disable=invalid-name - result = pickle.load(fh) + """Returns content of pickle file.""" + with open(f, "rb") as handle: # pylint:disable=invalid-name + result = pickle.load(handle) return result def featurize_single(structure, outdir=OUTDIR): + """Runs featurization on one structure.""" if Path(structure).stem not in ALREADY_FEATURIZED: try: gf = GetFeatures.from_file(structure, outdir) # pylint:disable=invalid-name @@ -41,22 +44,23 @@ def featurize_single(structure, outdir=OUTDIR): pass -@click.command('cli') -@click.option('--reverse', is_flag=True) +@click.command("cli") +@click.option("--reverse", is_flag=True) def main(reverse): + """CLI""" read_already_featurized() if reverse: - all_structures = sorted(glob(os.path.join(INDIR, '*.cif')), reverse=True) + all_structures = sorted(glob(os.path.join(INDIR, "*.cif")), reverse=True) else: - all_structures = sorted(glob(os.path.join(INDIR, '*.cif'))) + all_structures = sorted(glob(os.path.join(INDIR, "*.cif"))) with concurrent.futures.ProcessPoolExecutor() as executor: for _ in tqdm( - list(executor.map(featurize_single, all_structures)), - total=len(all_structures), + list(executor.map(featurize_single, all_structures)), + total=len(all_structures), ): pass -if __name__ == '__main__': - print(('working in {}'.format(INDIR))) +if __name__ == "__main__": + print(("working in {}".format(INDIR))) main() # pylint: disable=no-value-for-parameter diff --git a/run/_run_featurization_many.py b/run/_run_featurization_many.py index 1c79c51..124f105 100644 --- a/run/_run_featurization_many.py +++ b/run/_run_featurization_many.py @@ -14,21 +14,21 @@ from oximachine_featurizer.featurize import GetFeatures -OUTDIR = '/scratch/kjablonk/oximachine_all/features' -INDIR = '/work/lsmo/jablonka/2020-4-7_all_csd_for_oximachine/cif_for_feat' -ALREADY_FEATURIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, '*.pkl'))] +OUTDIR = "/scratch/kjablonk/oximachine_all/features" +INDIR = "/work/lsmo/jablonka/2020-4-7_all_csd_for_oximachine/cif_for_feat" +ALREADY_FEATURIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, "*.pkl"))] def read_already_featurized(): - if os.path.exists('already_featurized.txt'): - with open('already_featurized.txt', 'r') as fh: + if os.path.exists("already_featurized.txt"): + with open("already_featurized.txt", "r") as fh: already_featurized = fh.readlines() ALREADY_FEATURIZED.extend(already_featurized) def load_pickle(f): # pylint:disable=invalid-name - with open(f, 'rb') as fh: # pylint:disable=invalid-name - result = pickle.load(fh) + with open(f, "rb") as handle: # pylint:disable=invalid-name + result = pickle.load(handle) return result @@ -41,22 +41,22 @@ def featurize_single(structure, outdir=OUTDIR): pass -@click.command('cli') -@click.option('--reverse', is_flag=True) +@click.command("cli") +@click.option("--reverse", is_flag=True) def main(reverse): read_already_featurized() if reverse: - all_structures = sorted(glob(os.path.join(INDIR, '*.cif')), reverse=True) + all_structures = sorted(glob(os.path.join(INDIR, "*.cif")), reverse=True) else: - all_structures = sorted(glob(os.path.join(INDIR, '*.cif'))) + all_structures = sorted(glob(os.path.join(INDIR, "*.cif"))) with concurrent.futures.ProcessPoolExecutor() as executor: for _ in tqdm( - list(executor.map(featurize_single, all_structures)), - total=len(all_structures), + list(executor.map(featurize_single, all_structures)), + total=len(all_structures), ): pass -if __name__ == '__main__': - print(('working in {}'.format(INDIR))) +if __name__ == "__main__": + print(("working in {}".format(INDIR))) main() # pylint: disable=no-value-for-parameter diff --git a/run/_run_featurization_slurm_serial.py b/run/_run_featurization_slurm_serial.py index 7838ac9..6e9b1c9 100644 --- a/run/_run_featurization_slurm_serial.py +++ b/run/_run_featurization_slurm_serial.py @@ -18,16 +18,18 @@ import click -featurizer = logging.getLogger('featurizer') # pylint:disable=invalid-name +featurizer = logging.getLogger("featurizer") # pylint:disable=invalid-name featurizer.setLevel(logging.DEBUG) -logging.basicConfig(filename='featurizer.log', format='%(filename)s: %(message)s', level=logging.DEBUG) +logging.basicConfig( + filename="featurizer.log", format="%(filename)s: %(message)s", level=logging.DEBUG +) THIS_DIR = os.path.dirname(__file__) -OUTDIR = '/scratch/kjablonk/proj62_featurization/extended_chemspace' -CSDDIR = '/work/lsmo/mof_subset_csdmay2019' -ALREADY_FEAUTRIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, '*.pkl'))] -NAME_LIST = '/scratch/kjablonk/oxidationstates/to_sample_new.pkl' +OUTDIR = "/scratch/kjablonk/proj62_featurization/extended_chemspace" +CSDDIR = "/work/lsmo/mof_subset_csdmay2019" +ALREADY_FEAUTRIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, "*.pkl"))] +NAME_LIST = "/scratch/kjablonk/oxidationstates/to_sample_new.pkl" SUBMISSION_TEMPLATE = """#!/bin/bash -l #SBATCH --chdir ./ @@ -46,8 +48,8 @@ def load_pickle(f): # pylint:disable=invalid-name """Loads a pickle file""" - with open(f, 'rb') as fh: # pylint:disable=invalid-name - result = pickle.load(fh) + with open(f, "rb") as handle: # pylint:disable=invalid-name + result = pickle.load(handle) return result @@ -56,23 +58,27 @@ def load_pickle(f): # pylint:disable=invalid-name def write_and_submit_slurm(workdir, name, structure, outdir, submit=False): """writes a slurm submission script and submits it if requested""" - submission_template = SUBMISSION_TEMPLATE.format(name=name + '_featurize', structure=structure, outdir=outdir) - with open(os.path.join(workdir, name + '.slurm'), 'w') as fh: # pylint:disable=invalid-name + submission_template = SUBMISSION_TEMPLATE.format( + name=name + "_featurize", structure=structure, outdir=outdir + ) + with open( + os.path.join(workdir, name + ".slurm"), "w" + ) as fh: # pylint:disable=invalid-name for line in submission_template: fh.write(line) - featurizer.info('prepared {} for submission'.format(name)) + featurizer.info("prepared {} for submission".format(name)) if submit: - subprocess.call('sbatch {}'.format('{}.slurm'.format(name)), shell=True) + subprocess.call("sbatch {}".format("{}.slurm".format(name)), shell=True) time.sleep(2) - featurizer.info('submitted {}'.format(name)) + featurizer.info("submitted {}".format(name)) -@click.command('cli') -@click.argument('outdir') -@click.argument('start') -@click.argument('end') -@click.option('--submit', is_flag=True, help='actually submit slurm job') +@click.command("cli") +@click.argument("outdir") +@click.argument("start") +@click.argument("end") +@click.option("--submit", is_flag=True, help="actually submit slurm job") def main(outdir, start, end, submit): """Runs the CLI""" if not os.path.exists(outdir): @@ -83,8 +89,10 @@ def main(outdir, start, end, submit): for structure in TO_SAMPLE[start:end:5]: if structure not in ALREADY_FEAUTRIZED: name = structure - write_and_submit_slurm(THIS_DIR, name, os.path.join(CSDDIR, structure + '.cif'), outdir, submit) + write_and_submit_slurm( + THIS_DIR, name, os.path.join(CSDDIR, structure + ".cif"), outdir, submit + ) -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint:disable=no-value-for-parameter diff --git a/run/_submit_slurm_from_folder.py b/run/_submit_slurm_from_folder.py index 90bf606..b0b83fa 100644 --- a/run/_submit_slurm_from_folder.py +++ b/run/_submit_slurm_from_folder.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- # pylint:disable = logging-format-interpolation """ -This runscript can be used to submit the featurizations on a HPC clusters with the SLURM workload manager. +This runscript can be used to submit the featurizations on +a HPC clusters with the SLURM workload manager. -Usage: Install script in conda enviornment called ml on cluster and then run it using the +Usage: Install script in conda environment called ml +on the cluster and then run it using the the outdir, start and end indices and submit flag. """ @@ -16,15 +18,17 @@ import click -featurizer = logging.getLogger('featurizer') # pylint:disable=invalid-name +featurizer = logging.getLogger("featurizer") # pylint:disable=invalid-name featurizer.setLevel(logging.DEBUG) -logging.basicConfig(filename='featurizer.log', format='%(filename)s: %(message)s', level=logging.DEBUG) +logging.basicConfig( + filename="featurizer.log", format="%(filename)s: %(message)s", level=logging.DEBUG +) THIS_DIR = os.path.dirname(__file__) -OUTDIR = '/scratch/kjablonk/oximachine_all' -INDIR = '/work/lsmo/jablonka/2020-4-7_all_csd_for_oximachine/cif_for_feat' -ALREADY_FEAUTRIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, '*.pkl'))] +OUTDIR = "/scratch/kjablonk/oximachine_all" +INDIR = "/work/lsmo/jablonka/2020-4-7_all_csd_for_oximachine/cif_for_feat" +ALREADY_FEAUTRIZED = [Path(p).stem for p in glob(os.path.join(OUTDIR, "*.pkl"))] SUBMISSION_TEMPLATE = """#!/bin/bash -l #SBATCH --chdir ./ @@ -41,28 +45,32 @@ run_featurization {structure} {outdir} """ -all_structures = sorted(glob(os.path.join(INDIR, '*.cif'))) +all_structures = sorted(glob(os.path.join(INDIR, "*.cif"))) def write_and_submit_slurm(workdir, name, structure, outdir, submit=False): """writes a slurm submission script and submits it if requested""" - submission_template = SUBMISSION_TEMPLATE.format(name=name + '_featurize', structure=structure, outdir=outdir) - with open(os.path.join(workdir, name + '.slurm'), 'w') as fh: # pylint:disable=invalid-name + submission_template = SUBMISSION_TEMPLATE.format( + name=name + "_featurize", structure=structure, outdir=outdir + ) + with open( + os.path.join(workdir, name + ".slurm"), "w" + ) as fh: # pylint:disable=invalid-name for line in submission_template: fh.write(line) - featurizer.info('prepared {} for submission'.format(name)) + featurizer.info("prepared {} for submission".format(name)) if submit: - subprocess.call('sbatch {}'.format('{}.slurm'.format(name)), shell=True) + subprocess.call("sbatch {}".format("{}.slurm".format(name)), shell=True) time.sleep(2) - featurizer.info('submitted {}'.format(name)) + featurizer.info("submitted {}".format(name)) -@click.command('cli') -@click.argument('outdir') -@click.argument('start') -@click.argument('end') -@click.option('--submit', is_flag=True, help='actually submit slurm job') +@click.command("cli") +@click.argument("outdir") +@click.argument("start") +@click.argument("end") +@click.option("--submit", is_flag=True, help="actually submit slurm job") def main(outdir, start, end, submit): """Runs the CLI""" if not os.path.exists(outdir): @@ -76,5 +84,5 @@ def main(outdir, start, end, submit): write_and_submit_slurm(THIS_DIR, name, structure, outdir, submit) -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint:disable=no-value-for-parameter diff --git a/run/merge_two_x_y.py b/run/merge_two_x_y.py index d5fc1d8..e3f5095 100644 --- a/run/merge_two_x_y.py +++ b/run/merge_two_x_y.py @@ -100,9 +100,9 @@ def from_files( # pylint:disable=too-many-arguments ) @staticmethod - def output( # pylint:disable = invalid-name - X, # pylint:disable = invalid-name - y, # pylint:disable = invalid-name + def output( # pylint:disable=invalid-name + X, + y, names, outdir_features, outdir_labels, @@ -111,10 +111,10 @@ def output( # pylint:disable = invalid-name """Write the new training set files for the merged training set""" features, labels, names = shuffle(X, y, names, random_state=RANDOM_SEED) - np.save(os.path.join(outdir_features, 'features'), features) - np.save(os.path.join(outdir_labels, 'labels'), labels) + np.save(os.path.join(outdir_features, "features"), features) + np.save(os.path.join(outdir_labels, "labels"), labels) - with open(os.path.join(outdir_names, 'names.pkl'), 'wb') as picklefile: + with open(os.path.join(outdir_names, "names.pkl"), "wb") as picklefile: pickle.dump(names, picklefile) def merge(self): @@ -129,19 +129,21 @@ def merge(self): ) # Now shuffle and output - Merger.output(X, y, names, self.outdir_features, self.outdir_labels, self.outdir_names) - - -@click.command('cli') -@click.argument('features0path') -@click.argument('features1path') -@click.argument('labels0path') -@click.argument('labels1path') -@click.argument('names0path') -@click.argument('names1path') -@click.argument('outdir_features') -@click.argument('outdir_labels') -@click.argument('outdir_names') + Merger.output( + X, y, names, self.outdir_features, self.outdir_labels, self.outdir_names + ) + + +@click.command("cli") +@click.argument("features0path") +@click.argument("features1path") +@click.argument("labels0path") +@click.argument("labels1path") +@click.argument("names0path") +@click.argument("names1path") +@click.argument("outdir_features") +@click.argument("outdir_labels") +@click.argument("outdir_names") def run_merging( # pylint:disable=too-many-arguments features0path, features1path, @@ -168,5 +170,5 @@ def run_merging( # pylint:disable=too-many-arguments merger.merge() -if __name__ == '__main__': +if __name__ == "__main__": run_merging() # pylint:disable=no-value-for-parameter diff --git a/run/run_create_features_labels.py b/run/run_create_features_labels.py index 97b9e83..fbf650f 100644 --- a/run/run_create_features_labels.py +++ b/run/run_create_features_labels.py @@ -11,15 +11,15 @@ from oximachine_featurizer.featurize import FeatureCollector -@click.command('cli') -@click.argument('inpath') -@click.argument('labelsfile') -@click.argument('outdir') -def main(inpath, labelsfile, outdir): #pylint:disable=unused-argument +@click.command("cli") +@click.argument("inpath") +@click.argument("labelsfile") +@click.argument("outdir") +def main(inpath, labelsfile, outdir): # pylint:disable=unused-argument """Run the CLI""" - fc = FeatureCollector(inpath, labelsfile, outdir) #pylint:disable=invalid-name + fc = FeatureCollector(inpath, labelsfile, outdir) # pylint:disable=invalid-name fc.dump_featurecollection() -if __name__ == '__main__': - main() #pylint:disable=no-value-for-parameter +if __name__ == "__main__": + main() # pylint:disable=no-value-for-parameter diff --git a/run/run_featurization.py b/run/run_featurization.py index 8860b90..34f0131 100644 --- a/run/run_featurization.py +++ b/run/run_featurization.py @@ -1,8 +1,5 @@ # -*- coding: utf-8 -*- -# pylint:disable=relative-beyond-top-level -""" -Run the featurization on one structure -""" +"""Run the featurization on one structure""" import click import numpy as np @@ -14,14 +11,12 @@ @click.command("cli") @click.argument("structure") @click.argument("outname") -def main(structure, outname): - """ - CLI function - """ +def main(structure: str, outname: str): + """CLI function""" structure = Structure.from_file(structure) - X, _, _ = featurize(structure) # pylint: disable=invalid-name + feature_matrix, _, _ = featurize(structure) - np.save(outname, X) + np.save(outname, feature_matrix) if __name__ == "__main__": diff --git a/run/run_mine_mp.py b/run/run_mine_mp.py index 6475403..9b49f7a 100644 --- a/run/run_mine_mp.py +++ b/run/run_mine_mp.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """ Get some structures and labels for solids. -I probably should have used one simple GET request instead of querying multiple times, but I'll go with it now, +I probably should have used one simple GET request +instead of querying multiple times, but I'll go with it now, it is not too slow """ @@ -12,78 +13,78 @@ from pymatgen import MPRester from tqdm import tqdm -mp_api = MPRester(os.getenv('MP_API_KEY', None)) # pylint:disable=invalid-name +mp_api = MPRester(os.getenv("MP_API_KEY", None)) # pylint:disable=invalid-name # Select metals and anions that are of interest for us anions_dict = { # pylint:disable=invalid-name - 'I': -1, - 'Cl': -1, - 'Br': -1, - 'F': -1, - 'O': -2, - 'S': -2, - 'N': -3, + "I": -1, + "Cl": -1, + "Br": -1, + "F": -1, + "O": -2, + "S": -2, + "N": -3, } metals = [ # pylint:disable=invalid-name - 'Li', - 'Na', - 'K', - 'Rb', - 'Cs', - 'Be', - 'Mg', - 'Ca', - 'Sr', - 'Ba', - 'Sc', - 'Ti', - 'V', - 'Cr', - 'Mn', - 'Fe', - 'Co', - 'Ni', - 'Cu', - 'Zn', - 'Y', - 'Zr', - 'Nb', - 'Mo', - 'Tc', - 'Ru', - 'Rh', - 'Pd', - 'Ag', - 'Cd', - 'Hf', - 'Ta', - 'W', - 'Re', - 'Os', - 'Ir', - 'Pt', - 'Au', - 'Hg', - 'B', - 'Al', - 'Ga', - 'In', - 'Tl', - 'Sn', - 'Pb', - 'Bi', - 'La', - 'Ce', - 'Pr', - 'Eu', - 'Gd', - 'Tb', - 'Dy', - 'Ho', - 'Er', - 'Tm', - 'U', - 'Pu', + "Li", + "Na", + "K", + "Rb", + "Cs", + "Be", + "Mg", + "Ca", + "Sr", + "Ba", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "B", + "Al", + "Ga", + "In", + "Tl", + "Sn", + "Pb", + "Bi", + "La", + "Ce", + "Pr", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "U", + "Pu", ] anions = list(anions_dict.keys()) # pylint:disable=invalid-name @@ -91,18 +92,21 @@ def check_stable(entry_id): """Check if energy is at hull minimum""" - return mp_api.get_data(entry_id, prop='e_above_hull')[0]['e_above_hull'] == 0 + return mp_api.get_data(entry_id, prop="e_above_hull")[0]["e_above_hull"] == 0 def get_binary_combinations( # pylint:disable=dangerous-default-value, redefined-outer-name - metals: list = metals, anions: list = anions) -> list: + metals: list = metals, anions: list = anions +) -> list: """Create list of entries of binary compounds with the metals and anions we defined which are stable""" combinations = list(product(metals, anions)) entries = [] for combination in tqdm(combinations): combination_entries = mp_api.get_entries_in_chemsys(combination) for combination_entry in combination_entries: - if check_stable(combination_entry.entry_id) and (len(combination_entry.as_dict()['composition']) > 1): + if check_stable(combination_entry.entry_id) and ( + len(combination_entry.as_dict()["composition"]) > 1 + ): entries.append(combination_entry.entry_id) return entries @@ -127,10 +131,10 @@ def calculate_metal_oxidation_state(formula: dict, metal: str, anion: str): """This returns the metal oxidation state for a composition dict""" # first check if first or second group, always set them to +1 and +2, respectively and see oxidationstate = None - if metal in ['Li', 'Na', 'K', 'Rb', 'Cs']: + if metal in ["Li", "Na", "K", "Rb", "Cs"]: if _check_consistency_ox_state(formula, 1.0, metal, anion): oxidationstate = 1.0 - elif metal in ['Be', 'Mg', 'Ca', 'Sr', 'Ba']: + elif metal in ["Be", "Mg", "Ca", "Sr", "Ba"]: if _check_consistency_ox_state(formula, 2.0, metal, anion): oxidationstate = 2.0 else: @@ -159,45 +163,47 @@ def which_is_the_metal( # pylint:disable=dangerous-default-value return metal, anion -def collect_for_id(entry_id, outdir='mp_structures'): +def collect_for_id(entry_id, outdir="mp_structures"): """Run the collections for one materials project id""" if not os.path.exists(outdir): os.mkdir(outdir) outdict = {} - outdict['material_id'] = entry_id + outdict["material_id"] = entry_id s = mp_api.get_structure_by_material_id(entry_id) # pylint:disable=invalid-name formula_dict = dict(s.composition.get_el_amt_dict()) # it returns a defaultdict metal, anion = which_is_the_metal(formula_dict) - outdict['metal'] = metal - outdict['anion'] = anion + outdict["metal"] = metal + outdict["anion"] = anion formula_string = s.formula - outdict['formula'] = formula_string + outdict["formula"] = formula_string oxidationstate = calculate_metal_oxidation_state(formula_dict, metal, anion) - outdict['oxidationstate'] = oxidationstate - name = entry_id + formula_string.replace(' ', '_') - outdict['name'] = name + outdict["oxidationstate"] = oxidationstate + name = entry_id + formula_string.replace(" ", "_") + outdict["name"] = name if oxidationstate is not None: - s.to(filename=os.path.join(outdir, name + '.cif')) + s.to(filename=os.path.join(outdir, name + ".cif")) return outdict def collect_entries(): """Runs the whole thing""" - print('*** Starting collect entries for all binary combinations and check if they are stable ***') + print( + "*** Starting collect entries for all binary combinations and check if they are stable ***" + ) entries = get_binary_combinations() - print('*** Now iterating over all the entries to find out oxidation states ***') + print("*** Now iterating over all the entries to find out oxidation states ***") results = [] for entry in entries: outdict = collect_for_id(entry) results.append(outdict) print(f"Worked on {outdict['name']}") - print('*** Finished datacollection ***') - print('found {} materials'.format(len(results))) + print("*** Finished datacollection ***") + print("found {} materials".format(len(results))) df = pd.DataFrame(results) # pylint:disable=invalid-name - df.to_csv('mp_parsing_results.csv') + df.to_csv("mp_parsing_results.csv") -if __name__ == '__main__': +if __name__ == "__main__": collect_entries() diff --git a/run/run_parsing.py b/run/run_parsing.py index 2aa6142..75202fc 100644 --- a/run/run_parsing.py +++ b/run/run_parsing.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # pylint:disable=relative-beyond-top-level """ -Run the oxidation state mining. +Run the oxidation state mining. Note that the default filepaths are specific to our enviornment. """ @@ -16,7 +16,7 @@ from oximachine_featurizer.parse import GetOxStatesCSD -def prepare_list(indir='/mnt/lsmo_databases/mof_subset_csdmay2019'): +def prepare_list(indir="/mnt/lsmo_databases/mof_subset_csdmay2019"): """ Args: @@ -26,7 +26,7 @@ def prepare_list(indir='/mnt/lsmo_databases/mof_subset_csdmay2019'): list: filestemms """ - names = glob(os.path.join(indir, '*.cif')) + names = glob(os.path.join(indir, "*.cif")) names_cleaned = [Path(n).stem for n in names] return names_cleaned @@ -44,23 +44,23 @@ def run_parsing(names_cleaned, output_name=None): """ getoxstatesobject = GetOxStatesCSD(names_cleaned) if output_name is None: - timestr = time.strftime('%Y%m%d-%H%M%S') - output_name = '-'.join([timestr, 'csd_ox_parse_output']) + timestr = time.strftime("%Y%m%d-%H%M%S") + output_name = "-".join([timestr, "csd_ox_parse_output"]) outputdict = getoxstatesobject.run_parsing(njobs=4) - with open(output_name + '.pkl', 'wb') as filehandle: + with open(output_name + ".pkl", "wb") as filehandle: pickle.dump(outputdict, filehandle) -@click.command('cli') -@click.argument('indir', default='/mnt/lsmo_databases/mof_subset_csdmay2019') -@click.argument('outname', default=None) +@click.command("cli") +@click.argument("indir", default="/mnt/lsmo_databases/mof_subset_csdmay2019") +@click.argument("outname", default=None) def main(indir, outname): """CLI function""" names_cleaned = prepare_list(indir) run_parsing(names_cleaned, outname) -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter diff --git a/run/run_parsing_reference.py b/run/run_parsing_reference.py index 06d4342..44c22a2 100644 --- a/run/run_parsing_reference.py +++ b/run/run_parsing_reference.py @@ -17,7 +17,7 @@ def generate_id_list(num_samples=1009141): """Sample some random entries from the CSD""" ids = [] - csd_reader = io.EntryReader('CSD') + csd_reader = io.EntryReader("CSD") idxs = random.sample(list(range(len(csd_reader))), num_samples) for idx in idxs: ids.append(csd_reader[idx].identifier) @@ -37,23 +37,21 @@ def run_parsing(output_name=None): # all database entries getoxstatesobject = GetOxStatesCSD(generate_id_list()) if output_name is None: - timestr = time.strftime('%Y%m%d-%H%M%S') - output_name = '-'.join([timestr, 'csd_ox_parse_output_reference']) + timestr = time.strftime("%Y%m%d-%H%M%S") + output_name = "-".join([timestr, "csd_ox_parse_output_reference"]) outputdict = getoxstatesobject.run_parsing(njobs=4) - with open(output_name + '.pkl', 'wb') as filehandle: + with open(output_name + ".pkl", "wb") as filehandle: pickle.dump(outputdict, filehandle) -@click.command('cli') -@click.option('--outname', default=None) -def main(outname): - """ - CLI function - """ +@click.command("cli") +@click.option("--outname", default=None) +def main(outname: str): + """CLI function""" run_parsing(outname) -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter diff --git a/test/__init__.py b/test/__init__.py index 38ac08b..1f982ee 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -2,4 +2,4 @@ import os # This export here is not really nice for sharing the tests, but it is the easiest way to get it running ... -os.environ['CSDHOME'] = '/Applications/CCDC/CSD_2020' +os.environ["CSDHOME"] = "/Applications/CCDC/CSD_2020" diff --git a/test/conftest.py b/test/conftest.py index 3f2edd1..65a9dfc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- import numpy as np -import pytest import pandas as pd +import pytest -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def get_oxidationstate_dict(): names = [ - 'Tetracarbonyl-(h6-1,2,4,5-tetramethylbenzene)-vanadium(i) hexacarbonyl-vanadium', - '(m-carbonyl)-([(phenylphosphanediyl)di(2,1-phenylene)]bis[di(propan-2-yl)phosphane])-tris[1,2-bis(methoxy)ethane]-sodium-cobalt(-1)', + "Tetracarbonyl-(h6-1,2,4,5-tetramethylbenzene)-vanadium(i) hexacarbonyl-vanadium", + "(m-carbonyl)-([(phenylphosphanediyl)di(2,1-phenylene)]bis[di(propan-2-yl)phosphane])-tris[1,2-bis(methoxy)ethane]-sodium-cobalt(-1)", ] - expected_result = [{'V': [1, np.nan]}, {'Co': [-1]}] + expected_result = [{"V": [1, np.nan]}, {"Co": [-1]}] return names, expected_result diff --git a/test/main/test_utils.py b/test/main/test_utils.py index 7d4ee4d..354bf89 100644 --- a/test/main/test_utils.py +++ b/test/main/test_utils.py @@ -7,5 +7,5 @@ def test_symbolnamedict(): symbol_name_dict = SymbolNameDict().get_symbol_name_dict() assert isinstance(symbol_name_dict, dict) - assert 'H' not in list(symbol_name_dict.keys()) # default only metals - assert symbol_name_dict['Zn'] == 'zinc' + assert "H" not in list(symbol_name_dict.keys()) # default only metals + assert symbol_name_dict["Zn"] == "zinc"