Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Recovering MatBench results from weights for submission #67

Open
PatReis opened this issue Mar 31, 2023 · 0 comments
Open

Recovering MatBench results from weights for submission #67

PatReis opened this issue Mar 31, 2023 · 0 comments

Comments

@PatReis
Copy link

PatReis commented Mar 31, 2023

Dear m3gnet developers,

thanks for providing pretrained models and the code for M3Gnet. I was trying to recover the M3Gnet predictions from MatBench training for submitting your results to https://github.com/materialsproject/matbench. I have this code below but I could not recover the exact values reported in the paper (although very close).
Can you help me? Are there some differences or problems in conversion I did not found out?

import os.path
import os
import requests
import zipfile
import numpy as np
import tensorflow as tf
from m3gnet.models import M3GNet
from matbench.bench import MatbenchBenchmark
from pymatgen.core import Lattice, Structure
import logging
import urllib.request
from m3gnet.layers import AtomRef

download_url = "https://figshare.com/ndownloader/files/35948966"
full_file_path = "weights.zip"
if not os.path.exists(full_file_path):
    r = requests.get(download_url, allow_redirects=True)
    with open(full_file_path, 'wb') as f:
        f.write(r.content)

file_path = "model_weights"
os.makedirs(file_path, exist_ok=True)
archive = zipfile.ZipFile(full_file_path, "r")
archive.extractall(file_path)
archive.close()

subsets_compatible = [
    "matbench_mp_e_form",
    "matbench_mp_gap",
    "matbench_mp_is_metal",
    "matbench_perovskites",
    "matbench_log_kvrh",
    "matbench_log_gvrh",
    "matbench_dielectric",
    "matbench_phonons",
    "matbench_jdft2d"
  ]
units = {"matbench_jdft2d": 1000, "matbench_phonons": 1000}
fit_per_element_offset = False
overwrite = False
mb = MatbenchBenchmark(subset=subsets_compatible, autoload=False)

for idx_task, task in enumerate(mb.tasks):
    task.load()
    for i, fold in enumerate(task.folds):

        tf.keras.backend.clear_session()
        # tf.keras.backend.set_floatx("float64")
        if task.dataset_name in units:
            scale_unit = units[task.dataset_name]
        else:
            scale_unit = 1.0

        predictions_path = "%s_predictions_%s_fold_%s.npy" % (task.dataset_name, "m3gnet", i)
        # model = M3GNet.from_dir("MP-2021.2.8-EFS")
        train_inputs, train_outputs = task.get_train_and_val_data(fold)
        test_inputs = task.get_test_data(fold, include_target=False)
        model = M3GNet.from_dir("model_weights/m3gnet_models/%s/%s/m3gnet" % (task.dataset_name, fold))
        if not os.path.exists(predictions_path) or overwrite:
            if fit_per_element_offset:
                graphs = [model.graph_converter(i) for i in train_inputs]
                ar = AtomRef(max_z=model.n_atom_types + 1)
                ar.fit(graphs, train_outputs)
                model.set_element_refs(ar.property_per_element)
            predictions = model.predict_structures(test_inputs)
            np.save(predictions_path, predictions)
        else:
            predictions = np.load(predictions_path)
            print("loaded predictions: %s" % predictions_path)

        if predictions.shape[-1] == 1:
            predictions = np.squeeze(predictions, axis=-1)

        # train_std = np.std(train_outputs)
        # train_mean = np.mean(train_outputs)
        # predictions = predictions * train_std + train_mean
        predictions = scale_unit * predictions

        # Record data!
        task.record(fold, predictions, params={})

# Save your results
mb.to_file("results.json.gz")

for key, values in mb.scores.items():
    factor = 1000.0 if key in ["matbench_mp_e_form", "matbench_mp_gap", "matbench_perovskites"] else 1.0
    if key not in ["matbench_mp_is_metal"]:
        print(key, factor*values["mae"]["mean"], factor*values["mae"]["std"])
    else:
        print(key, values["rocauc"]["mean"],  values["rocauc"]["std"])

With this script I got:

matbench_mp_e_form 19.48588313396765 0.19626422885988018
matbench_mp_gap 194.9911496262826 6.773441655230544
matbench_mp_is_metal 0.9397143291057087 0.0028206924210185786
matbench_perovskites 32.99242523617427 1.3762758094915537
matbench_log_kvrh 0.07380833213983187 0.010354237232502703
matbench_log_gvrh 0.09913490092743893 0.011390576803829782
matbench_dielectric 0.3168320033220523 0.06471518661133054
matbench_phonons 34.112623907973145 4.56153709897016
matbench_jdft2d 50.06711240004724 11.892898998285041

only log_kvrh, log_gvrh and gap seems worse than supposed to.

Thanks in advance.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant