Source code for syndirella.Postera

#!/usr/bin/env python3
"""
Postera.py

This module contains the functionality for a Postera search.
"""
import json
import logging
import os
import random
import sys
import time
from typing import (Any, List, Dict, Tuple, Optional)

import requests
from rdkit import Chem

import syndirella.fairy as fairy
from syndirella.DatabaseSearch import DatabaseSearch


[docs] class Postera(DatabaseSearch): """ This class contains information about the Postera search. It will perform the Postera search using the perform_database_search function. It will also store the results of the Postera search as a .csv file. """ def __init__(self): super().__init__()
[docs] self.url = "https://api.postera.ai"
[docs] self.api_key = os.environ["MANIFOLD_API_KEY"]
[docs] self.logger = logging.getLogger(f"{__name__}")
[docs] def structure_output(self, hits: List[Dict] | None, query_smiles: str, keep_catalogue: bool = False) -> List[ Tuple[ str, Tuple[ str, str] | None]] | None: """ Formats output into a list of tuples with smiles and catalogue. """ hits_info: List[Tuple[str, Tuple[str, str] | None]] = [] if hits is None: self.logger.critical( f"Error with API output, returning empty list for superstructure search for {query_smiles}!") return None for hit in hits: if keep_catalogue and type(hit['catalogEntries']) is list and len(hit['catalogEntries']) > 0: for entry in hit['catalogEntries']: entry: dict hits_info.append((hit['smiles'], (entry['catalogName'], entry['catalogId']))) else: hits_info.append((hit['smiles'], None)) if len(hits_info) == 0: self.logger.warning( f"No superstructures found for {query_smiles}, returning original query reactant, {query_smiles}.") # add query smiles just in case it's not returned hits_info.append((query_smiles, None)) return hits_info
@staticmethod
[docs] def get_resp_json(url: str, api_key: str, data: Dict = None, retries: int = 50, backoff_factor: float = 0.5) -> Optional[Dict] | None: """ Directly get the response json from a request, with retry mechanism for handling 429 status code. """ logger = logging.getLogger(__name__) for attempt in range(retries): try: response = requests.post( url, headers={ 'X-API-KEY': api_key, 'Content-Type': 'application/json', }, data=json.dumps(data), ) if response.status_code in [429, 504]: if attempt < retries - 1: # Calculate wait time using jittered exponential backoff strategy with at most 3 minutes wait_time = backoff_factor * (2 ** attempt) if wait_time > 180: # choose randomly num attempts to wait for wait_time = random.uniform(0, 180) error_type = "Rate limit exceeded" if response.status_code == 429 else "Gateway timeout" logger.warning(f"{error_type}. Waiting for {wait_time} seconds before retrying...") time.sleep(wait_time) continue else: logger.error( f"Max retries exceeded with status code {response.status_code}. Please try again later. " f"{response.status_code}") return None response.raise_for_status() return response.json() except requests.exceptions.HTTPError as err: logger.error(f"HTTP error: {err}") except requests.exceptions.ConnectionError as err: logger.error(f"Connection error: {err}") except requests.exceptions.Timeout as err: logger.error(f"Timeout error: {err}") except requests.exceptions.RequestException as err: logger.error(f"Error: {err}") break # Exit the loop on non-recoverable errors return None
@staticmethod
[docs] def get_search_results(url: str, api_key: str, data: Dict[str, Any], page: int = 1, max_pages: int = 900) -> List[Dict] | None: """ Recursively get all pages for the endpoint until reach null next page or the max_pages threshold. The default max_pages is set to 900 since the recursion limit is 1000. """ current_limit = sys.getrecursionlimit() if max_pages > (current_limit - 100): raise ValueError(f"max_pages ({max_pages}) too close to recursion limit ({current_limit})") # List where we will gather up all the hits_path. all_hits = [] data = { **data, 'page': page, } response: Dict | None = Postera.get_resp_json(url, api_key, data) if response is None: # API error, returning None return None elif response.get('results') is not None: all_hits.extend(response.get('results', [])) elif response.get('routes') is not None: all_hits.extend(response.get('routes', [])) # Grab more hits_path if there is a next page supplied. if page >= max_pages: return all_hits next_page = response.get('nextPage', None) if next_page is not None: next_hits = Postera.get_search_results( url, api_key, data, next_page, max_pages ) all_hits.extend(next_hits) return all_hits
[docs] def make_query(self): """ This function is used to make a query to the Postera database. """ pass