Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VAE example #37

Draft
wants to merge 1 commit into
base: feature/SOF-4481
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/job/train-vae.ipynb
Git LFS file not shown
58 changes: 43 additions & 15 deletions examples/job/train-vae.py
Expand Up @@ -4,7 +4,7 @@
# In[]:


get_ipython().system('pip install tensorflow rdkit_pypi pymonad tqdm deepchem')
get_ipython().system('pip install tensorflow rdkit_pypi pymonad tqdm deepchem pillow')


# In[]:
Expand All @@ -23,6 +23,7 @@
import rdkit.Chem
import tqdm

rdkit.RDLogger.DisableLog('rdApp.*')
RANDOM_SEED = 42
tqdm.tqdm.pandas()

Expand All @@ -33,6 +34,7 @@


data = pd.read_csv("assets/curated-solubility-dataset.csv")
data.to_pickle("vae_data.pkl")

# Drop duplicates based on the InChI representation
nondupes = data.drop_duplicates(subset="InChI")
Expand All @@ -43,7 +45,7 @@

# # Data Augmentation

# In[ ]:
# In[]:


augmented_data = list(smiles.copy())
Expand Down Expand Up @@ -74,7 +76,7 @@ def augment_smiles(smiles_string: str):

# # Tokenization with DeepChem

# In[ ]:
# In[]:


# DeepChem's SmileTokenizer uses the WordPiece transformer by HuggingFace (https://huggingface.co/transformers/tokenizer_summary.html), with the regular expression SMILES tokenization strategy developed by Schwaller, P. et al in https://doi.org/10.1039/c8sc02339e
Expand All @@ -93,7 +95,7 @@ def augment_smiles(smiles_string: str):
tokenized_smiles


# In[ ]:
# In[]:


# Next up, we'll remove SMILES with any unknown characters, since we don't want our generator putting those tokens in the output
Expand All @@ -108,13 +110,13 @@ def augment_smiles(smiles_string: str):

# # Create the VAE

# In[ ]:
# In[]:


from tensorflow.keras.layers import Input, Dense, Conv1D, Layer, Flatten, Reshape, Conv1DTranspose


# In[ ]:
# In[]:


class Sampling(Layer):
Expand All @@ -129,7 +131,7 @@ def call(self, inputs):
return z_mean + tf.exp(0.5 * z_log_var) * epsilon


# In[ ]:
# In[]:


latent_dim = 1
Expand All @@ -147,7 +149,7 @@ def call(self, inputs):
encoder.summary()


# In[ ]:
# In[]:


latent_inputs = Input(shape=(latent_dim,))
Expand All @@ -161,7 +163,7 @@ def call(self, inputs):
decoder.summary()


# In[ ]:
# In[]:


class VAE(tf.keras.Model):
Expand Down Expand Up @@ -209,21 +211,22 @@ def train_step(self, data):

# # Train the VAE

# In[ ]:
# In[]:


# Reshape the data as needed, scale between 0 and 1
train_data = np.array([i for i in tokenized_smiles])
train_data = np.expand_dims(train_data, -1).astype("float32") / maxlen
train_data.reshape(1,312,-1).shape
train_data = train_data.reshape(-1,maxlen,1)
train_data.shape


# In[ ]:
# In[]:


vae = VAE(encoder, decoder)
vae.compile(optimizer=tf.keras.optimizers.Adam())
vae.fit(train_data, epochs=1000, batch_size=64)
vae.fit(train_data, epochs=100, batch_size=64)


# In[ ]:
Expand Down Expand Up @@ -254,8 +257,33 @@ def invert_tokenization(tokens):
# In[ ]:


predictions = (vae.decoder.predict([np.random.uniform(low=-10, high=10, size=100)])[:,:,0] * maxlen).astype(int)
list(map(lambda tokenized: invert_tokenization(tokenized), predictions))
predictions = (vae.decoder.predict([np.linspace(-100, 1000, 50)])[:,:,0] * maxlen).astype(int)
pred_smiles = list(map(lambda tokenized: invert_tokenization(tokenized), predictions))


# In[ ]:


from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw.MolDrawing import MolDrawing
pred_mols = []
for i in pred_smiles:
mol = rdkit.Chem.MolFromSmiles(i)
if mol:
pred_mols.append(mol)


# In[ ]:


img = rdkit.Chem.Draw.MolsToGridImage(pred_mols, subImgSize=[3000,3000])
img


# In[ ]:





# In[ ]:
Expand Down
Binary file added examples/vae_data.pkl
Binary file not shown.