ABIGAIL Code

GVAE Version #1:

import math
import torch
import torch.nn as nn
from torch_scatter import scatter_softmax
from torch.nn import Linear
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import BatchNorm
from config import SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS
from utils import graph_representation_to_molecule, to_one_hot
from tqdm import tqdm

class GVAE(nn.Module):
    def __init__(self, encoder_embedding_size=64, latent_embedding_size=128, heads=4,
                 node_input_dim=30,
                 edge_input_dim=3,
                 hidden_dim=32,
                 latent_dim=16):
        super(GVAE, self).__init__()
        self.node_input_dim = node_input_dim
        self.edge_input_dim = edge_input_dim
        self.encoder_embedding_size = encoder_embedding_size
        self.edge_dim = None
        self.heads = heads
        self.latent_embedding_size = latent_embedding_size
        self.latent_dim = latent_dim
        self.conv1 = None
        self.conv1_initialized = False
        self.num_edge_types = len(SUPPORTED_EDGES)
        self.num_atom_types = len(SUPPORTED_ATOMS)
        self.max_num_atoms = MAX_MOLECULE_SIZE
        self.atomic_numbers = ATOMIC_NUMBERS
        self.decoder_hidden_neurons = 512
        self.lin_edge = None

        self.bn1 = BatchNorm(self.encoder_embedding_size)

        self.conv2 = TransformerConv(self.encoder_embedding_size,
                                     self.encoder_embedding_size,
                                     heads=self.heads,
                                     concat=False,
                                     beta=True,
                                     edge_dim=self.edge_dim)
        self.conv2.message = self.message
        self.bn2 = BatchNorm(self.encoder_embedding_size)

        self.conv3 = TransformerConv(self.encoder_embedding_size,
                                     self.encoder_embedding_size,
                                     heads=self.heads,
                                     concat=False,
                                     beta=True,
                                     edge_dim=self.edge_dim)
        self.conv3.message = self.message
        self.bn3 = BatchNorm(self.encoder_embedding_size)

        self.conv4 = TransformerConv(self.encoder_embedding_size,
                                     self.encoder_embedding_size,
                                     heads=self.heads,
                                     concat=False,
                                     beta=True,
                                     edge_dim=self.edge_dim)
        self.conv4.message = self.message

        self.mu_transform = Linear(self.encoder_embedding_size, self.latent_dim)
        self.logvar_transform = Linear(self.encoder_embedding_size, self.latent_dim)

        self.linear_1 = Linear(self.latent_dim, self.decoder_hidden_neurons)
        self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons)

        atom_output_dim = self.max_num_atoms * (self.num_atom_types + 1)
        self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim)

        edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * (self.num_edge_types + 1))
        self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim)

    def encode(self, x, edge_attr, edge_index, batch_index):
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.bn1(x)
        x = self.conv2(x, edge_index, edge_attr).relu()
        x = self.bn2(x)
        x = self.conv3(x, edge_index, edge_attr).relu()
        x = self.bn3(x)
        x = self.conv4(x, edge_index, edge_attr).relu()

        mu = self.mu_transform(x)
        logvar = self.logvar_transform(x)
        return mu, logvar

    def decode_graph(self, graph_z):
        z = self.linear_1(graph_z).relu()
        z = self.linear_2(z).relu()

        atom_logits = self.atom_decode(z)
        edge_logits = self.edge_decode(z)

        return atom_logits, edge_logits

    def decode(self, z, batch_index):
        node_logits = []
        triu_logits = []
        for graph_id in torch.unique(batch_index):
            graph_z = z[graph_id]
            atom_logits, edge_logits = self.decode_graph(graph_z)
            node_logits.append(atom_logits)
            triu_logits.append(edge_logits)

        node_logits = torch.cat(node_logits)
        triu_logits = torch.cat(triu_logits)
        return triu_logits, node_logits

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, edge_attr, edge_index, batch_index):
        if not self.conv1_initialized:
            feature_size = x.shape[1]
            self.edge_dim = edge_attr.shape[1]
            if self.conv1 is None or self.conv1.in_channels != feature_size or self.conv1.edge_dim != self.edge_dim:
                print(f"Initializing/Reinitializing conv1 with feature_size: {feature_size}, edge_dim: {self.edge_dim}")
                self.conv1 = TransformerConv(feature_size, self.encoder_embedding_size,
                                             heads=self.heads, concat=False, beta=True, edge_dim=self.edge_dim)
                self.conv1.message = self.message
                self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
            self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
            self.conv1.message = self.message
            self.conv1_initialized = True
        mu, logvar = self.encode(x, edge_attr, edge_index, batch_index)
        z = self.reparameterize(mu, logvar)
        triu_logits, node_logits = self.decode(z, batch_index)
        return triu_logits, node_logits, mu, logvar

    def message(self, query_i, key_j, value_j, edge_attr, index, ptr, size_i):
        if edge_attr is not None:
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.encoder_embedding_size // self.heads)
            if edge_attr.size(1) == key_j.size(1) and edge_attr.size(2) == key_j.size(2):
                key_j = key_j + edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.encoder_embedding_size // self.heads)
        alpha = scatter_softmax(alpha, index, dim=0)
        return value_j * alpha.view(-1, self.heads, 1)

GVAE Version #2:

import math
import torch
import torch.nn as nn
from torch_scatter import scatter_softmax
from torch.nn import Linear
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import Set2Set
from torch_geometric.nn import BatchNorm
from config import SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS
from utils import graph_representation_to_molecule, to_one_hot
from tqdm import tqdm

class GVAE(nn.Module):
    def __init__(self, encoder_embedding_size=64, latent_embedding_size=128, heads=4,
                 node_input_dim=30,  # From x shape: 30 features per node
                 edge_input_dim=3,  # Updated to match the actual edge attribute dimensions
                 hidden_dim=32,
                 latent_dim=16):
        super(GVAE, self).__init__()
        self.node_input_dim = node_input_dim
        self.edge_input_dim = edge_input_dim
        self.encoder_embedding_size = encoder_embedding_size
        self.edge_dim = None
        self.heads = heads
        self.latent_embedding_size = latent_embedding_size
        self.latent_dim = latent_dim
        self.conv1 = None
        self.conv1_initialized = False
        self.num_edge_types = len(SUPPORTED_EDGES)
        self.num_atom_types = len(SUPPORTED_ATOMS)
        self.max_num_atoms = MAX_MOLECULE_SIZE
        self.atomic_numbers = ATOMIC_NUMBERS
        self.decoder_hidden_neurons = 512
        self.lin_edge = None

        self.bn1 = BatchNorm(self.encoder_embedding_size)

        self.conv2 = TransformerConv(self.encoder_embedding_size,
                                     self.encoder_embedding_size,
                                     heads=self.heads,
                                     concat=False,
                                     beta=True,
                                     edge_dim=self.edge_dim)
        self.conv2.message = self.message  # Add this line
        self.bn2 = BatchNorm(self.encoder_embedding_size)

        self.conv3 = TransformerConv(self.encoder_embedding_size,
                                     self.encoder_embedding_size,
                                     heads=self.heads,
                                     concat=False,
                                     beta=True,
                                     edge_dim=self.edge_dim)
        self.conv3.message = self.message  # Add this line
        self.bn3 = BatchNorm(self.encoder_embedding_size)

        self.conv4 = TransformerConv(self.encoder_embedding_size,
                                     self.encoder_embedding_size,
                                     heads=self.heads,
                                     concat=False,
                                     beta=True,
                                     edge_dim=self.edge_dim)
        self.conv4.message = self.message

        self.pooling = Set2Set(self.encoder_embedding_size, processing_steps=4)

        self.mu_transform = Linear(self.encoder_embedding_size * 2, self.latent_dim)
        self.logvar_transform = Linear(self.encoder_embedding_size * 2, self.latent_dim)

        self.linear_1 = Linear(self.latent_dim, self.decoder_hidden_neurons)
        self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons)

        atom_output_dim = self.max_num_atoms * (self.num_atom_types + 1)
        self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim)

        edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * (self.num_edge_types + 1))
        self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim)

    def encode(self, x, edge_attr, edge_index, batch_index):
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.bn1(x)
        x = self.conv2(x, edge_index, edge_attr).relu()
        x = self.bn2(x)
        x = self.conv3(x, edge_index, edge_attr).relu()
        x = self.bn3(x)
        x = self.conv4(x, edge_index, edge_attr).relu()

        x = self.pooling(x, batch_index)

        mu = self.mu_transform(x)
        logvar = self.logvar_transform(x)
        return mu, logvar

    def decode_graph(self, graph_z):
        z = self.linear_1(graph_z).relu()
        z = self.linear_2(z).relu()

        atom_logits = self.atom_decode(z)
        edge_logits = self.edge_decode(z)

        return atom_logits, edge_logits

    def decode(self, z, batch_index):
        node_logits = []
        triu_logits = []
        for graph_id in torch.unique(batch_index):
            graph_z = z[graph_id]
            atom_logits, edge_logits = self.decode_graph(graph_z)
            node_logits.append(atom_logits)
            triu_logits.append(edge_logits)

        node_logits = torch.cat(node_logits)
        triu_logits = torch.cat(triu_logits)
        return triu_logits, node_logits

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, edge_attr, edge_index, batch_index):
        if not self.conv1_initialized:
            feature_size = x.shape[1]
            self.edge_dim = edge_attr.shape[1]
            if self.conv1 is None or self.conv1.in_channels != feature_size or self.conv1.edge_dim != self.edge_dim:
                print(f"Initializing/Reinitializing conv1 with feature_size: {feature_size}, edge_dim: {self.edge_dim}")
                self.conv1 = TransformerConv(feature_size, self.encoder_embedding_size,
                                             heads=self.heads, concat=False, beta=True, edge_dim=self.edge_dim)
                self.conv1.message = self.message
                self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
            self.lin_edge = nn.Linear(self.edge_dim, self.encoder_embedding_size)
            self.conv1.message = self.message
            self.conv1_initialized = True
        mu, logvar = self.encode(x, edge_attr, edge_index, batch_index)
        z = self.reparameterize(mu, logvar)
        triu_logits, node_logits = self.decode(z, batch_index)
        return triu_logits, node_logits, mu, logvar

    def sample_mols(self, num=10000):
        print("Sampling molecules ... ")
        n_valid = 0
        for _ in tqdm(range(num)):
            z = torch.randn(1, self.latent_embedding_size)
            dummy_batch_index = torch.Tensor([0]).int()
            triu_logits, node_logits = self.decode(z, dummy_batch_index)
            edge_matrix_shape = (int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1)) / 2), len(SUPPORTED_EDGES) + 1)
            triu_preds_matrix = triu_logits.reshape(edge_matrix_shape)
            triu_preds = torch.argmax(triu_preds_matrix, dim=1)
            node_matrix_shape = (MAX_MOLECULE_SIZE, (len(SUPPORTED_ATOMS) + 1))
            node_preds_matrix = node_logits.reshape(node_matrix_shape)
            node_preds = torch.argmax(node_preds_matrix[:, :9], dim=1)
            node_preds_one_hot = to_one_hot(node_preds, options=ATOMIC_NUMBERS)
            atom_numbers_dummy = torch.Tensor(ATOMIC_NUMBERS).repeat(node_preds_one_hot.shape[0], 1)
            atom_types = torch.masked_select(atom_numbers_dummy, node_preds_one_hot.bool())
            smiles, mol = graph_representation_to_molecule(atom_types, triu_preds.float())
            if smiles and "." not in smiles:
                print("Successfully generated: ", smiles)
                n_valid += 1
        return n_valid

    def message(self, query_i, key_j, value_j, edge_attr, index, ptr, size_i):
        if edge_attr is not None:
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.encoder_embedding_size // self.heads)
            if edge_attr.size(1) == key_j.size(1) and edge_attr.size(2) == key_j.size(2):
                key_j = key_j + edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.encoder_embedding_size // self.heads)
        alpha = scatter_softmax(alpha, index, dim=0)
        return value_j * alpha.view(-1, self.heads, 1)

Reconstruction Training:

