Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise misplaced columns in dynotears estimation, and fix a minor bug in dataset generation. #119

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from

Conversation

ThinkNaive
Copy link

No description provided.

.gitignore Outdated Show resolved Hide resolved
@@ -547,8 +547,8 @@ def generate_dataframe_dynamic( # pylint: disable=R0914
st=sem_type, sts=s_types
)
)
intra_nodes = sorted(el for el in g.nodes if "_lag0" in el)
inter_nodes = sorted(el for el in g.nodes if "_lag0" not in el)
intra_nodes = sorted([el for el in g.nodes if "_lag0" in el], key=lambda t: t.split('_lag')[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorted accepts a generator expression, doesn't need to (waste memory!) coercing it into a list first :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorted(el for el in g.nodes if "_lag0" not in el) will lead to misplaced columns, e.g. [0_lag1, 0_lag2, 0_lag3, 1_lag1, 1_lag2, 1_lag3, 2_lag1, 2_lag2, 2_lag3] (node first, then lag, case1). But the dynotears model accepts columns like [0_lag1, 1_lag1, 2_lag1, 0_lag2, 1_lag2, 2_lag2, 0_lag3, 1_lag3, 2_lag3] (lag first, then node, case2). The default sorted expression takes a case1 style. I've test the revised code using the code below.

__test__.py

# __test__.py
from causalnex.structure.transformers import DynamicDataTransformer
import os
import random
from matplotlib import ticker
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from causalnex.structure.data_generators import gen_stationary_dyn_net_and_df
from netgraph import InteractiveGraph
from causalnex.structure.dynotears import from_pandas_dynamic
import scipy.linalg as slin

# metrics
def count_accuracy(B_true, B_est) -> tuple:
    B_true = B_true != 0
    B_est = B_est != 0
    d = min(B_est.shape)

    # fdr
    fp = np.sum(~B_true & B_est)
    pp = np.sum(B_est)
    fdr = fp / pp

    # tpr
    tp = np.sum(B_true & B_est)
    tt = np.sum(B_true)
    tpr = tp / tt

    # fpr
    tf = d * (d - 1) / 2 - np.sum(B_true) + \
        B_est.shape[0] * B_est.shape[1] - d * d
    fpr = fp / tf

    # shd
    shd = np.sum(B_true != B_est)

    # nnz
    nnz = pp

    return fdr, tpr, fpr, shd, nnz

# reproduce
def set_seed(seed):
    """
    Referred from:
    - https://stackoverflow.com/questions/38469632/tensorflow-non-repeatable-results
    """
    # Reproducibility
    random.seed(seed)
    np.random.seed(seed)
    try:
        os.environ["PYTHONHASHSEED"] = str(seed)
    except:
        pass

# plot weighted directed graph
def plot_graph(g, node_size=8, font_size=8):
    # define layout
    pos = {}
    for node in g.nodes():
        # get node vertex index and slice index
        index = node.split('_lag')
        vertex = int(index[0]) + 1
        slice = int(index[1])

        # generate position with each slice as round shape
        x = np.random.uniform(low=-0.0, high=0.0) + 0.3 * \
            np.cos(2 * np.pi * vertex / d) + 1.0 * (p - slice)
        y = np.random.uniform(low=-0.1, high=0.1) + \
            np.sin(2 * np.pi * vertex / d)
        pos[node] = (x, y)

    g = nx.DiGraph(g)
    fig, ax = plt.subplots()
    plot_instance = InteractiveGraph(
        g,
        node_size=node_size,
        node_labels=True,
        node_label_fontdict=dict(size=font_size),
        node_layout=pos,
        arrows=True,
        ax=ax
    )

# transform graph to matrix
def to_matrix(g):
    a = nx.to_numpy_array(g)
    # a /= np.abs(a).max()

    # permute array by order from max lag to instant.
    sorted_nodes = sorted(g.nodes(), key=lambda k: d * int(k[-1]) + int(k[0]))
    sorted_index = [list(g.nodes()).index(i) for i in sorted_nodes]
    p = np.zeros(a.shape)
    for y, x in enumerate(sorted_index):
        p[y, x] = 1.0

    a = p.dot(a).dot(p.T)
    return a[:, :d]

# plot matrix
def plot_matrix(gt, ge=None, names=None, rng=2.0):
    # transform graph to matrix and compute layout
    bt = to_matrix(gt)
    d = bt.shape[1]
    n_col = bt.shape[0] // d
    if ge:
        be = to_matrix(ge)
        n_row = 2
        b = [bt, be]
    else:
        n_row = 1
        b = [bt]

    # plot matrix
    fig, ax = plt.subplots(n_row, n_col, figsize=(16, 6))
    ax = ax.flatten()
    for row in range(n_row):
        for col in range(n_col):
            # split matrix for intra and inter-p
            mat = b[row][d * col: d * (col + 1), :]
            # plot matrix
            im = ax[row * n_col + col].imshow(mat, cmap="seismic",
                                            interpolation="none", vmin=-rng, vmax=rng)
            # add value labels
            for i in range(d):
                for j in range(d):
                    color = 'white' if abs(mat[i, j]) > 0.3 * rng else 'black'
                    value = '{:g}'.format(round(mat[i, j], 3))
                    ax[row * n_col + col].text(j, i, value,
                                            ha="center", va="center", color=color, fontsize=9)
            # hide axis
            ax[row * n_col + col].set_xticks([])
            ax[row * n_col + col].set_yticks([])
            # show graph type labels
            if col == 0:
                ax[row * n_col + col].set_ylabel(names[row], fontsize=10)
            # show instant and lagged labels
            if row == 0:
                if col == 0:
                    label = '$W$ (Intra-slice)'
                else:
                    label = '$A_{%d}$ (Inter-slice)' % col
                ax[row * n_col + col].set_xlabel(label, fontsize=10)
                ax[row * n_col + col].xaxis.set_label_position('top')

    # add colorbar
    cb = fig.colorbar(im, ax=ax, shrink=0.7)
    tick_locator = ticker.MaxNLocator(nbins=5)
    cb.locator = tick_locator
    cb.set_ticks([t - rng for t in range(int(rng) * 2 + 1)])
    cb.update_ticks()

    return fig

# compute loss
def compute_loss(dataset, g, p):
    X, Xlags = DynamicDataTransformer(
        p=p).fit_transform(dataset, return_df=False)
    n, d_vars = X.shape

    wa = to_matrix(g)
    w_mat, a_mat = wa[: d_vars, :], wa[d_vars:, :]

    loss = (
        0.5
        / n
        * np.square(
            np.linalg.norm(
                X.dot(np.eye(d_vars, d_vars) - w_mat)
                - Xlags.dot(a_mat), "fro"
            )
        )
    )

    h = np.trace(slin.expm(w_mat * w_mat)) - d_vars

    return loss, h

# main
set_seed(0)
d = 5
p = 3

# generate dataset
g, df, intra, inter = gen_stationary_dyn_net_and_df(
    n_samples=500,
    num_nodes=5,
    p=3,
    w_min_intra=0.5,
    w_max_intra=2.0,
    w_min_inter=0.3,
    w_max_inter=0.5,
    degree_intra=4,
    degree_inter=1,
    w_decay=1.3
)

# convert dataset for learning
dataset = df.iloc[:, :d]
dataset.columns = [t.split('_lag')[0] for t in dataset.columns]

# estimate graph
gg = from_pandas_dynamic(
    time_series=[dataset], p=3, tau_w=0.01, tau_a=0.01, lambda_w=0.05, lambda_a=0.05)

# compute metrics
fdr, tpr, fpr, shd, nnz = count_accuracy(
    to_matrix(g), to_matrix(gg))
print("FDR:{:.2%} TPR:{:.2%} FPR:{:.2%} SHD:{:2d} NNZ:{:2d}".format(
    fdr, tpr, fpr, shd, nnz
))

# compute loss for both true graph and estimated graph
print('TrueGraph: loss={} h={}'.format(*compute_loss(dataset, g, p)))
print('Estimated: loss={} h={}'.format(*compute_loss(dataset, gg, p)))

# plot graph and matrix
# plot_graph(g)
# plot_graph(gg)
fig = plot_matrix(g, gg, names=('True', 'Estimated'))

# save image
imgfile = "graph.svg"
fig.savefig(imgfile, transparent=True)

plt.show(block=True)

@oentaryorj oentaryorj added the bug Something isn't working label Sep 8, 2021
@oentaryorj
Copy link
Contributor

Hi @ThinkNaive, thanks for the PR. It seems that the unit tests are failing now. Would you be able to help address them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants