Source code for syndirella.slipper.slipper_synthesizer.Labeler

#!venv/bin/env python3
"""
slipper_synthesizer/Labeler.py

This module contains a helper module for the SlipperSynthesizer class to label the products based on labeled atom ids
to expand or not expand.
"""
import os
from rdkit import Chem
from rdkit.Chem import rdFMCS
from typing import (List, Dict, Tuple, Union, Optional)
import pandas as pd
from rdkit.Chem.Draw import rdMolDraw2D
from syndirella.route.Library import Library
import time

[docs] class Labeler: def __init__(self, products: pd.DataFrame, atom_ids_expansion: Dict[int, bool], library: Library):
[docs] self.products: pd.DataFrame = products
[docs] self.atom_ids_expansion: dict = atom_ids_expansion
[docs] self.library: Library = library
[docs] self.output_dir: str = library.output_dir
[docs] def label_products(self): """ This is the main entry function for the Labeler class. """ self.show_atoms_to_expand_and_not_expand() # save png of colored atoms self.show_mcs_on_products() # save png of mcs on products self.label_products_with_atom_ids() return self.products
[docs] def show_atoms_to_expand_and_not_expand(self): """ This will save a png of the scaffold compound with atoms to expand colored in green and not to expand colored in red. """ # get scaffold compound base_compound: Chem.Mol = self.library.reaction.scaffold # label atom ids for i, atom in enumerate(base_compound.GetAtoms()): atom_index = atom.GetIdx() atom.SetProp("molAtomMapNumber", str(atom_index)) # assert that atom ids are in the scaffold compound assert all([atom_id in [atom.GetIdx() for atom in base_compound.GetAtoms()] for atom_id in self.atom_ids_expansion.keys()]), "Atom ids to expand are not in the scaffold compound." atom_ids: List[int] = [atom_id for atom_id in self.atom_ids_expansion.keys() if self.atom_ids_expansion[atom_id] is True or self.atom_ids_expansion[atom_id] is False] green = (0, 1, 0) red = (1, 0, 0) atom_colors: Dict[int, str] = {atom_id: green if self.atom_ids_expansion[atom_id] is True else red for atom_id in atom_ids} drawer = rdMolDraw2D.MolDraw2DSVG(400, 400) drawer.DrawMolecule(base_compound, highlightAtoms=atom_ids, highlightAtomColors=atom_colors) drawer.FinishDrawing() svg = drawer.GetDrawingText() with open(os.path.join(self.output_dir, "base_expansion.svg"), "w") as f: f.write(svg)
[docs] def label_products_with_atom_ids(self): """ This function will label the products with the atom ids to expand and not expand. """ # get the atom ids to expand and not expand atom_ids_to_expand: List[int] = [atom_id for atom_id in self.atom_ids_expansion.keys() if self.atom_ids_expansion[atom_id] is True] atom_ids_to_not_expand: List[int] = [atom_id for atom_id in self.atom_ids_expansion.keys() if self.atom_ids_expansion[atom_id] is False] # time how long this takes start = time.time() self.products = self.products.apply(lambda row: self._label_product_with_atom_ids(atom_ids_to_expand, atom_ids_to_not_expand, row), axis=1) end = time.time() elapsed_time = end - start # Convert elapsed time to hours, minutes, and seconds hours = int(elapsed_time // 3600) minutes = int((elapsed_time % 3600) // 60) seconds = elapsed_time % 60 print(f"The function took {hours} hours, {minutes} minutes, and {seconds:.2f} seconds to complete.")
[docs] def _label_product_with_atom_ids(self, atom_ids_to_expand: List[int], atom_ids_to_not_expand: List[int], row: pd.Series) -> pd.Series: """ This function goes through each scaffold and checks if the compound has been expanded or not. """ results_atom_ids_to_expand: Dict[int, bool] = {} results_atom_ids_to_not_expand: Dict[int, bool] = {} for atom_id in atom_ids_to_expand: results_atom_ids_to_expand[atom_id]: List[bool or str] = self._has_non_mcs_bond(row["smiles"], self.library.reaction.scaffold, atom_id) row_name = f"expanded_on_atom_{atom_id}" row[row_name] = results_atom_ids_to_expand[atom_id] for atom_id in atom_ids_to_not_expand: results_atom_ids_to_not_expand[atom_id]: List[bool or str] = self._has_non_mcs_bond(row["smiles"], self.library.reaction.scaffold, atom_id) row_name = f"expanded_on_atom_{atom_id}" row[row_name] = results_atom_ids_to_not_expand[atom_id] # Qualify if good expansion on if only expanding on good atoms and not expanding on bad atoms import math # Assuming results_atom_ids_to_not_expand is a dictionary # and row is a dictionary representing a row in a dataframe or similar structure if any(value == 'non_mcs_match' for value in results_atom_ids_to_expand.values()) or \ any(value == 'non_mcs_match' for value in results_atom_ids_to_not_expand.values()): row["bad_expansion"] = 'non_mcs_match' row["good_expansion"] = 'non_mcs_match' elif any(value for value in results_atom_ids_to_expand.values()) and \ not any(value for value in results_atom_ids_to_not_expand.values()): # if expanding on any good atom --> good, if not expanding on any bad atom --> good row["bad_expansion"] = False row["good_expansion"] = True elif any(value for value in results_atom_ids_to_not_expand.values()): # if expanding on any bad atom --> bad row["bad_expansion"] = True row["good_expansion"] = False else: # if there is expansion on both bad and good --> indeterminate --> None row["bad_expansion"] = None row["good_expansion"] = None return row
[docs] def _has_non_mcs_bond(self, mol_smiles: str, ref_mol: Chem.Mol, atom_id_ref: int) -> bool or str: """ Check if the specified atom in the given molecule has a bond that is not in the MCS with a reference molecule. :param mol_smiles: SMILES string of the molecule to check :param ref_smiles: SMILES string of the reference molecule :param atom_id: Atom ID to check in the given molecule :return: True if the atom has a non-MCS bond, False otherwise """ mol = Chem.MolFromSmiles(mol_smiles) # Find MCS between mol and ref_mol mcs_res = rdFMCS.FindMCS([mol, ref_mol]) mcs_mol = Chem.MolFromSmarts(mcs_res.smartsString) # check that mcs matches the ref_mol exactly so we are accurately comparing atom indicies if mcs_mol.GetNumAtoms() != ref_mol.GetNumAtoms(): return 'non_mcs_match' # Get atom mapping for the MCS in ref ref_match = {ref_idx: mcs_idx for mcs_idx, ref_idx in enumerate(ref_mol.GetSubstructMatch(mcs_mol))} # Get atom mapping for the MCS in mol mcs_match = mol.GetSubstructMatch(mcs_mol) # if atom not found in ref_match, ignore # Find mcs atom id try: mcs_atom_id = ref_match[atom_id_ref] except KeyError: return 'non_mcs_match' # corresponds to not found # Find atom id in mol to check from mcs atom id atom_id = mcs_match[mcs_atom_id] # Check if the atom_id is valid if atom_id >= mol.GetNumAtoms(): raise ValueError("Invalid atom ID.") # Check bonds of the specified atom for bond in mol.GetAtomWithIdx(atom_id).GetBonds(): # Get the indices of the bonded atoms begin_atom_idx = bond.GetBeginAtomIdx() end_atom_idx = bond.GetEndAtomIdx() # Check if this bond is in the MCS if begin_atom_idx not in mcs_match or end_atom_idx not in mcs_match: # Bond is not in MCS return True # No non-MCS bonds found return False
[docs] def show_mcs_on_products(self): """ This function will output a png of the MCS on the products. """ return NotImplementedError