def graph_representation_to_molecule(tensor):
    # Create an empty RDKit molecule
    mol = Chem.RWMol()
    carbon_indices = []
    atoms=[]
    atomattributes = tensor_to_atoms(tensor)
    for atom in atomattributes:
        atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
        atoms.append(atm)
    counter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'C':
            carbon_indices.append(counter)
        counter +=1
    counter = 0
    for carbons in carbon_indices:
        if counter+1<len(carbon_indices) and carbon_indices[counter+1]<len(atoms) and counter<6:
            if counter%2 == 0:
                mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.SINGLE)
                atomattributes[carbons].valence-=1
                atomattributes[carbon_indices[counter+1]].valence -= 1
                atomattributes[carbons].inRing -= True
                atomattributes[carbon_indices[counter + 1]].inRing -= True

            else:
                mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.DOUBLE)
                atomattributes[carbons].inRing -= True
                atomattributes[carbon_indices[counter+1]].inRing -= True
                atomattributes[carbons].valence -= 2
                atomattributes[carbon_indices[counter + 1]].valence -= 2
        counter += 1
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        for atom in atomattributes:
            if atom.return_symbol() == 'C' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
                carboncounter + 1] < len(atoms) and counter < len(atomattributes):
                if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                    carbon_indices[carboncounter + 1]].return_valence() > 0:
                    if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                        mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                        atomattributes[counter].valence -= 1
                        atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                        carbon_indices.append(atoms[counter])
            counter += 1
            carboncounter += 1

    #Oxygen
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'O' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Chlorine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'Cl' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
    if carboncounter == 6:
        carboncounter = 0
    # Sulfur
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'S' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0
    # Nitrogen
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'N' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter<len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Bromine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'Br' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Florine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'F' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    # Phosphorus
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'P' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Iodine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'I' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0
    atoms_to_remove = []
    for atom in mol.GetAtoms():

        if atom.GetDegree() == 0:  # If the atom has no bonds (degree 0)
            atoms_to_remove.append(atom.GetIdx())

        # Remove atoms in reverse order to avoid indexing issues
    for atom_idx in sorted(atoms_to_remove, reverse=True):
        mol.RemoveAtom(atom_idx)

    img = Draw.MolToImage(mol)
    smiles = Chem.MolToSmiles(mol)
    print(smiles)
    return mol
    AllChem.Compute2DCoords(mol)
    rings = mol.GetRingInfo().AtomRings()
    carbon_rings = [ring for ring in rings if all(mol.GetAtomWithIdx(idx).GetSymbol() == 'C' for idx in ring)]

    # Reorder the SMILES string to create rings with carbon blocks
    reordered_smiles = reorder_smiles_to_create_rings(mol_representation, atom_indices)

    # Print the original and reordered SMILES representations
    print("Original SMILES Representation:", smiles_representation)
    print("Reordered SMILES Representation:", reordered_smiles)

Hybrid Between AI and Rule-Based:

class BondPredictionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(BondPredictionModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return self.sigmoid(x)

model = BondPredictionModel(input_dim=5, hidden_dim=32)  


def predict_bond(atom1, atom2, atom_features):
    features = torch.cat((atom_features[atom1], atom_features[atom2]), dim=0)
    bond_prediction = model(features)
    return bond_prediction.item() > 0.5  


def graph_representation_to_molecule(tensor):
    mol = Chem.RWMol()
    carbon_indices = []
    atoms = []
    atomattributes = tensor_to_atoms(tensor)

    for atom in atomattributes:
        atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
        atoms.append(atm)

    carbon_indices = [idx for idx, atom in enumerate(atomattributes) if atom.return_symbol() == 'C']

    for i in range(len(carbon_indices)):
        for j in range(i + 1, len(carbon_indices)):
            if predict_bond(carbon_indices[i], carbon_indices[j], tensor):
                mol.AddBond(atoms[carbon_indices[i]], atoms[carbon_indices[j]], Chem.BondType.SINGLE)
                atomattributes[carbon_indices[i]].valence -= 1
                atomattributes[carbon_indices[j]].valence -= 1

    for atom in atomattributes:
        if atom.return_symbol() in ['O', 'N', 'Cl','P','S','Br','I']:
            connected_carbon_idx = random.choice(carbon_indices)
            if predict_bond(atom, connected_carbon_idx, tensor):
                mol.AddBond(atoms[atom.idx], atoms[connected_carbon_idx], Chem.BondType.SINGLE)
                atomattributes[atom.idx].valence -= 1
                atomattributes[connected_carbon_idx].valence -= 1

    atoms_to_remove = []
    for atom in mol.GetAtoms():
        if atom.GetDegree() == 0: 
            atoms_to_remove.append(atom.GetIdx())

    for atom_idx in sorted(atoms_to_remove, reverse=True):
        mol.RemoveAtom(atom_idx)

    AllChem.Compute2DCoords(mol)

    smiles = Chem.MolToSmiles(mol)
    print(f"Generated SMILES: {smiles}")

    img = Draw.MolToImage(mol)
    img.show() 

    return mol

Other Important Files:

UTILS:

import random
import torch
from torch_geometric.utils import to_dense_adj
from rdkit import Chem
import random
from rdkit import RDLogger

from ShapeDebugger import smiles
from config import DEVICE as device
from config import (SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS, valence_dict,
                    DISABLE_RDKIT_WARNINGS)
import matplotlib.pyplot as plt
# Disable rdkit warnings
if DISABLE_RDKIT_WARNINGS:
    RDLogger.DisableLog('rdApp.*')

from rdkit import Chem
from rdkit.Chem import AllChem, Draw

def count_rings(mol):
    """Count and print the number of rings in the molecule."""
    ssr = Chem.GetSymmSSSR(mol)
    num_rings = len(ssr)
    print(f"Number of rings detected: {num_rings}")
    for i, ring in enumerate(ssr):
        print(f"Ring {i + 1}: Atoms {list(ring)}")
    return num_rings

def log_ring_info(mol):
    """Log information about rings in the molecule."""
    ssr = Chem.GetSymmSSSR(mol)
    num_rings = len(ssr)
    print(f"Number of rings detected: {num_rings}")
    for i, ring in enumerate(ssr):
        print(f"Ring {i + 1}: Atoms {list(ring)}")
    return ssr

def break_down_rings(mol):
    """Break down complex ring structures in the molecule."""
    ssr = log_ring_info(mol)
    if not ssr:
        print("No rings detected in the molecule.")
        return mol
    ring_bonds = set()
    for ring in ssr:
        for bond_idx in ring:
            ring_bonds.add(bond_idx)
    if not ring_bonds:
        print("No ring bonds found in the molecule.")
        return mol
    fragmented_mol = AllChem.FragmentOnBonds(mol, list(ring_bonds), addDummies=False)
    fragmented_mol = Chem.RemoveHs(fragmented_mol)  # Remove hydrogen atoms added during fragmentation
    return fragmented_mol

def count_parameters(model):
    """
    Counts the number of parameters for a Pytorch model
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def kl_loss(mu=None, logstd=None):
    """
    Closed formula of the KL divergence for normal distributions
    """
    MAX_LOGSTD = 10
    logstd = logstd.clamp(max=MAX_LOGSTD)
    kl_div = -0.5 * torch.mean(torch.sum(1 + 2 * logstd - mu ** 2 - logstd.exp() ** 2, dim=1))
    kl_div = kl_div.clamp(max=1000)
    return kl_div

def slice_graph_targets(graph_id, edge_targets, node_targets, batch_index):
    """
    Slices out the upper triangular part of an adjacency matrix for
    a single graph from a large adjacency matrix for a full batch.
    For the node features the corresponding section in the batch is sliced out.
    --------
    graph_id: The ID of the graph (in the batch index) to slice
    edge_targets: A dense adjacency matrix for the whole batch
    node_targets: A tensor of node labels for the whole batch
    batch_index: The node to graph map for the batch
    """
    graph_mask = torch.eq(batch_index, graph_id)
    graph_edge_targets = edge_targets[graph_mask][:, graph_mask]
    size = graph_edge_targets.shape[0]
    triu_indices = torch.triu_indices(size, size, offset=1)
    triu_mask = torch.squeeze(to_dense_adj(triu_indices)).bool()
    graph_edge_targets = graph_edge_targets[triu_mask]
    graph_node_targets = node_targets[graph_mask]
    return graph_edge_targets, graph_node_targets

def slice_graph_predictions(triu_logits, node_logits, graph_triu_size, triu_start_point, graph_size, node_start_point):
    """
    Slices out the corresponding section from a list of batch triu values.
    Given a start point and the size of a graph's triu, simply slices
    the section from the batch list.
    -------
    triu_logits: A batch of triu predictions of different graphs
    node_logits: A batch of node predictions with fixed size MAX_GRAPH_SIZE
    graph_triu_size: Size of the triu of the graph to slice
    triu_start_point: Index of the first node of this graph in the triu batch
    graph_size: Max graph size
    node_start_point: Index of the first node of this graph in the nodes batch
    """
    graph_logits_triu = torch.squeeze(triu_logits[triu_start_point:triu_start_point + graph_triu_size])
    graph_node_logits = torch.squeeze(node_logits[node_start_point:node_start_point + graph_size])
    return graph_logits_triu, graph_node_logits

def slice_edge_type_from_edge_feats(edge_feats):
    """
    This function only works for the MolGraphConvFeaturizer used in the dataset.
    It slices the one-hot encoded edge type from the edge feature matrix.
    The first 4 values stand for ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"].
    """
    edge_types_one_hot = edge_feats[:, :4]
    edge_types = edge_types_one_hot.nonzero(as_tuple=False)
    edge_types[:, 1] = edge_types[:, 1] + 1
    return edge_types

def slice_atom_type_from_node_feats(node_features, as_index=False):
    """
    This function only works for the MolGraphConvFeaturizer used in the dataset.
    It slices the one-hot encoded atom type from the node feature matrix.
    Unknown atom types are not considered and not expected in the datset.
    """
    supported_atoms = SUPPORTED_ATOMS
    atomic_numbers = ATOMIC_NUMBERS
    atom_types_one_hot = node_features[:, :len(supported_atoms)]
    if not as_index:
        atom_numbers_dummy = torch.Tensor(atomic_numbers).repeat(atom_types_one_hot.shape[0], 1)
        atom_types = torch.masked_select(atom_numbers_dummy, atom_types_one_hot.bool())
    else:
        atom_types = torch.argmax(atom_types_one_hot, dim=1)
    return atom_types

def to_one_hot(x, options):
    """
    Converts a tensor of values to a one-hot vector
    based on the entries in options.
    """
    return torch.nn.functional.one_hot(x.long(), len(options))

def squared_difference(input, target):
    return (input - target) ** 2

def triu_to_dense(triu_values, num_nodes):
    """
    Converts a triangular upper part of a matrix as flat vector
    to a squared adjacency matrix with a specific size (num_nodes).
    """
    dense_adj = torch.zeros((num_nodes, num_nodes)).to(device).float()
    triu_indices = torch.triu_indices(num_nodes, num_nodes, offset=1)
    tril_indices = torch.tril_indices(num_nodes, num_nodes, offset=-1)
    dense_adj[triu_indices[0], triu_indices[1]] = triu_values
    dense_adj[tril_indices[0], tril_indices[1]] = triu_values
    return dense_adj

def triu_to_3d_dense(triu_values, num_nodes, depth=len(SUPPORTED_EDGES)):
    """
    Converts the triangular upper part of a matrix
    for several dimensions into a 3d tensor.
    """
    adj_matrix_3d = torch.empty((num_nodes, num_nodes, depth), dtype=torch.float, device=device)
    for edge_type in range(len(SUPPORTED_EDGES)):
        adj_mat_edge_type = triu_to_dense(triu_values[:, edge_type].float(), num_nodes)
        adj_matrix_3d[:, :, edge_type] = adj_mat_edge_type
    return adj_matrix_3d

def calculate_node_edge_pair_loss(node_tar, edge_tar, node_pred, edge_pred):
    """
    Calculates a loss based on the sum of node-edge pairs.
    node_tar:  [nodes, supported atoms]
    node_pred: [max nodes, supported atoms + 1]
    edge_tar:  [triu values for target nodes, supported edges]
    edge_pred: [triu values for predicted nodes, supported edges]
    """
    edge_pred_3d = triu_to_3d_dense(edge_pred, node_pred.shape[0])
    edge_tar_3d = triu_to_3d_dense(edge_tar, node_tar.shape[0])
    node_edge_preds = torch.empty((MAX_MOLECULE_SIZE, len(SUPPORTED_ATOMS), len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
    for edge in range(len(SUPPORTED_EDGES)):
        node_edge_preds[:, :, edge] = torch.matmul(edge_pred_3d[:, :, edge], node_pred[:, :9])
    node_edge_tar = torch.empty((node_tar.shape[0], len(SUPPORTED_ATOMS), len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
    for edge in range(len(SUPPORTED_EDGES)):
        node_edge_tar[:, :, edge] = torch.matmul(edge_tar_3d[:, :, edge], node_tar.float())
    node_edge_pred_matrix = torch.sum(node_edge_preds, dim=0)
    node_edge_tar_matrix = torch.sum(node_edge_tar, dim=0)
    if torch.equal(node_edge_pred_matrix.int(), node_edge_tar_matrix.int()):
        print("Reconstructed node-edge pairs: ", node_edge_pred_matrix.int())
    node_edge_loss = torch.mean(sum(squared_difference(node_edge_pred_matrix, node_edge_tar_matrix.float())))
    node_edge_node_preds = torch.empty((MAX_MOLECULE_SIZE, MAX_MOLECULE_SIZE, len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
    for edge in range(len(SUPPORTED_EDGES)):
        node_edge_node_preds[:, :, edge] = torch.matmul(node_edge_preds[:, :, edge], node_pred[:, :9].t())
    node_edge_node_tar = torch.empty((node_tar.shape[0], node_tar.shape[0], len(SUPPORTED_EDGES)), dtype=torch.float, device=device)
    for edge in range(len(SUPPORTED_EDGES)):
        node_edge_node_tar[:, :, edge] = torch.matmul(node_edge_tar[:, :, edge], node_tar.float().t())
    node_edge_node_loss = sum(squared_difference(torch.sum(node_edge_node_preds, [0, 1]), torch.sum(node_edge_node_tar, [0, 1])))
    return node_edge_loss

def approximate_recon_loss(node_targets, node_preds, triu_targets, triu_preds):
    """
    See: https://github.com/seokhokang/graphvae_approx/
    """
    onehot_node_targets = to_one_hot(node_targets, SUPPORTED_ATOMS)
    onehot_triu_targets = to_one_hot(triu_targets, ["None"] + SUPPORTED_EDGES)
    node_matrix_shape = (MAX_MOLECULE_SIZE, (len(SUPPORTED_ATOMS) + 1))
    node_preds_matrix = node_preds.reshape(node_matrix_shape)
    edge_matrix_shape = (int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1)) / 2), len(SUPPORTED_EDGES) + 1)
    triu_preds_matrix = triu_preds.reshape(edge_matrix_shape)
    node_preds_reduced = torch.sum(node_preds_matrix[:, :9], 0)
    node_targets_reduced = torch.sum(onehot_node_targets, 0)
    triu_preds_reduced = torch.sum(triu_preds_matrix[:, 1:], 0)
    triu_targets_reduced = torch.sum(onehot_triu_targets[:, 1:], 0)
    node_loss = sum(squared_difference(node_preds_reduced, node_targets_reduced.float()))
    edge_loss = sum(squared_difference(triu_preds_reduced, triu_targets_reduced.float()))
    node_edge_loss = calculate_node_edge_pair_loss(onehot_node_targets, onehot_triu_targets, node_preds_matrix, triu_preds_matrix)
    approx_loss = node_loss + edge_loss + node_edge_loss
    if all(node_targets_reduced == node_preds_reduced.int()) and all(triu_targets_reduced == triu_preds_reduced.int()):
        print("Reconstructed all edges: ", node_targets_reduced)
        print("and all nodes: ", node_targets_reduced)
    return approx_loss

def gvae_loss(triu_logits, node_logits, edge_index, edge_types, node_types, mu, logvar, batch_index, kl_beta):
    """
    Calculates the loss for the graph variational autoencoder,
    consisting of a node loss, an edge loss and the KL divergence.
    """
    batch_edge_targets = torch.squeeze(to_dense_adj(edge_index))
    num_edges = edge_index.shape[1]
    if edge_types.shape[0] < num_edges:
        print(f"Warning: edge_types ({edge_types.shape[0]}) is smaller than number of edges ({num_edges})")
        edge_types = torch.nn.functional.pad(edge_types, (0, 0, 0, num_edges - edge_types.shape[0]))
    elif edge_types.shape[0] > num_edges:
        print(f"Warning: edge_types ({edge_types.shape[0]}) is larger than number of edges ({num_edges})")
        edge_types = edge_types[:num_edges]
    batch_edge_targets[edge_index[0], edge_index[1]] = edge_types[:, 1].float()
    graph_size = MAX_MOLECULE_SIZE * (len(SUPPORTED_ATOMS) + 1)
    graph_triu_size = int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1)) / 2) * (len(SUPPORTED_EDGES) + 1)
    batch_recon_loss = []
    triu_indices_counter = 0
    graph_size_counter = 0
    for graph_id in torch.unique(batch_index):
        triu_targets, node_targets = slice_graph_targets(graph_id, batch_edge_targets, node_types, batch_index)
        triu_preds, node_preds = slice_graph_predictions(triu_logits, node_logits, graph_triu_size, triu_indices_counter, graph_size, graph_size_counter)
        triu_indices_counter = triu_indices_counter + graph_triu_size
        graph_size_counter = graph_size_counter + graph_size
        recon_loss = approximate_recon_loss(node_targets, node_preds, triu_targets, triu_preds)
        batch_recon_loss.append(recon_loss)
    num_graphs = torch.unique(batch_index).shape[0]
    batch_recon_loss = torch.true_divide(sum(batch_recon_loss), num_graphs)
    kl_divergence = kl_loss(mu, logvar)
    return batch_recon_loss + kl_beta * kl_divergence, kl_divergence

def tensor_to_symbols(node_types_tensor):
    """
    Convert a tensor of atomic numbers to a list of element symbols.
    """
    atomic_numbers = node_types_tensor.numpy().astype(int)
    symbols = [Chem.GetPeriodicTable().GetElementSymbol(int(num)) for num in atomic_numbers]
    return symbols

def triu_to_dense(adjacency_triu, num_nodes):
    """
    Converts an upper triangular adjacency matrix to a dense adjacency matrix.
    """
    adjacency_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.int)
    idx = 0
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            adjacency_matrix[i, j] = adjacency_triu[idx].item()
            adjacency_matrix[j, i] = adjacency_triu[idx].item()
            idx += 1
    return adjacency_matrix

def random_bond_type():
    """
    Randomly selects a bond type (single, double, or triple).
    """
    bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE]
    return random.choice(bond_types)

def can_add_bond(element_symbol, current_bonds):
    """
    Checks if a bond can be added based on the atom's valence and current bonds.
    """
    return current_bonds < valence_dict.get(element_symbol, 0)

class Atom:
    def __init__(self, symbol, valence, atomic_number,inRing):
        self.symbol = symbol
        self.valence = valence
        self.atomic_number = atomic_number
        self.inRing = inRing

    def return_symbol(self):
        return self.symbol
    def return_valence(self):
        return self.valence
    def return_atomic_number(self):
        return self.valence
    def return_inRing(self):
        return self.inRing
# Function to create an RDKit molecule and return its SMILES representation
def tensor_to_atoms(tensor):
    atoms = []
    for atomno in tensor:
        if atomno:
            if atomno == 6:
                symbol = 'C'
            if atomno == 7:
                symbol = 'N'
            if atomno == 8:
                symbol = 'O'
            if atomno == 9:
                symbol = 'F'
            if atomno == 15:
                symbol = 'P'
            if atomno == 16:
                symbol = 'S'
            if atomno == 17:
                symbol = 'Cl'
            if atomno == 35:
                symbol = 'Br'
            if atomno == 53:
                symbol = 'I'
            valence = valence_dict.get(symbol)
            atoms.append(Atom(symbol, valence, atomno,False))
    return atoms

def graph_representation_to_molecule(tensor):
    mol = Chem.RWMol()
    carbon_indices = []
    atoms = []
    atomattributes = tensor_to_atoms(tensor)

    for atom in atomattributes:
        atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
        atoms.append(atm)

    carbon_indices = [idx for idx, atom in enumerate(atomattributes) if atom.return_symbol() == 'C']

    for i in range(len(carbon_indices)):
        for j in range(i + 1, len(carbon_indices)):
            if predict_bond(carbon_indices[i], carbon_indices[j], tensor):
                mol.AddBond(atoms[carbon_indices[i]], atoms[carbon_indices[j]], Chem.BondType.SINGLE)
                atomattributes[carbon_indices[i]].valence -= 1
                atomattributes[carbon_indices[j]].valence -= 1

    for atom in atomattributes:
        if atom.return_symbol() in ['O', 'N', 'Cl','P','S','Br','I']:
            connected_carbon_idx = random.choice(carbon_indices)
            if predict_bond(atom, connected_carbon_idx, tensor):
                mol.AddBond(atoms[atom.idx], atoms[connected_carbon_idx], Chem.BondType.SINGLE)
                atomattributes[atom.idx].valence -= 1
                atomattributes[connected_carbon_idx].valence -= 1

    atoms_to_remove = []
    for atom in mol.GetAtoms():
        if atom.GetDegree() == 0: 
            atoms_to_remove.append(atom.GetIdx())

    for atom_idx in sorted(atoms_to_remove, reverse=True):
        mol.RemoveAtom(atom_idx)

    AllChem.Compute2DCoords(mol)

    smiles = Chem.MolToSmiles(mol)
    print(f"Generated SMILES: {smiles}")

    img = Draw.MolToImage(mol)
    img.show() 

    return mol
def graph_to_molecule(tensor):
 # Create an empty RDKit molecule
    mol = Chem.RWMol()
    carbon_indices = []
    atoms=[]
    if check_carbon_polymer(tensor) == True:
        counter =0;
        for atom in tensor:
            if not (counter%6==0):
                tensor[counter] = random.choice(ATOMIC_NUMBERS)

    atomattributes = tensor_to_atoms(tensor)

    for atom in atomattributes:
        atm = mol.AddAtom(Chem.Atom(atom.return_symbol()))
        atoms.append(atm)
    counter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'C':
            carbon_indices.append(counter)
        counter +=1
    counter = 0
    for carbons in carbon_indices:
        if counter+1<len(carbon_indices) and carbon_indices[counter+1]<len(atoms) and counter<6:
            if counter%2 == 0:
                mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.SINGLE)
                atomattributes[carbons].valence-=1
                atomattributes[carbon_indices[counter+1]].valence -= 1
                atomattributes[carbons].inRing -= True
                atomattributes[carbon_indices[counter + 1]].inRing -= True

            else:
                mol.AddBond(atoms[carbons], atoms[carbon_indices[counter+1]], Chem.BondType.DOUBLE)
                atomattributes[carbons].inRing -= True
                atomattributes[carbon_indices[counter+1]].inRing -= True
                atomattributes[carbons].valence -= 2
                atomattributes[carbon_indices[counter + 1]].valence -= 2
        counter += 1
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        for atom in atomattributes:
            if atom.return_symbol() == 'C' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
                carboncounter + 1] < len(atoms) and counter < len(atomattributes):
                if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                    carbon_indices[carboncounter + 1]].return_valence() > 0:
                    if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                        mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                        atomattributes[counter].valence -= 1
                        atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                        carbon_indices.append(atoms[counter])
            counter += 1
            carboncounter += 1

    #Oxygen
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == tensor and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Chlorine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'Cl' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
    if carboncounter == 6:
        carboncounter = 0
    # Sulfur
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'S' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0
    # Nitrogen
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'N' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter<len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Bromine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'Br' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Florine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'F' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    # Phosphorus
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'P' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0

    #Iodine
    counter = 0
    carboncounter = 0
    for atom in atomattributes:
        if atom.return_symbol() == 'I' and carboncounter + 1 < len(carbon_indices) and carbon_indices[
            carboncounter + 1] < len(atoms) and counter < len(atomattributes):
            if random.randint(0, 10) > 5 and atomattributes[counter].return_valence() > 0 and atomattributes[
                carbon_indices[carboncounter + 1]].return_valence() > 0:
                if (mol.GetBondBetweenAtoms(atoms[counter], atoms[carbon_indices[carboncounter + 1]]) == None):
                    mol.AddBond(atoms[counter], atoms[carbon_indices[carboncounter + 1]], Chem.BondType.SINGLE)
                    atomattributes[counter].valence -= 1
                    atomattributes[carbon_indices[carboncounter + 1]].valence -= 1
                    carbon_indices.append(atoms[counter])
        counter += 1
        carboncounter += 1
        if carboncounter == 6:
            carboncounter = 0
    atoms_to_remove = []
    for atom in mol.GetAtoms():

        if atom.GetDegree() == 0:  # If the atom has no bonds (degree 0)
            atoms_to_remove.append(atom.GetIdx())

        # Remove atoms in reverse order to avoid indexing issues
    for atom_idx in sorted(atoms_to_remove, reverse=True):
        mol.RemoveAtom(atom_idx)

    img = Draw.MolToImage(mol)
    smiles = Chem.MolToSmiles(mol)
    print(smiles)
    return mol
    AllChem.Compute2DCoords(mol)
    rings = mol.GetRingInfo().AtomRings()
    carbon_rings = [ring for ring in rings if all(mol.GetAtomWithIdx(idx).GetSymbol() == 'C' for idx in ring)]

    # Reorder the SMILES string to create rings with carbon blocks
    reordered_smiles = reorder_smiles_to_create_rings(mol_representation, atom_indices)

    # Print the original and reordered SMILES representations
    print("Original SMILES Representation:", smiles_representation)
    print("Reordered SMILES Representation:", reordered_smiles)

def graph_smiles(mol):
    # Draw the molecule using RDKit
    img = Draw.MolToImage(mol)

    # Display the image with Matplotlib
    plt.figure(figsize=(50, 50))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

Config:

DEVICE = "cpu"
SUPPORTED_EDGES = ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"]
SUPPORTED_ATOMS = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I"]
ATOMIC_NUMBERS =  [6, 7, 8, 9, 15, 16, 17, 35, 53]
valence_dict = {
    "H": 1,
    "C": 4,
    "N": 3,
    "O": 2,
    "F": 1,
    "P": 3,
    "S": 2,
    "Cl":1,
    "Br": 1,
    "I": 1,
}
MAX_MOLECULE_SIZE = 50
DISABLE_RDKIT_WARNINGS = True

Training:

import torch
from torch_geometric.data import DataLoader
from inhibitordataset import MoleculeDataset
from tqdm import tqdm
import numpy as np
import mlflow.pytorch
from rdkit import Chem
from utils import (count_parameters, gvae_loss,
                   slice_edge_type_from_edge_feats, slice_atom_type_from_node_feats,
                   graph_representation_to_molecule, to_one_hot)
from Gvaeold import GVAE
from config import DEVICE as device
strings = []
train_dataset = MoleculeDataset(root="data/", filename="/Users/Siddhu/Desktop/AB Drugs/ABYN.csv")[600:]
test_dataset = MoleculeDataset(root="data/", filename="/Users/Siddhu/Desktop/AB Drugs/ABYN.csv", test=True)[:600]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

sample_data = train_dataset[0]
node_input_dim = sample_data.x.shape[1]
edge_input_dim = sample_data.edge_attr.shape[1]  

model = GVAE(
    node_input_dim=node_input_dim,
    hidden_dim=32,  
    latent_dim=16,  
    encoder_embedding_size=64,
    latent_embedding_size=128,
    heads=4,
    edge_input_dim=edge_input_dim 
)

model = model.to(device)
print("Model parameters: ", count_parameters(model))

loss_fn = gvae_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
kl_beta = 0.5

def validate_molecule(mol):
    try:
        Chem.SanitizeMol(mol)
        return True
    except:
        return False

def run_one_epoch(data_loader, type, epoch, kl_beta):
    
    all_losses = []
    all_kldivs = []
    
    for _, batch in enumerate(tqdm(data_loader)):
        
        try:
            
            batch.to(device)
            
            optimizer.zero_grad()
            

            total_elements = batch.edge_attr.numel()

           
            if batch.x.shape[1] != model.node_input_dim:
                batch.x = batch.x.view(-1, model.node_input_dim)


            if total_elements % model.edge_input_dim == 0:
                batch.edge_attr = batch.edge_attr.view(-1, model.edge_input_dim)
            else:
                print(f"Skipping reshape for batch.edge_attr due to incompatible dimensions: {batch.edge_attr.shape}")

            triu_logits, node_logits, mu, logvar = model(batch.x.float(),
                                                         batch.edge_attr.float(),
                                                         batch.edge_index,
                                                         batch.batch)

                        edge_targets = slice_edge_type_from_edge_feats(batch.edge_attr.float())
            node_targets = slice_atom_type_from_node_feats(batch.x.float(), as_index=True)
            loss, kl_div = loss_fn(triu_logits, node_logits,
                                   batch.edge_index, edge_targets,
                                   node_targets, mu, logvar,
                                   batch.batch, kl_beta)
            if type == "Train":
                loss.backward()
                optimizer.step()
                        all_losses.append(loss.detach().cpu().numpy())
            all_kldivs.append(kl_div.detach().cpu().numpy())
        except IndexError as error: 
            print("Error: ", error)

    # Perform sampling
    if type == "Test":
        print(f"Starting test {epoch}")
        generated_smiles = []
        for _ in tqdm(range(100), desc="Sampling molecules"):
            z = torch.randn(1, model.latent_dim).to(device)
            dummy_batch_index = torch.Tensor([0]).int().to(device)
            triu_logits, node_logits = model.decode(z, dummy_batch_index)
            edge_matrix_shape = (int((model.max_num_atoms * (model.max_num_atoms - 1)) / 2), model.num_edge_types + 1)
            triu_preds_matrix = triu_logits.reshape(edge_matrix_shape)
            triu_preds = torch.argmax(triu_preds_matrix, dim=1)
            node_matrix_shape = (model.max_num_atoms, (model.num_atom_types) + 1)
            node_preds_matrix = node_logits.reshape(node_matrix_shape)
            node_preds = torch.argmax(node_preds_matrix[:, :9], dim=1)
            node_preds_one_hot = to_one_hot(node_preds, options=model.atomic_numbers)
            atom_numbers_dummy = torch.Tensor(model.atomic_numbers).repeat(node_preds_one_hot.shape[0], 1).to(device)
            atom_types = torch.masked_select(atom_numbers_dummy, node_preds_one_hot.bool())
            mol = graph_representation_to_molecule(atom_types)
            if mol and validate_molecule(mol):
                smiles = Chem.MolToSmiles(mol)
                generated_smiles.append(smiles)
                strings.append(smiles) 
        mlflow.log_metric(key=f"Sampled molecules", value=float(len(generated_smiles)), step=epoch)
 
    mlflow.log_metric(key=f"{type} Epoch Loss", value=float(np.array(all_losses).mean()), step=epoch)
    mlflow.log_metric(key=f"{type} KL Divergence", value=float(np.array(all_kldivs).mean()), step=epoch)
    mlflow.pytorch.log_model(model, "model")

    if type == "Train":
        torch.save(model.state_dict(), f'model_weights_epoch_{epoch}.pth')
        print(f'Model weights saved for epoch {epoch}')



with mlflow.start_run() as run:
    for epoch in range(5):
        model.train()
        run_one_epoch(train_loader, type="Train", epoch=epoch, kl_beta=kl_beta)
        model.eval()
        run_one_epoch(train_loader, type="Test", epoch=epoch, kl_beta=kl_beta)
    print(strings)

'''        
        if epoch % 5 == 0:
            print("Start test epoch...")
            model.eval()
            
            run_one_epoch(test_loader, type="Test", epoch=epoch, kl_beta=kl_beta)
'''
   
torch.save(model.state_dict(), 'final_model_weights.pth')
print('Final model weights saved')

Evaluation:

import json
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from rdkit import Chem
from rdkit.Chem import Descriptors
import pandas as pd
import numpy as np
from Visualize import smiles_to_image
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem.rdMolDescriptors import CalcMolFormula

def smiles_to_features(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return [
            Descriptors.MolWt(mol),  
            Descriptors.NumRotatableBonds(mol),  
            Descriptors.TPSA(mol),  
            Descriptors.MolLogP(mol),  
            Descriptors.HeavyAtomCount(mol),  
            Descriptors.FractionCSP3(mol),              Descriptors.NumHAcceptors(mol), 
            Descriptors.NumHDonors(mol),              Descriptors.RingCount(mol 
        ]
    else:
        return [0] * 9  

class SMILESDataset(Dataset):
    def __init__(self, file_path, test=False):
        self.data = pd.read_csv(file_path)
        self.features = np.array([smiles_to_features(smiles) for smiles in self.data['Canonical SMILES']])
        self.labels = np.array(self.data['Targeting-protein'], dtype=np.float32)

        if test:
            self.features = self.features[:600]
            self.labels = self.labels[:600]
        else:
            self.features = self.features[600:]
            self.labels = self.labels[600:]

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return torch.tensor(self.features[idx], dtype=torch.float32), torch.tensor(self.labels[idx],
                                                                                   dtype=torch.float32)


class InhibitionClassifier(nn.Module):
    def __init__(self, input_size=9):
        super(InhibitionClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 16)
        self.fc2 = nn.Linear(16, 8)
        self.fc3 = nn.Linear(8, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x


def train_model(dataset_path, epochs=20000, batch_size=16, learning_rate=0.001):
    arr=[]
    train_dataset = SMILESDataset(dataset_path)
    test_dataset = SMILESDataset(dataset_path, test=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model = InhibitionClassifier()
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_history = []
    for epoch in range(epochs):
        correct = 0
        total = 0
        total_loss = 0
        for features, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(features).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            predictions = (outputs >= 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
        accuracy = correct / total
        arr.append(accuracy)

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.2f}")
    torch.save(model.state_dict(), f'model_weights_epoch_{epoch}.pth')
    return model,arr


def predict_inhibition(smiles, model_path='model_weights_epoch_9.pth'):
    
    features = smiles_to_features(smiles)
    features = np.array(features, dtype=np.float32)

    features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)

    model = InhibitionClassifier()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    with torch.no_grad():
        output = model(features_tensor)
    probability = output.item()   
    if probability > 0.5:
        print(f"The molecule is likely an inhibitor with a probability of {probability:.4f}.")
        return True
    else:
        print(f"The molecule is likely not an inhibitor with a probability of {probability:.4f}.")

        return False

def evaluate_model(model, dataset_path):
    test_dataset = SMILESDataset(dataset_path, test=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    correct = 0
    total = 0
    with torch.no_grad():
        for features, labels in test_loader:
            outputs = model(features).squeeze()
            predictions = (outputs >= 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

Inhibitor Dataset:

import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.utils import dense_to_sparse

print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")

class MoleculeDataset(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
        self.test = test
        self.filename = filename
        super(MoleculeDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return [self.filename]

    @property
    def processed_file_names(self):
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        if self.test:
            return [f'data_test_{i}.pt' for i in list(self.data.index)]
        else:
            return [f'data_{i}.pt' for i in list(self.data.index)]

    def download(self):
        pass

    def process(self):
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        for index, row in tqdm(self.data.iterrows(), total=self.data.shape[0]):
            mol = Chem.MolFromSmiles(row["Canonical SMILES"])
            if mol is None:
                continue
            mol = Chem.AddHs(mol)
            AllChem.EmbedMolecule(mol, randomSeed=42)
            AllChem.UFFOptimizeMolecule(mol)

            # Featurize molecule
            node_features = self.get_node_features(mol)
            edge_index, edge_attr = self.get_edge_info(mol)

            data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
            data.y = self._get_label(row["Targeting-protein"])
            data.smiles = row["Canonical SMILES"]
            self._validate_data(data)
            if self.test:
                torch.save(data, os.path.join(self.processed_dir, f'data_test_{index}.pt'))
            else:
                torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt'))

    def get_node_features(self, mol):
        all_node_feats = []
        for atom in mol.GetAtoms():
            node_feats = [
                atom.GetAtomicNum(),  
                atom.GetSymbol(),  
                atom.GetTotalValence(),  
                atom.GetDoubleProp("_GasteigerCharge") if atom.HasProp("_GasteigerCharge") else 0.0
                atom.GetDoubleProp("_GasteigerCharge") if atom.HasProp("_GasteigerCharge") else 0.0,  
                atom.GetHybridization(),  
                atom.GetFormalCharge(),  
                atom.IsInRing()  
            ]
            print(node_feats)
            all_node_feats.append(node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def get_edge_info(self, mol):
        num_atoms = mol.GetNumAtoms()
        adj = np.zeros((num_atoms, num_atoms), dtype=np.float32)
        edge_feats = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            adj[i, j] = 1
            adj[j, i] = 1
            edge_feats.append([bond.GetBondTypeAsDouble()])
        edge_index, _ = dense_to_sparse(torch.tensor(adj))
        return edge_index, torch.tensor(edge_feats, dtype=torch.float)

    def _get_label(self, label):
        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.int64)

    def _validate_data(self, data):
        node_feature_dim = data.x.shape[1]
        if not hasattr(self, 'node_feature_dim'):
            self.node_feature_dim = node_feature_dim
        if node_feature_dim != self.node_feature_dim:
            raise ValueError(f"Inconsistent node feature dimensions: expected {self.node_feature_dim}, got {node_feature_dim}")

    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data