Skip to content

MEME: Generating RNN Model Explanations via Model Extraction

Notifications You must be signed in to change notification settings

dmitrykazhdan/MEME-RNN-XAI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MEME: Generating RNN Model Explanations via Model Extraction

This repository contains an implementation of MEME. MEME is a (M)odel (E)xplanation via (M)odel (E)xtraction framework, which can be used for analysing RNN models via explainable concept-based extracted models, in order to explain and improve performance of RNN models, as well as to extract useful knowledge from them.

For further details, see our paper (link coming soon).

The experiments use the following open-source datasets:

Abstract

Recurrent Neural Networks (RNNs) have achieved remarkable performance on a range of tasks. A key step to further empowering RNN-based approaches is improving their explainability and interpretability. In this work we present MEME: a model extraction approach capable of approximating RNNs with interpretable models represented by human-understandable concepts and their interactions. We demonstrate how MEME can be applied to two multivariate, continuous data case studies: Room Occupation Prediction, and In-Hospital Mortality Prediction. Using these case-studies, we show how our extracted models can be used to interpret RNNs both locally and globally, by approximating RNN decision-making via interpretable concept interactions.

Visual Abstract

alt text

Given an RNN model, we: (1) approximate its hidden space by a set of concepts. (2) approximate its hidden space dynamics by a set of transition functions, one per concept. (3) approximate its output behaviour by a concept-class mapping, specifying an output class label for every concept. For every step in (1)-(3), the parts of the RNN being approximated are highlighted in red. In (a)-(c) we cluster the RNN's training data points in their hidden representation (assumed to be two-dimensional, in this example), and use the clustering to produce a set of concepts (in this case: sick, healthy and uncertain, written as unc.). In (d)-(f) we approximate the hidden function of the RNN by a function FC, which predicts transitions between the concepts. We represent this function by a set of functions, one per concept (in this case: Fs, Fu, Fh). In (g)-(i) we approximate the output behaviour of the RNN by a function S, which predicts the output class from a concept. This function is represented by a concept-class mapping, specifying an output label for every concept (in this case: healthy→0, sick→1, and unc→1). Collectively, steps (1)-(3) are used to produce our extracted model, consisting of concepts, their interactions, and their corresponding class labels.

Processing Example

alt text

Extracted model sequence processing for three timesteps (t = 0,1,2), with uncertain as the initial concept. For each timestep t, the concept the model is in at time t is highlighted with a double border. We show the input data (x1, x2, x3), the corresponding concept transition sequence (uncertainuncertainsicksick), and the explanations for each transition function prediction. In this example, the class labels outputted by the model are not shown.

Prerequisites

TBC...

Citing

If you find this code useful in your research, please consider citing:

@article{kazhdan2020meme,
  title={MEME: Generating RNN Model Explanations via Model Extraction},
  author={Kazhdan, Dmitry and Dimanov, Botty and Jamnik, Mateja and Li{\`o}, Pietro},
  journal={arXiv preprint arXiv:2012.06954},
  year={2020}
}