GVAE Version #1:
import math
import torch
import torch.nn as nn
from torch_scatter import scatter_softmax
from torch.nn import Linear
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import BatchNorm
from config import SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS
from utils import graph_representation_to_molecule, to_one_hot
from tqdm import tqdm
class GVAE(nn.Module):
def __init__(self, encoder_embedding_size=64, latent_embedding_size=128, heads=4,
node_input_dim=30,
edge_input_dim=3,
hidden_dim=32,
latent_dim=16):
super(GVAE, self).__init__()
self.node_input_dim = node_input_dim
self.edge_input_dim = edge_input_dim
self.encoder_embedding_size = encoder_embedding_size
self.edge_dim = None
self.heads = heads
self.latent_embedding_size = latent_embedding_size
self.latent_dim = latent_dim
self.conv1 = None
self.conv1_initialized = False
self.num_edge_types = len(SUPPORTED_EDGES)
self.num_atom_types = len(SUPPORTED_ATOMS)
self.max_num_atoms = MAX_MOLECULE_SIZE
self.atomic_numbers = ATOMIC_NUMBERS
self.decoder_hidden_neurons = 512
self.lin_edge = None
self.bn1 = BatchNorm(self.encoder_embedding_size)
self.conv2 = TransformerConv(self.encoder_embedding_size,
self.encoder_embedding_size,
heads=self.heads,
concat=False,
beta=True,
edge_dim=self.edge_dim)
self.conv2.message = self.message
self.bn2 = BatchNorm(self.encoder_embedding_size)
self.conv3 = TransformerConv(self.encoder_embedding_size,
self.encoder_embedding_size,
heads=self.heads,
concat=False,
beta=True,
edge_dim=self.edge_dim)
self.conv3.message = self.message
self.bn3 = BatchNorm(self.encoder_embedding_size)
self.conv4 = TransformerConv(self.encoder_embedding_size,
self.encoder_embedding_size,
heads=self.heads,
concat=False,
beta=True,
edge_dim=self.edge_dim)
self.conv4.message = self.message
self.mu_transform = Linear(self.encoder_embedding_size, self.latent_dim)
self.logvar_transform = Linear(self.encoder_embedding_size, self.latent_dim)
self.linear_1 = Linear(self.latent_dim, self.decoder_hidden_neurons)
self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons)
atom_output_dim = self.max_num_atoms * (self.num_atom_types + 1)
self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim)
edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * (self.num_edge_types + 1))
self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim)
def encode(self, x, edge_attr, edge_index, batch_index):
x = self.conv1(x, edge_index, edge_attr).relu()
x = self.bn1(x)
x = self.conv2(x, edge_index, edge_attr).relu()
x = self.bn2(x)
x = self.conv3(x, edge_index, edge_attr).relu()
x = self.bn3(x)
x = self.conv4(x, edge_index, edge_attr).relu()
mu = self.mu_transform(x)
logvar = self.logvar_transform(x)
return mu, logvar
def decode_graph(self, graph_z):
z = self.linear_1(graph_z).relu()
z = self.linear_2(z).relu()
atom_logits = self.atom_decode(z)
edge_logits = self.edge_decode(z)
return atom_logits, edge_logits
def decode(self, z, batch_index):
node_logits = []
triu_logits = []
for graph_id in torch.unique(batch_index):
graph_z = z[graph_id]
atom_logits, edge_logits = self.decode_graph(graph_z)
node_logits.append(atom_logits)
triu_logits.append(edge_logits)
node_logits = torch.cat(node_logits)
triu_logits = torch.cat(triu_logits)
return triu_logits, node_logits
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def forward(self, x, edge_attr, edge_index, batch_index):
if not self.conv1_initialized:
feature_size = x.shape[1]
self.edge_dim = edge_attr.shape[1]
if self.conv1 is None or self.conv1.in_channels != feature_size or self.conv1.edge_dim != self.edge_dim:
print(f"Initializing/Reinitializing conv1 with feature_size: {feature_size}, edge_dim: {self.edge_dim}")
self.conv1 = TransformerConv(feature_size, self.encoder_embedding_size,
heads=self.heads, concat=False, beta=True, edge_dim=self.edge_dim)
self.conv1.message = self.message
self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
self.conv1.message = self.message
self.conv1_initialized = True
mu, logvar = self.encode(x, edge_attr, edge_index, batch_index)
z = self.reparameterize(mu, logvar)
triu_logits, node_logits = self.decode(z, batch_index)
return triu_logits, node_logits, mu, logvar
def message(self, query_i, key_j, value_j, edge_attr, index, ptr, size_i):
if edge_attr is not None:
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.encoder_embedding_size // self.heads)
if edge_attr.size(1) == key_j.size(1) and edge_attr.size(2) == key_j.size(2):
key_j = key_j + edge_attr
alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.encoder_embedding_size // self.heads)
alpha = scatter_softmax(alpha, index, dim=0)
return value_j * alpha.view(-1, self.heads, 1)
GVAE Version #2:
import math
import torch
import torch.nn as nn
from torch_scatter import scatter_softmax
from torch.nn import Linear
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import Set2Set
from torch_geometric.nn import BatchNorm
from config import SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS
from utils import graph_representation_to_molecule, to_one_hot
from tqdm import tqdm
class GVAE(nn.Module):
def __init__(self, encoder_embedding_size=64, latent_embedding_size=128, heads=4,
node_input_dim=30, # From x shape: 30 features per node
edge_input_dim=3, # Updated to match the actual edge attribute dimensions
hidden_dim=32,
latent_dim=16):
super(GVAE, self).__init__()
self.node_input_dim = node_input_dim
self.edge_input_dim = edge_input_dim
self.encoder_embedding_size = encoder_embedding_size
self.edge_dim = None
self.heads = heads
self.latent_embedding_size = latent_embedding_size
self.latent_dim = latent_dim
self.conv1 = None
self.conv1_initialized = False
self.num_edge_types = len(SUPPORTED_EDGES)
self.num_atom_types = len(SUPPORTED_ATOMS)
self.max_num_atoms = MAX_MOLECULE_SIZE
self.atomic_numbers = ATOMIC_NUMBERS
self.decoder_hidden_neurons = 512
self.lin_edge = None
self.bn1 = BatchNorm(self.encoder_embedding_size)
self.conv2 = TransformerConv(self.encoder_embedding_size,
self.encoder_embedding_size,
heads=self.heads,
concat=False,
beta=True,
edge_dim=self.edge_dim)
self.conv2.message = self.message # Add this line
self.bn2 = BatchNorm(self.encoder_embedding_size)
self.conv3 = TransformerConv(self.encoder_embedding_size,
self.encoder_embedding_size,
heads=self.heads,
concat=False,
beta=True,
edge_dim=self.edge_dim)
self.conv3.message = self.message # Add this line
self.bn3 = BatchNorm(self.encoder_embedding_size)
self.conv4 = TransformerConv(self.encoder_embedding_size,
self.encoder_embedding_size,
heads=self.heads,
concat=False,
beta=True,
edge_dim=self.edge_dim)
self.conv4.message = self.message
self.pooling = Set2Set(self.encoder_embedding_size, processing_steps=4)
self.mu_transform = Linear(self.encoder_embedding_size * 2, self.latent_dim)
self.logvar_transform = Linear(self.encoder_embedding_size * 2, self.latent_dim)
self.linear_1 = Linear(self.latent_dim, self.decoder_hidden_neurons)
self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons)
atom_output_dim = self.max_num_atoms * (self.num_atom_types + 1)
self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim)
edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * (self.num_edge_types + 1))
self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim)
def encode(self, x, edge_attr, edge_index, batch_index):
x = self.conv1(x, edge_index, edge_attr).relu()
x = self.bn1(x)
x = self.conv2(x, edge_index, edge_attr).relu()
x = self.bn2(x)
x = self.conv3(x, edge_index, edge_attr).relu()
x = self.bn3(x)
x = self.conv4(x, edge_index, edge_attr).relu()
x = self.pooling(x, batch_index)
mu = self.mu_transform(x)
logvar = self.logvar_transform(x)
return mu, logvar
def decode_graph(self, graph_z):
z = self.linear_1(graph_z).relu()
z = self.linear_2(z).relu()
atom_logits = self.atom_decode(z)
edge_logits = self.edge_decode(z)
return atom_logits, edge_logits
def decode(self, z, batch_index):
node_logits = []
triu_logits = []
for graph_id in torch.unique(batch_index):
graph_z = z[graph_id]
atom_logits, edge_logits = self.decode_graph(graph_z)
node_logits.append(atom_logits)
triu_logits.append(edge_logits)
node_logits = torch.cat(node_logits)
triu_logits = torch.cat(triu_logits)
return triu_logits, node_logits
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def forward(self, x, edge_attr, edge_index, batch_index):
if not self.conv1_initialized:
feature_size = x.shape[1]
self.edge_dim = edge_attr.shape[1]
if self.conv1 is None or self.conv1.in_channels != feature_size or self.conv1.edge_dim != self.edge_dim:
print(f"Initializing/Reinitializing conv1 with feature_size: {feature_size}, edge_dim: {self.edge_dim}")
self.conv1 = TransformerConv(feature_size, self.encoder_embedding_size,
heads=self.heads, concat=False, beta=True, edge_dim=self.edge_dim)
self.conv1.message = self.message
self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
self.conv1.message = self.message
self.conv1_initialized = True
mu, logvar = self.encode(x, edge_attr, edge_index, batch_index)
z = self.reparameterize(mu, logvar)
triu_logits, node_logits = self.decode(z, batch_index)
return triu_logits, node_logits, mu, logvar
def sample_mols(self, num=10000):
print("Sampling molecules ... ")
n_valid = 0
for _ in tqdm(range(num)):
z = torch.randn(1, self.latent_embedding_size)
dummy_batch_index = torch.Tensor([0]).int()
triu_logits, node_logits = self.decode(z, dummy_batch_index)
edge_matrix_shape = (int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1)) / 2), len(SUPPORTED_EDGES) + 1)
triu_preds_matrix = triu_logits.reshape(edge_matrix_shape)
triu_preds = torch.argmax(triu_preds_matrix, dim=1)
node_matrix_shape = (MAX_MOLECULE_SIZE, (len(SUPPORTED_ATOMS) + 1))
node_preds_matrix = node_logits.reshape(node_matrix_shape)
node_preds = torch.argmax(node_preds_matrix[:, :9], dim=1)
node_preds_one_hot = to_one_hot(node_preds, options=ATOMIC_NUMBERS)
atom_numbers_dummy = torch.Tensor(ATOMIC_NUMBERS).repeat(node_preds_one_hot.shape[0], 1)
atom_types = torch.masked_select(atom_numbers_dummy, node_preds_one_hot.bool())
smiles, mol = graph_representation_to_molecule(atom_types, triu_preds.float())
if smiles and "." not in smiles:
print("Successfully generated: ", smiles)
n_valid += 1
return n_valid
def message(self, query_i, key_j, value_j, edge_attr, index, ptr, size_i):
if edge_attr is not None:
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.encoder_embedding_size // self.heads)
if edge_attr.size(1) == key_j.size(1) and edge_attr.size(2) == key_j.size(2):
key_j = key_j + edge_attr
alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.encoder_embedding_size // self.heads)
alpha = scatter_softmax(alpha, index, dim=0)
return value_j * alpha.view(-1, self.heads, 1)
Reconstruction Training:
def graph_representation_to_molecule(tensor):
# Create an empty RDKit molecule
mol = Chem.RWMol()
carbon_indices = []
atoms=[]
atomattributes = tensor_to_atoms(tensor)
for atom in atomattributes:
atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
atoms.append(atm)
counter = 0
for atom in atomattributes:
if atom.return_symbol() == 'C':
carbon_indices.append(counter)
counter +=1
counter = 0
for carbons in carbon_indices:
if counter+1<len(carbon_indices) and carbon_indices[counter+1]<len(atoms) and counter<6:
if counter%2 == 0:
mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.SINGLE)
atomattributes[carbons].valence-=1
atomattributes[carbon_indices[counter+1]].valence -= 1
atomattributes[carbons].inRing -= True
atomattributes[carbon_indices[counter + 1]].inRing -= True
else:
mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.DOUBLE)
atomattributes[carbons].inRing -= True
atomattributes[carbon_indices[counter+1]].inRing -= True
atomattributes[carbons].valence -= 2
atomattributes[carbon_indices[counter + 1]].valence -= 2
counter += 1
counter = 0
carboncounter = 0
for atom in atomattributes:
for atom in atomattributes:
if atom.return_symbol() == 'C' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
#Oxygen
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'O' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Chlorine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'Cl' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
# Sulfur
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'S' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
# Nitrogen
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'N' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter<len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Bromine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'Br' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Florine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'F' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
# Phosphorus
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'P' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Iodine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'I' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
atoms_to_remove = []
for atom in mol.GetAtoms():
if atom.GetDegree() == 0: # If the atom has no bonds (degree 0)
atoms_to_remove.append(atom.GetIdx())
# Remove atoms in reverse order to avoid indexing issues
for atom_idx in sorted(atoms_to_remove, reverse=True):
mol.RemoveAtom(atom_idx)
img = Draw.MolToImage(mol)
smiles = Chem.MolToSmiles(mol)
print(smiles)
return mol
AllChem.Compute2DCoords(mol)
rings = mol.GetRingInfo().AtomRings()
carbon_rings = [ring for ring in rings if all(mol.GetAtomWithIdx(idx).GetSymbol() == 'C' for idx in ring)]
# Reorder the SMILES string to create rings with carbon blocks
reordered_smiles = reorder_smiles_to_create_rings(mol_representation, atom_indices)
# Print the original and reordered SMILES representations
print("Original SMILES Representation:", smiles_representation)
print("Reordered SMILES Representation:", reordered_smiles)
Hybrid Between AI and Rule-Based:
class BondPredictionModel(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(BondPredictionModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return self.sigmoid(x)
model = BondPredictionModel(input_dim=5, hidden_dim=32)
def predict_bond(atom1, atom2, atom_features):
features = torch.cat((atom_features[atom1], atom_features[atom2]), dim=0)
bond_prediction = model(features)
return bond_prediction.item() > 0.5
def graph_representation_to_molecule(tensor):
mol = Chem.RWMol()
carbon_indices = []
atoms = []
atomattributes = tensor_to_atoms(tensor)
for atom in atomattributes:
atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
atoms.append(atm)
carbon_indices = [idx for idx, atom in enumerate(atomattributes) if atom.return_symbol() == 'C']
for i in range(len(carbon_indices)):
for j in range(i + 1, len(carbon_indices)):
if predict_bond(carbon_indices[i], carbon_indices[j], tensor):
mol.AddBond(atoms[carbon_indices[i]], atoms[carbon_indices[j]], Chem.BondType.SINGLE)
atomattributes[carbon_indices[i]].valence -= 1
atomattributes[carbon_indices[j]].valence -= 1
for atom in atomattributes:
if atom.return_symbol() in ['O', 'N', 'Cl','P','S','Br','I']:
connected_carbon_idx = random.choice(carbon_indices)
if predict_bond(atom, connected_carbon_idx, tensor):
mol.AddBond(atoms[atom.idx], atoms[connected_carbon_idx], Chem.BondType.SINGLE)
atomattributes[atom.idx].valence -= 1
atomattributes[connected_carbon_idx].valence -= 1
atoms_to_remove = []
for atom in mol.GetAtoms():
if atom.GetDegree() == 0:
atoms_to_remove.append(atom.GetIdx())
for atom_idx in sorted(atoms_to_remove, reverse=True):
mol.RemoveAtom(atom_idx)
AllChem.Compute2DCoords(mol)
smiles = Chem.MolToSmiles(mol)
print(f"Generated SMILES: {smiles}")
img = Draw.MolToImage(mol)
img.show()
return mol
Other Important Files:
UTILS:
import random
import torch
from torch_geometric.utils import to_dense_adj
from rdkit import Chem
import random
from rdkit import RDLogger
from ShapeDebugger import smiles
from config import DEVICE as device
from config import (SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS, valence_dict,
DISABLE_RDKIT_WARNINGS)
import matplotlib.pyplot as plt
# Disable rdkit warnings
if DISABLE_RDKIT_WARNINGS:
RDLogger.DisableLog('rdApp.*')
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
def count_rings(mol):
"""Count and print the number of rings in the molecule."""
ssr = Chem.GetSymmSSSR(mol)
num_rings = len(ssr)
print(f"Number of rings detected: {num_rings}")
for i, ring in enumerate(ssr):
print(f"Ring {i + 1}: Atoms {list(ring)}")
return num_rings
def log_ring_info(mol):
"""Log information about rings in the molecule."""
ssr = Chem.GetSymmSSSR(mol)
num_rings = len(ssr)
print(f"Number of rings detected: {num_rings}")
for i, ring in enumerate(ssr):
print(f"Ring {i + 1}: Atoms {list(ring)}")
return ssr
def break_down_rings(mol):
"""Break down complex ring structures in the molecule."""
ssr = log_ring_info(mol)
if not ssr:
print("No rings detected in the molecule.")
return mol
ring_bonds = set()
for ring in ssr:
for bond_idx in ring:
ring_bonds.add(bond_idx)
if not ring_bonds:
print("No ring bonds found in the molecule.")
return mol
fragmented_mol = AllChem.FragmentOnBonds(mol, list(ring_bonds), addDummies=False)
fragmented_mol = Chem.RemoveHs(fragmented_mol) # Remove hydrogen atoms added during fragmentation
return fragmented_mol
def count_parameters(model):
"""
Counts the number of parameters for a Pytorch model
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def kl_loss(mu=None, logstd=None):
"""
Closed formula of the KL divergence for normal distributions
"""
MAX_LOGSTD = 10
logstd = logstd.clamp(max=MAX_LOGSTD)
kl_div = -0.5 * torch.mean(torch.sum(1 + 2 * logstd - mu ** 2 - logstd.exp() ** 2, dim=1))
kl_div = kl_div.clamp(max=1000)
return kl_div
def slice_graph_targets(graph_id, edge_targets, node_targets, batch_index):
"""
Slices out the upper triangular part of an adjacency matrix for
a single graph from a large adjacency matrix for a full batch.
For the node features the corresponding section in the batch is sliced out.
--------
graph_id: The ID of the graph (in the batch index) to slice
edge_targets: A dense adjacency matrix for the whole batch
node_targets: A tensor of node labels for the whole batch
batch_index: The node to graph map for the batch
"""
graph_mask = torch.eq(batch_index, graph_id)
graph_edge_targets = edge_targets[graph_mask][:, graph_mask]
size = graph_edge_targets.shape[0]
triu_indices = torch.triu_indices(size, size, offset=1)
triu_mask = torch.squeeze(to_dense_adj(triu_indices)).bool()
graph_edge_targets = graph_edge_targets[triu_mask]
graph_node_targets = node_targets[graph_mask]
return graph_edge_targets, graph_node_targets
def slice_graph_predictions(triu_logits, node_logits, graph_triu_size, triu_start_point, graph_size, node_start_point):
"""
Slices out the corresponding section from a list of batch triu values.
Given a start point and the size of a graph's triu, simply slices
the section from the batch list.
-------
triu_logits: A batch of triu predictions of different graphs
node_logits: A batch of node predictions with fixed size MAX_GRAPH_SIZE
graph_triu_size: Size of the triu of the graph to slice
triu_start_point: Index of the first node of this graph in the triu batch
graph_size: Max graph size
node_start_point: Index of the first node of this graph in the nodes batch
"""
graph_logits_triu = torch.squeeze(triu_logits[triu_start_point:triu_start_point + graph_triu_size])
graph_node_logits = torch.squeeze(node_logits[node_start_point:node_start_point + graph_size])
return graph_logits_triu, graph_node_logits
def slice_edge_type_from_edge_feats(edge_feats):
"""
This function only works for the MolGraphConvFeaturizer used in the dataset.
It slices the one-hot encoded edge type from the edge feature matrix.
The first 4 values stand for ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"].
"""
edge_types_one_hot = edge_feats[:, :4]
edge_types = edge_types_one_hot.nonzero(as_tuple=False)
edge_types[:, 1] = edge_types[:, 1] + 1
return edge_types
def slice_atom_type_from_node_feats(node_features, as_index=False):
"""
This function only works for the MolGraphConvFeaturizer used in the dataset.
It slices the one-hot encoded atom type from the node feature matrix.
Unknown atom types are not considered and not expected in the datset.
"""
supported_atoms = SUPPORTED_ATOMS
atomic_numbers = ATOMIC_NUMBERS
atom_types_one_hot = node_features[:, :len(supported_atoms)]
if not as_index:
atom_numbers_dummy = torch.Tensor(atomic_numbers).repeat(atom_types_one_hot.shape[0], 1)
atom_types = torch.masked_select(atom_numbers_dummy, atom_types_one_hot.bool())
else:
atom_types = torch.argmax(atom_types_one_hot, dim=1)
return atom_types
def to_one_hot(x, options):
"""
Converts a tensor of values to a one-hot vector
based on the entries in options.
"""
return torch.nn.functional.one_hot(x.long(), len(options))
def squared_difference(input, target):
return (input - target) ** 2
def triu_to_dense(triu_values, num_nodes):
"""
Converts a triangular upper part of a matrix as flat vector
to a squared adjacency matrix with a specific size (num_nodes).
"""
dense_adj = torch.zeros((num_nodes, num_nodes)).to(device).float()
triu_indices = torch.triu_indices(num_nodes, num_nodes, offset=1)
tril_indices = torch.tril_indices(num_nodes, num_nodes, offset=-1)
dense_adj[triu_indices[0], triu_indices[1]] = triu_values
dense_adj[tril_indices[0], tril_indices[1]] = triu_values
return dense_adj
def triu_to_3d_dense(triu_values, num_nodes, depth=len(SUPPORTED_EDGES)):
"""
Converts the triangular upper part of a matrix
for several dimensions into a 3d tensor.
"""
adj_matrix_3d = torch.empty((num_nodes, num_nodes, depth), dtype=torch.float, device=device)
for edge_type in range(len(SUPPORTED_EDGES)):
adj_mat_edge_type = triu_to_dense(triu_values[:, edge_type].float(), num_nodes)
adj_matrix_3d[:, :, edge_type] = adj_mat_edge_type
return adj_matrix_3d
def calculate_node_edge_pair_loss(node_tar, edge_tar, node_pred, edge_pred):
"""
Calculates a loss based on the sum of node-edge pairs.
node_tar: [nodes, supported atoms]
node_pred: [max nodes, supported atoms + 1]
edge_tar: [triu values for target nodes, supported edges]
edge_pred: [triu values for predicted nodes, supported edges]
"""
edge_pred_3d = triu_to_3d_dense(edge_pred, node_pred.shape[0])
edge_tar_3d = triu_to_3d_dense(edge_tar, node_tar.shape[0])
node_edge_preds = torch.empty((MAX_MOLECULE_SIZE, len(SUPPORTED_ATOMS), len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
for edge in range(len(SUPPORTED_EDGES)):
node_edge_preds[:, :, edge] = torch.matmul(edge_pred_3d[:, :, edge], node_pred[:, :9])
node_edge_tar = torch.empty((node_tar.shape[0], len(SUPPORTED_ATOMS), len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
for edge in range(len(SUPPORTED_EDGES)):
node_edge_tar[:, :, edge] = torch.matmul(edge_tar_3d[:, :, edge], node_tar.float())
node_edge_pred_matrix = torch.sum(node_edge_preds, dim=0)
node_edge_tar_matrix = torch.sum(node_edge_tar, dim=0)
if torch.equal(node_edge_pred_matrix.int(), node_edge_tar_matrix.int()):
print("Reconstructed node-edge pairs: ", node_edge_pred_matrix.int())
node_edge_loss = torch.mean(sum(squared_difference(node_edge_pred_matrix, node_edge_tar_matrix.float())))
node_edge_node_preds = torch.empty((MAX_MOLECULE_SIZE, MAX_MOLECULE_SIZE, len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
for edge in range(len(SUPPORTED_EDGES)):
node_edge_node_preds[:, :, edge] = torch.matmul(node_edge_preds[:, :, edge], node_pred[:, :9].t())
node_edge_node_tar = torch.empty((node_tar.shape[0], node_tar.shape[0], len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
for edge in range(len(SUPPORTED_EDGES)):
node_edge_node_tar[:, :, edge] = torch.matmul(node_edge_tar[:, :, edge], node_tar.float().t())
node_edge_node_loss = sum(squared_difference(torch.sum(node_edge_node_preds, [0, 1]), torch.sum(node_edge_node_tar, [0, 1])))
return node_edge_loss
def approximate_recon_loss(node_targets, node_preds, triu_targets, triu_preds):
"""
See: https://github.com/seokhokang/graphvae_approx/
"""
onehot_node_targets = to_one_hot(node_targets, SUPPORTED_ATOMS)
onehot_triu_targets = to_one_hot(triu_targets, ["None"] + SUPPORTED_EDGES)
node_matrix_shape = (MAX_MOLECULE_SIZE, (len(SUPPORTED_ATOMS) + 1))
node_preds_matrix = node_preds.reshape(node_matrix_shape)
edge_matrix_shape = (int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1)) / 2), len(SUPPORTED_EDGES) + 1)
triu_preds_matrix = triu_preds.reshape(edge_matrix_shape)
node_preds_reduced = torch.sum(node_preds_matrix[:, :9], 0)
node_targets_reduced = torch.sum(onehot_node_targets, 0)
triu_preds_reduced = torch.sum(triu_preds_matrix[:, 1:], 0)
triu_targets_reduced = torch.sum(onehot_triu_targets[:, 1:], 0)
node_loss = sum(squared_difference(node_preds_reduced, node_targets_reduced.float()))
edge_loss = sum(squared_difference(triu_preds_reduced, triu_targets_reduced.float()))
node_edge_loss = calculate_node_edge_pair_loss(onehot_node_targets, onehot_triu_targets, node_preds_matrix, triu_preds_matrix)
approx_loss = node_loss + edge_loss + node_edge_loss
if all(node_targets_reduced == node_preds_reduced.int()) and all(triu_targets_reduced == triu_preds_reduced.int()):
print("Reconstructed all edges: ", node_targets_reduced)
print("and all nodes: ", node_targets_reduced)
return approx_loss
def gvae_loss(triu_logits, node_logits, edge_index, edge_types, node_types, mu, logvar, batch_index, kl_beta):
"""
Calculates the loss for the graph variational autoencoder,
consisting of a node loss, an edge loss and the KL divergence.
"""
batch_edge_targets = torch.squeeze(to_dense_adj(edge_index))
num_edges = edge_index.shape[1]
if edge_types.shape[0] < num_edges:
print(f"Warning: edge_types ({edge_types.shape[0]}) is smaller than number of edges ({num_edges})")
edge_types = torch.nn.functional.pad(edge_types, (0, 0, 0, num_edges - edge_types.shape[0]))
elif edge_types.shape[0] > num_edges:
print(f"Warning: edge_types ({edge_types.shape[0]}) is larger than number of edges ({num_edges})")
edge_types = edge_types[:num_edges]
batch_edge_targets[edge_index[0], edge_index[1]] = edge_types[:, 1].float()
graph_size = MAX_MOLECULE_SIZE * (len(SUPPORTED_ATOMS) + 1)
graph_triu_size = int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1)) / 2) * (len(SUPPORTED_EDGES) + 1)
batch_recon_loss = []
triu_indices_counter = 0
graph_size_counter = 0
for graph_id in torch.unique(batch_index):
triu_targets, node_targets = slice_graph_targets(graph_id, batch_edge_targets, node_types, batch_index)
triu_preds, node_preds = slice_graph_predictions(triu_logits, node_logits, graph_triu_size, triu_indices_counter, graph_size, graph_size_counter)
triu_indices_counter = triu_indices_counter + graph_triu_size
graph_size_counter = graph_size_counter + graph_size
recon_loss = approximate_recon_loss(node_targets, node_preds, triu_targets, triu_preds)
batch_recon_loss.append(recon_loss)
num_graphs = torch.unique(batch_index).shape[0]
batch_recon_loss = torch.true_divide(sum(batch_recon_loss), num_graphs)
kl_divergence = kl_loss(mu, logvar)
return batch_recon_loss + kl_beta * kl_divergence, kl_divergence
def tensor_to_symbols(node_types_tensor):
"""
Convert a tensor of atomic numbers to a list of element symbols.
"""
atomic_numbers = node_types_tensor.numpy().astype(int)
symbols = [Chem.GetPeriodicTable().GetElementSymbol(int(num)) for num in atomic_numbers]
return symbols
def triu_to_dense(adjacency_triu, num_nodes):
"""
Converts an upper triangular adjacency matrix to a dense adjacency matrix.
"""
adjacency_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.int)
idx = 0
for i in range(num_nodes):
for j in range(i + 1, num_nodes):
adjacency_matrix[i, j] = adjacency_triu[idx].item()
adjacency_matrix[j, i] = adjacency_triu[idx].item()
idx += 1
return adjacency_matrix
def random_bond_type():
"""
Randomly selects a bond type (single, double, or triple).
"""
bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE]
return random.choice(bond_types)
def can_add_bond(element_symbol, current_bonds):
"""
Checks if a bond can be added based on the atom's valence and current bonds.
"""
return current_bonds < valence_dict.get(element_symbol, 0)
class Atom:
def __init__(self, symbol, valence, atomic_number,inRing):
self.symbol = symbol
self.valence = valence
self.atomic_number = atomic_number
self.inRing = inRing
def return_symbol(self):
return self.symbol
def return_valence(self):
return self.valence
def return_atomic_number(self):
return self.valence
def return_inRing(self):
return self.inRing
# Function to create an RDKit molecule and return its SMILES representation
def tensor_to_atoms(tensor):
atoms = []
for atomno in tensor:
if atomno:
if atomno == 6:
symbol = 'C'
if atomno == 7:
symbol = 'N'
if atomno == 8:
symbol = 'O'
if atomno == 9:
symbol = 'F'
if atomno == 15:
symbol = 'P'
if atomno == 16:
symbol = 'S'
if atomno == 17:
symbol = 'Cl'
if atomno == 35:
symbol = 'Br'
if atomno == 53:
symbol = 'I'
valence = valence_dict.get(symbol)
atoms.append(Atom(symbol, valence, atomno,False))
return atoms
def graph_representation_to_molecule(tensor):
mol = Chem.RWMol()
carbon_indices = []
atoms = []
atomattributes = tensor_to_atoms(tensor)
for atom in atomattributes:
atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
atoms.append(atm)
carbon_indices = [idx for idx, atom in enumerate(atomattributes) if atom.return_symbol() == 'C']
for i in range(len(carbon_indices)):
for j in range(i + 1, len(carbon_indices)):
if predict_bond(carbon_indices[i], carbon_indices[j], tensor):
mol.AddBond(atoms[carbon_indices[i]], atoms[carbon_indices[j]], Chem.BondType.SINGLE)
atomattributes[carbon_indices[i]].valence -= 1
atomattributes[carbon_indices[j]].valence -= 1
for atom in atomattributes:
if atom.return_symbol() in ['O', 'N', 'Cl','P','S','Br','I']:
connected_carbon_idx = random.choice(carbon_indices)
if predict_bond(atom, connected_carbon_idx, tensor):
mol.AddBond(atoms[atom.idx], atoms[connected_carbon_idx], Chem.BondType.SINGLE)
atomattributes[atom.idx].valence -= 1
atomattributes[connected_carbon_idx].valence -= 1
atoms_to_remove = []
for atom in mol.GetAtoms():
if atom.GetDegree() == 0:
atoms_to_remove.append(atom.GetIdx())
for atom_idx in sorted(atoms_to_remove, reverse=True):
mol.RemoveAtom(atom_idx)
AllChem.Compute2DCoords(mol)
smiles = Chem.MolToSmiles(mol)
print(f"Generated SMILES: {smiles}")
img = Draw.MolToImage(mol)
img.show()
return mol
def graph_to_molecule(tensor):
# Create an empty RDKit molecule
mol = Chem.RWMol()
carbon_indices = []
atoms=[]
if check_carbon_polymer(tensor) == True:
counter =0;
for atom in tensor:
if not (counter%6==0):
tensor[counter] = random.choice(ATOMIC_NUMBERS)
atomattributes = tensor_to_atoms(tensor)
for atom in atomattributes:
atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
atoms.append(atm)
counter = 0
for atom in atomattributes:
if atom.return_symbol() == 'C':
carbon_indices.append(counter)
counter +=1
counter = 0
for carbons in carbon_indices:
if counter+1<len(carbon_indices) and carbon_indices[counter+1]<len(atoms) and counter<6:
if counter%2 == 0:
mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.SINGLE)
atomattributes[carbons].valence-=1
atomattributes[carbon_indices[counter+1]].valence -= 1
atomattributes[carbons].inRing -= True
atomattributes[carbon_indices[counter + 1]].inRing -= True
else:
mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.DOUBLE)
atomattributes[carbons].inRing -= True
atomattributes[carbon_indices[counter+1]].inRing -= True
atomattributes[carbons].valence -= 2
atomattributes[carbon_indices[counter + 1]].valence -= 2
counter += 1
counter = 0
carboncounter = 0
for atom in atomattributes:
for atom in atomattributes:
if atom.return_symbol() == 'C' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
#Oxygen
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == tensor and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Chlorine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'Cl' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
# Sulfur
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'S' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
# Nitrogen
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'N' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter<len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Bromine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'Br' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Florine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'F' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
# Phosphorus
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'P' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
#Iodine
counter = 0
carboncounter = 0
for atom in atomattributes:
if atom.return_symbol() == 'I' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
carboncounter + 1] < len(atoms) and counter < len(atomattributes):
if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
carbon_indices[carboncounter + 1]].return_valence() > 0:
if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
atomattributes[counter].valence -= 1
atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
carbon_indices.append(atoms[counter])
counter += 1
carboncounter += 1
if carboncounter == 6:
carboncounter = 0
atoms_to_remove = []
for atom in mol.GetAtoms():
if atom.GetDegree() == 0: # If the atom has no bonds (degree 0)
atoms_to_remove.append(atom.GetIdx())
# Remove atoms in reverse order to avoid indexing issues
for atom_idx in sorted(atoms_to_remove, reverse=True):
mol.RemoveAtom(atom_idx)
img = Draw.MolToImage(mol)
smiles = Chem.MolToSmiles(mol)
print(smiles)
return mol
AllChem.Compute2DCoords(mol)
rings = mol.GetRingInfo().AtomRings()
carbon_rings = [ring for ring in rings if all(mol.GetAtomWithIdx(idx).GetSymbol() == 'C' for idx in ring)]
# Reorder the SMILES string to create rings with carbon blocks
reordered_smiles = reorder_smiles_to_create_rings(mol_representation, atom_indices)
# Print the original and reordered SMILES representations
print("Original SMILES Representation:", smiles_representation)
print("Reordered SMILES Representation:", reordered_smiles)
def graph_smiles(mol):
# Draw the molecule using RDKit
img = Draw.MolToImage(mol)
# Display the image with Matplotlib
plt.figure(figsize=(50, 50))
plt.imshow(img)
plt.axis('off')
plt.show()
Config:
DEVICE = "cpu"
SUPPORTED_EDGES = ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"]
SUPPORTED_ATOMS = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I"]
ATOMIC_NUMBERS = [6, 7, 8, 9, 15, 16, 17, 35, 53]
valence_dict = {
"H": 1,
"C": 4,
"N": 3,
"O": 2,
"F": 1,
"P": 3,
"S": 2,
"Cl":1,
"Br": 1,
"I": 1,
}
MAX_MOLECULE_SIZE = 50
DISABLE_RDKIT_WARNINGS = True
Training:
import torch
from torch_geometric.data import DataLoader
from inhibitordataset import MoleculeDataset
from tqdm import tqdm
import numpy as np
import mlflow.pytorch
from rdkit import Chem
from utils import (count_parameters, gvae_loss,
slice_edge_type_from_edge_feats, slice_atom_type_from_node_feats,
graph_representation_to_molecule, to_one_hot)
from Gvaeold import GVAE
from config import DEVICE as device
strings = []
train_dataset = MoleculeDataset(root="data/", filename="/Users/Siddhu/Desktop/AB Drugs/ABYN.csv")[600:]
test_dataset = MoleculeDataset(root="data/", filename="/Users/Siddhu/Desktop/AB Drugs/ABYN.csv", test=True)[:600]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
sample_data = train_dataset[0]
node_input_dim = sample_data.x.shape[1]
edge_input_dim = sample_data.edge_attr.shape[1]
model = GVAE(
node_input_dim=node_input_dim,
hidden_dim=32,
latent_dim=16,
encoder_embedding_size=64,
latent_embedding_size=128,
heads=4,
edge_input_dim=edge_input_dim
)
model = model.to(device)
print("Model parameters: ", count_parameters(model))
loss_fn = gvae_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
kl_beta = 0.5
def validate_molecule(mol):
try:
Chem.SanitizeMol(mol)
return True
except:
return False
def run_one_epoch(data_loader, type, epoch, kl_beta):
all_losses = []
all_kldivs = []
for _, batch in enumerate(tqdm(data_loader)):
try:
batch.to(device)
optimizer.zero_grad()
total_elements = batch.edge_attr.numel()
if batch.x.shape[1] != model.node_input_dim:
batch.x = batch.x.view(-1, model.node_input_dim)
if total_elements % model.edge_input_dim == 0:
batch.edge_attr = batch.edge_attr.view(-1, model.edge_input_dim)
else:
print(f"Skipping reshape for batch.edge_attr due to incompatible dimensions: {batch.edge_attr.shape}")
triu_logits, node_logits, mu, logvar = model(batch.x.float(),
batch.edge_attr.float(),
batch.edge_index,
batch.batch)
edge_targets = slice_edge_type_from_edge_feats(batch.edge_attr.float())
node_targets = slice_atom_type_from_node_feats(batch.x.float(), as_index=True)
loss, kl_div = loss_fn(triu_logits, node_logits,
batch.edge_index, edge_targets,
node_targets, mu, logvar,
batch.batch, kl_beta)
if type == "Train":
loss.backward()
optimizer.step()
all_losses.append(loss.detach().cpu().numpy())
all_kldivs.append(kl_div.detach().cpu().numpy())
except IndexError as error:
print("Error: ", error)
# Perform sampling
if type == "Test":
print(f"Starting test {epoch}")
generated_smiles = []
for _ in tqdm(range(100), desc="Sampling molecules"):
z = torch.randn(1, model.latent_dim).to(device)
dummy_batch_index = torch.Tensor([0]).int().to(device)
triu_logits, node_logits = model.decode(z, dummy_batch_index)
edge_matrix_shape = (int((model.max_num_atoms * (model.max_num_atoms - 1)) / 2), model.num_edge_types + 1)
triu_preds_matrix = triu_logits.reshape(edge_matrix_shape)
triu_preds = torch.argmax(triu_preds_matrix, dim=1)
node_matrix_shape = (model.max_num_atoms, (model.num_atom_types) + 1)
node_preds_matrix = node_logits.reshape(node_matrix_shape)
node_preds = torch.argmax(node_preds_matrix[:, :9], dim=1)
node_preds_one_hot = to_one_hot(node_preds, options=model.atomic_numbers)
atom_numbers_dummy = torch.Tensor(model.atomic_numbers).repeat(node_preds_one_hot.shape[0], 1).to(device)
atom_types = torch.masked_select(atom_numbers_dummy, node_preds_one_hot.bool())
mol = graph_representation_to_molecule(atom_types)
if mol and validate_molecule(mol):
smiles = Chem.MolToSmiles(mol)
generated_smiles.append(smiles)
strings.append(smiles)
mlflow.log_metric(key=f"Sampled molecules", value=float(len(generated_smiles)), step=epoch)
mlflow.log_metric(key=f"{type} Epoch Loss", value=float(np.array(all_losses).mean()), step=epoch)
mlflow.log_metric(key=f"{type} KL Divergence", value=float(np.array(all_kldivs).mean()), step=epoch)
mlflow.pytorch.log_model(model, "model")
if type == "Train":
torch.save(model.state_dict(), f'model_weights_epoch_{epoch}.pth')
print(f'Model weights saved for epoch {epoch}')
with mlflow.start_run() as run:
for epoch in range(5):
model.train()
run_one_epoch(train_loader, type="Train", epoch=epoch, kl_beta=kl_beta)
model.eval()
run_one_epoch(train_loader, type="Test", epoch=epoch, kl_beta=kl_beta)
print(strings)
'''
if epoch % 5 == 0:
print("Start test epoch...")
model.eval()
run_one_epoch(test_loader, type="Test", epoch=epoch, kl_beta=kl_beta)
'''
torch.save(model.state_dict(), 'final_model_weights.pth')
print('Final model weights saved')
Evaluation:
import json
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from rdkit import Chem
from rdkit.Chem import Descriptors
import pandas as pd
import numpy as np
from Visualize import smiles_to_image
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem.rdMolDescriptors import CalcMolFormula
def smiles_to_features(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol:
return [
Descriptors.MolWt(mol),
Descriptors.NumRotatableBonds(mol),
Descriptors.TPSA(mol),
Descriptors.MolLogP(mol),
Descriptors.HeavyAtomCount(mol),
Descriptors.FractionCSP3(mol), Descriptors.NumHAcceptors(mol),
Descriptors.NumHDonors(mol), Descriptors.RingCount(mol
]
else:
return [0] * 9
class SMILESDataset(Dataset):
def __init__(self, file_path, test=False):
self.data = pd.read_csv(file_path)
self.features = np.array([smiles_to_features(smiles) for smiles in self.data['Canonical SMILES']])
self.labels = np.array(self.data['Targeting-protein'], dtype=np.float32)
if test:
self.features = self.features[:600]
self.labels = self.labels[:600]
else:
self.features = self.features[600:]
self.labels = self.labels[600:]
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
return torch.tensor(self.features[idx], dtype=torch.float32), torch.tensor(self.labels[idx],
dtype=torch.float32)
class InhibitionClassifier(nn.Module):
def __init__(self, input_size=9):
super(InhibitionClassifier, self).__init__()
self.fc1 = nn.Linear(input_size, 16)
self.fc2 = nn.Linear(16, 8)
self.fc3 = nn.Linear(8, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
def train_model(dataset_path, epochs=20000, batch_size=16, learning_rate=0.001):
arr=[]
train_dataset = SMILESDataset(dataset_path)
test_dataset = SMILESDataset(dataset_path, test=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
model = InhibitionClassifier()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_history = []
for epoch in range(epochs):
correct = 0
total = 0
total_loss = 0
for features, labels in train_loader:
optimizer.zero_grad()
outputs = model(features).squeeze()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
predictions = (outputs >= 0.5).float()
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
arr.append(accuracy)
print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.2f}")
torch.save(model.state_dict(), f'model_weights_epoch_{epoch}.pth')
return model,arr
def predict_inhibition(smiles, model_path='model_weights_epoch_9.pth'):
features = smiles_to_features(smiles)
features = np.array(features, dtype=np.float32)
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
model = InhibitionClassifier()
model.load_state_dict(torch.load(model_path))
model.eval()
with torch.no_grad():
output = model(features_tensor)
probability = output.item()
if probability > 0.5:
print(f"The molecule is likely an inhibitor with a probability of {probability:.4f}.")
return True
else:
print(f"The molecule is likely not an inhibitor with a probability of {probability:.4f}.")
return False
def evaluate_model(model, dataset_path):
test_dataset = SMILESDataset(dataset_path, test=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for features, labels in test_loader:
outputs = model(features).squeeze()
predictions = (outputs >= 0.5).float()
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = correct / total * 100
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
Inhibitor Dataset:
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.utils import dense_to_sparse
print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")
class MoleculeDataset(Dataset):
def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
self.test = test
self.filename = filename
super(MoleculeDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return [self.filename]
@property
def processed_file_names(self):
self.data = pd.read_csv(self.raw_paths[0]).reset_index()
if self.test:
return [f'data_test_{i}.pt' for i in list(self.data.index)]
else:
return [f'data_{i}.pt' for i in list(self.data.index)]
def download(self):
pass
def process(self):
self.data = pd.read_csv(self.raw_paths[0]).reset_index()
for index, row in tqdm(self.data.iterrows(), total=self.data.shape[0]):
mol = Chem.MolFromSmiles(row["Canonical SMILES"])
if mol is None:
continue
mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol, randomSeed=42)
AllChem.UFFOptimizeMolecule(mol)
# Featurize molecule
node_features = self.get_node_features(mol)
edge_index, edge_attr = self.get_edge_info(mol)
data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
data.y = self._get_label(row["Targeting-protein"])
data.smiles = row["Canonical SMILES"]
self._validate_data(data)
if self.test:
torch.save(data, os.path.join(self.processed_dir, f'data_test_{index}.pt'))
else:
torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt'))
def get_node_features(self, mol):
all_node_feats = []
for atom in mol.GetAtoms():
node_feats = [
atom.GetAtomicNum(),
atom.GetSymbol(),
atom.GetTotalValence(),
atom.GetDoubleProp("_GasteigerCharge") if atom.HasProp("_GasteigerCharge") else 0.0
atom.GetDoubleProp("_GasteigerCharge") if atom.HasProp("_GasteigerCharge") else 0.0,
atom.GetHybridization(),
atom.GetFormalCharge(),
atom.IsInRing()
]
print(node_feats)
all_node_feats.append(node_feats)
return torch.tensor(all_node_feats, dtype=torch.float)
def get_edge_info(self, mol):
num_atoms = mol.GetNumAtoms()
adj = np.zeros((num_atoms, num_atoms), dtype=np.float32)
edge_feats = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
adj[i, j] = 1
adj[j, i] = 1
edge_feats.append([bond.GetBondTypeAsDouble()])
edge_index, _ = dense_to_sparse(torch.tensor(adj))
return edge_index, torch.tensor(edge_feats, dtype=torch.float)
def _get_label(self, label):
label = np.asarray([label])
return torch.tensor(label, dtype=torch.int64)
def _validate_data(self, data):
node_feature_dim = data.x.shape[1]
if not hasattr(self, 'node_feature_dim'):
self.node_feature_dim = node_feature_dim
if node_feature_dim != self.node_feature_dim:
raise ValueError(f"Inconsistent node feature dimensions: expected {self.node_feature_dim}, got {node_feature_dim}")
def len(self):
return self.data.shape[0]
def get(self, idx):
if self.test:
data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
else:
data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
return data