#!/usr/bin/env python3
"""
syndirella.fairy.py
This module provides functions to output retrosynthesis queries for a given list of scaffolds.
"""
import json
import logging
import os.path
from typing import Dict, List
import pandas as pd
from syndirella.Postera import Postera
from syndirella.error import APIQueryError
from .cli_defaults import cli_default_settings
[docs]
logger = logging.getLogger(__name__)
with open(cli_default_settings['rxn_smarts_path']) as f:
[docs]
reaction_smarts = json.load(f)
[docs]
reaction_smarts_names: List[str] = list(reaction_smarts.keys())
[docs]
def save_df(df: pd.DataFrame, output_dir: str, csv_path: str) -> str:
"""
Save the DataFrame to the output directory.
"""
csv_basename = os.path.basename(csv_path)
pkl_basename = csv_basename.replace('.csv', '.pkl.gz')
saved_path = os.path.join(output_dir, f'justretroquery_{pkl_basename}')
df.to_pickle(saved_path)
return saved_path
[docs]
def retro_search(scaffold: str) -> pd.DataFrame:
"""
Perform retrosynthesis search on the given scaffold and formats outputs.
"""
postera = Postera()
routes: List[Dict[str, List[Dict[str, str]]]] | None = postera.perform_route_search(scaffold)
if routes is None:
logger.critical(f"API retrosynthesis query failed for {scaffold}.")
raise APIQueryError(message=f"API retrosynthesis query failed for {scaffold}.", smiles=scaffold)
formatted_routes: Dict = format_routes(routes)
logger.info(len(formatted_routes))
to_add = {'smiles': scaffold}
for route, details in formatted_routes.items():
to_add[route] = details
return pd.DataFrame([to_add]) # each dictionary is a single row
[docs]
def process_df(df: pd.DataFrame):
"""
Process the input DataFrame and create output df with retrosynthesis information.
"""
logger.info(f"Processing DataFrame with len {len(df)}")
route_infos = []
for i, scaffold in enumerate(df['smiles']):
route_info: pd.DataFrame = retro_search(scaffold)
route_infos.append(route_info)
# format the DataFrame
route_info_df = pd.concat(route_infos)
merged_df = df.merge(route_info_df, on='smiles')
merged_df.reset_index(drop=True, inplace=True)
return merged_df
#######################################
[docs]
def run_justretroquery(settings: Dict):
"""
Run the justretroquery pipeline with the given settings.
"""
# all you need is input csv with just a column of smiles
try:
csv_path: str = settings['input']
output_dir: str = settings['output']
except KeyError as e:
raise KeyError(f"Missing critical argument to run justretroquery: {e}")
df: pd.DataFrame = pd.read_csv(csv_path)
df: pd.DataFrame = process_df(df)
saved_path: str = save_df(df, output_dir, csv_path)
logger.info(f"Saved DataFrame to {saved_path}")
logger.info('Justretroquery execution completed successfully.')