Skip to content

Here, we use Deep SHAP (or SHAP) to explain the behavior of nanophotonic structures learned by a convolutional neural network (CNN). Reference: https://pubs.acs.org/doi/full/10.1021/acsphotonics.0c01067

Notifications You must be signed in to change notification settings

Raman-Lab-UCLA/Explainability_for_Photonics

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

61 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Explainability for Photonics

Introduction

Welcome to the Raman Lab GitHub! This repo will walk you through the code used in the following publication: https://pubs.acs.org/doi/full/10.1021/acsphotonics.0c01067

Here, we use Deep SHAP (or SHAP) to explain the behavior of nanophotonic structures learned by a convolutional neural network (CNN).

Requirements

The following libraries are required to run the provided scripts. Specific versions are needed due to compatibility issues between Tensorflow and SHAP (as of this writing).

-Python 3.7

-Tensorflow 1.14.0

-Keras 2.3.1

-SHAP 0.31.0

-Matplotlib 3.3.2

Installation and usage instructions for Deep SHAP are at: https://github.com/slundberg/shap. For convenience, here are installation commands for the Conda distribution (after installing Anaconda: https://www.anaconda.com/products/individual).

conda create -n myenv python=3.7
conda activate myenv
conda install tensorflow-gpu==1.14.0
conda install keras
conda install -c anaconda opencv
conda install -c anaconda scikit-learn
conda install matplotlib==3.3.2
conda install pandas
conda install -c conda-forge shap==0.31.0
conda install spyder

Steps

1) Train the CNN (CNN_Train.py)

Download the files in the 'Training Data' folder and update the following lines in the 'CNN_Train.py' file:

## Define File Locations (Images, Spectra, and CNN Model Save)
img_path = 'C:/.../*.png'
spectra_path = 'C:/.../Spectra.csv'
save_dir = 'C:/.../model.h5'

Running this file will train the CNN and save the model in the specified location. Depending on the available hardware, the CNN training process can take up to a few hours.

2) Explain CNN Results (SHAP_Explanation.py)

Deep SHAP explains the predictions of an 'Base' image in reference to a 'Background'. This Background can be a collection of images or a single image. To minimize noise, our recommendation is to use a 'white' image as the Base, and the image to be evaluated as the Background. This will compare the importance of a feature, to the absence of this feature, towards a target output. Simply update the following paths and run the 'SHAP_Explanation.py' script (you can refer to the Examples folder for sample Background and Base images):

## Define File Locations (CNN, Test Image, and Background Image)
model = load_model('C:/.../model.h5', compile=False)
back_img_path = 'C:/.../Background.png'
base_img_path = 'C:/.../Base.png'

After running the script, a list of SHAP value heatmaps (shap_values) will be generated. The size and order of this list reflects the CNN's outputs, and the resolution of the heatmaps are the same as the CNN input images. Therefore, to plot a specific heatmap (corresponding to a particular wavelength), simply index the list as such:

shap.image_plot(shap_values[i], back_img.reshape(1,40,40,1), show=False) #where 'i' is a value between 0 and the total list size

Optionally, for ease of viewing, the SHAP values can be normalized and replotted like so:

import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors as colors

with open('C:/.../shap_explanations.data', 'rb') as filehandle:
    shap_values = pickle.load(filehandle)
    
X = np.arange(-20, 20, 1)
Y = np.arange(-20, 20, 1)
X, Y = np.meshgrid(X, Y)

maximum = np.max(shap_values)
minimum = -np.min(shap_values)

shap_i = shap_values[i][:][:][:][:] #where 'i' is a value between 0 and the total list size
shap_i[shap_i>0] = shap_i[shap_i>0] / maximum
shap_i[shap_i<0] = shap_i[shap_i<0] / minimum
shap_values_normalized = shap_i.squeeze()[::-1]

fig = plt.figure()
ax = fig.gca()
pcm = ax.pcolormesh(X, Y, shap_values_normalized, norm=colors.SymLogNorm(linthresh=0.01, linscale=1),cmap='bwr', vmin=-1, vmax = 1)
fig.colorbar(pcm)
ax.axis('off')

3) Explanation Validation (SHAP_Validation.py)

To validate that the explanations represent physical phenomena, we used the SHAP explanations to reconstruct the original image, which can either suppress or enhance an absorption spectrum. This reconstructed image can be imported directly into EM simulation software (e.g., Lumerical FDTD). Run the 'SHAP_Validation.py' script after specifying the location of the saved SHAP values:

#Import SHAP Values
with open('C:/.../shap_explanations.data', 'rb') as filehandle:
    shap_values = pickle.load(filehandle)

Tune the conversion settings by modifying the following line in the script:

if np.max(shap_values_convert) > shap_values_convert[i][j] > np.max(shap_values_convert)*0.05: #Convert Top 95% of Red Pixels        

Citation

If you find this repo helpful, or use any of the code you find here, please cite our work using the following:

C. Yeung, et al. Elucidating the Behavior of Nanophotonic Structures through Explainable Machine Learning Algorithms. ACS Photonics, 2020. 

About

Here, we use Deep SHAP (or SHAP) to explain the behavior of nanophotonic structures learned by a convolutional neural network (CNN). Reference: https://pubs.acs.org/doi/full/10.1021/acsphotonics.0c01067

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages