[docs]classController:""" A utility class for printing, logging, and controlling the status of the algorithm. The class provides methods for generating and displaying generation-specific statistics, as well as storing basic statistics and printing the archive to CSV files. Attributes: archive: An archive object containing elite molecules and associated data. surrogate: A surrogate model object used in the optimization process. generation: The current generation of the optimization algorithm. fitness_calls: The number of fitness function calls made so far. memory_of_molecules: A list to store molecules across generations. max_generations: The maximum number of generations for the optimization process. max_fitness_calls: The maximum number of fitness function calls allowed. remaining_fitness_calls: The remaining number of fitness function calls. Methods: __init__(config): Initializes a Controller object with the given configuration. set_archive(archive): Sets the archive object for the controller. active(): Checks if the optimization process is still active. update(): Updates the controller state and archives statistics. add_fitness_calls(fitness_calls): Adds to the total number of fitness function calls. write_statistics(statistics, metrics): Prints archive statistics to the console. store_statistics(statistics, metrics): Appends archive statistics to a CSV file. calculate_statistics(archive_data): Calculates various statistics from archive data. calculate_surrogate_metrics(molecules): Calculates surrogate model metrics. get_archive_data(): Retrieves elite molecule attributes and creates a DataFrame. store_molecules(): Stores all molecules in memory to a CSV file. """def__init__(self,config)->None:""" Initializes a Controller object with the given configuration. Args: config: Configuration object containing settings for the controller. """self.archive=Noneself.surrogate=Noneself.generation=0self.fitness_calls=0self.memory_of_molecules=[]self.max_generations=config.max_generationsself.max_fitness_calls=config.max_fitness_callsself.remaining_fitness_calls=self.max_fitness_calls
[docs]defset_archive(self,archive)->None:""" Sets the archive object for the controller. Args: archive: An archive object containing elite molecules and associated data. """self.archive=archivereturnNone
[docs]defactive(self)->bool:""" Checks if the optimization process is still active. Returns: bool: True if the process is active, False otherwise. """returnself.generation<self.max_generationsandself.fitness_calls<self.max_fitness_calls
[docs]defupdate(self)->None:""" Updates the controller state and archives statistics. This includes generating and printing statistics for the current generation, writing archive data to a CSV file, and appending basic statistics to another CSV file. """archive_data=self.get_archive_data()molecules=self.archive.incoming_moleculesarchive_statistics=self.calculate_statistics(archive_data)surrogate_metrics=self.calculate_surrogate_metrics(molecules)self.write_statistics(archive_statistics,surrogate_metrics)self.store_statistics(archive_statistics,surrogate_metrics)pd.DataFrame(data=archive_data).to_csv("archive_{}.csv".format(self.generation),index=False)self.generation=self.generation+1returnNone
[docs]defadd_fitness_calls(self,fitness_calls:int)->None:""" Adds to the total number of fitness function calls. Args: fitness_calls: The number of fitness calls to add. """self.fitness_calls=self.fitness_calls+fitness_callsself.remaining_fitness_calls=self.max_fitness_calls-self.fitness_callsreturnNone
[docs]defwrite_statistics(self,statistics:pd.DataFrame,metrics:pd.DataFrame)->None:""" Prints statistics about the archive to the console. Args: statistics: A DataFrame containing various archive and quality-diversity metrics. metrics: A DataFrame containing surrogate model metrics. """print("Generation: {}, Size: {:.2f}%, QD Score: {:.2f}".format(self.generation,statistics["coverage"]*100,statistics["quality_diversity_score"]))print("Fitness Max: {:.5f}, Fitness Mean: {:.5f}, Function Calls: {:.0f}".format(statistics["max_fitness"],statistics["mean_fitness"],self.fitness_calls))print("Surrogate model overview | Max Error: {:.2f}, MSE: {:.4f}, MAE: {:.4f}".format(metrics["max_err"],metrics["mse"],metrics["mae"]))returnNone
[docs]defstore_statistics(self,statistics:pd.DataFrame,metrics:pd.DataFrame)->None:""" Appends basic archive statistics to a CSV file saved to disk. Args: statistics: A DataFrame containing various archive and quality-diversity metrics. metrics: A DataFrame containing surrogate model metrics. """ifos.path.isfile("statistics.csv"):withopen("statistics.csv","a")asfile:csv.writer(file).writerow([self.generation]+[statistics["max_fitness"],statistics["mean_fitness"],statistics["quality_diversity_score"],statistics["coverage"]*100,]+[self.fitness_calls]+[metrics["max_err"],metrics["mse"],metrics["mae"]])file.close()else:withopen("statistics.csv","w")asfile:csv.writer(file).writerow(["generation"]+["maximum fitness"]+["mean fitness"]+["quality diversity score"]+["coverage"]+["function calls"]+["max_err"]+["mse"]+["mae"])csv.writer(file).writerow([self.generation]+[statistics["max_fitness"],statistics["mean_fitness"],statistics["quality_diversity_score"],statistics["coverage"]*100,]+[self.fitness_calls]+[metrics["max_err"],metrics["mse"],metrics["mae"]])file.close()returnNone
[docs]defcalculate_statistics(self,archive_data)->Dict:""" Calculates various statistics from the provided archive data. Args: archive_data: A DataFrame containing data about the elite molecules of the archive. Returns: Dict: A dictionary containing the calculated statistics. """coverage=len(archive_data["smiles"])/self.archive.archive_sizequality_diversity_score=np.sum(archive_data["fitness"])max_fitness,mean_fitness=np.max(archive_data["fitness"]),np.mean(archive_data["fitness"])return{"coverage":coverage,"max_fitness":max_fitness,"mean_fitness":mean_fitness,"quality_diversity_score":quality_diversity_score,}
[docs]defcalculate_surrogate_metrics(self,molecules:List[Molecule])->Dict:""" Calculates metrics for the surrogate model based on the provided molecules. Args: molecules (List[Molecule]): A list of Molecule objects. Returns: Dict: A dictionary containing surrogate model metrics. """self.memory_of_molecules=self.memory_of_molecules+self.archive.incoming_moleculesifself.generation==0:max_err,mae,mse=np.nan,np.nan,np.nanelse:fitnesses=np.array([molecule.fitnessformoleculeinmolecules])predicted_fitnesses=np.array([molecule.predicted_fitnessformoleculeinmolecules])max_err,mae,mse=max_error(fitnesses,predicted_fitnesses),mean_absolute_error(fitnesses,predicted_fitnesses),mean_squared_error(fitnesses,predicted_fitnesses)return{"max_err":max_err,"mae":mae,"mse":mse}
[docs]defget_archive_data(self)->None:""" Retrieves elite molecule attributes and creates a DataFrame. Returns: pd.DataFrame: A DataFrame containing attributes of elite molecules. """elite_molecules=[elite.moleculeforeliteinself.archive.elitesifelite.molecule]elite_attributes=[{attr:getattr(molecule,attr)forattrindir(molecule)ifnotcallable(getattr(molecule,attr))andnotattr.startswith("__")}formoleculeinelite_molecules]returnpd.DataFrame(elite_attributes)
[docs]defstore_molecules(self)->None:""" Stores all molecules in memory to a CSV file. """molecule_df=pd.DataFrame([{attr:getattr(molecule,attr)forattrindir(molecule)ifnotcallable(getattr(molecule,attr))andnotattr.startswith("__")}formoleculeinself.memory_of_molecules])molecule_df.to_csv("molecules.csv",index=False)returnNone
[docs]classArchive:""" A composite class containing the current elite molecules in a CVT tree structure. This class allows for processing new molecules, sampling existing elite molecules, and storing the current state of the archive on disk. The CVT centers are either loaded from or saved to cache disk storage. Attributes: archive_size: The size of the archive. archive_accuracy: The accuracy setting for the archive. archive_dimensions: The dimensionality of the archive. cache_string: The string for cache file naming. cvt_location: The location path for the CVT cache file. cvt_centers: The centers of the CVT clusters. cvt: A KDTree structure for CVT centers. elites: A list of Elite objects representing the archive. incoming_molecules: A list to store incoming molecules. Methods: __init__(config, archive_dimensions): Initializes the Archive with the given configuration and dimensions. cvt_index(descriptor): Returns the CVT index for the niche nearest to the given descriptor. update_niche_index(molecule): Calculates and stores the niche index of a molecule in the molecule object. add_to_archive(molecules): Adds molecules to the archive, keeping only the most fit molecule per niche. sample(size): Returns a list of elite molecules of the requested size, weighted by fitness. sample_pairs(size): Returns a list of pairs of elite molecules of the requested size, weighted by fitness. """def__init__(self,config,archive_dimensions:int)->None:""" Initializes the Archive with the given configuration and dimensions. Args: config: Configuration object containing settings for the archive. archive_dimensions: The dimensionality of the archive. """self.archive_size=config.sizeself.archive_accuracy=config.accuracyself.archive_dimensions=archive_dimensionsself.cache_string="cache_{}_{}.csv".format(self.archive_dimensions,self.archive_accuracy)self.cvt_location=hydra.utils.to_absolute_path("data/cvt/"+self.cache_string)ifos.path.isfile(self.cvt_location):self.cvt_centers=np.loadtxt(self.cvt_location)else:kmeans=KMeans(n_clusters=self.archive_size)kmeans=kmeans.fit(np.random.rand(config.accuracy,self.archive_dimensions))self.cvt_centers=kmeans.cluster_centers_np.savetxt(self.cvt_location,self.cvt_centers)self.cvt=KDTree(self.cvt_centers,metric="euclidean")self.elites=[Elite(index)forindex,_inenumerate(self.cvt_centers,start=0)]self.incoming_molecules=[]returnNone
[docs]defcvt_index(self,descriptor:List[float])->int:""" Returns the CVT index for the niche nearest to the given descriptor. Args: descriptor: A list of descriptor values for the molecule. Returns: int: The CVT index for the nearest niche. """returnself.cvt.query([descriptor],k=1)[1][0][0]
[docs]defupdate_niche_index(self,molecule:Molecule)->Molecule:""" Calculates and stores the niche index of a molecule in the molecule object. Args: molecule: The molecule for which to calculate the niche index. Returns: Molecule: The updated molecule with the niche index set. """molecule.niche_index=self.cvt_index(molecule.descriptor)returnmolecule
[docs]defadd_to_archive(self,molecules:List[Molecule])->None:""" Adds molecules to the archive, keeping only the most fit molecule per niche. Args: molecules: A list of molecules to be added to the archive. """formoleculeinmolecules:self.elites[self.cvt_index(molecule.descriptor)].update(molecule)self.incoming_molecules=moleculesreturnNone
[docs]defsample(self,size:int)->List[Chem.Mol]:""" Returns a list of elite molecules of the requested size, weighted by fitness. Args: size: The number of elite molecules to sample. Returns: List[Chem.Mol]: A list of sampled elite molecules. """pairs=[(elite.molecule,elite.molecule.fitness)foreliteinself.elitesifelite.molecule]molecules,weights=map(list,zip(*pairs))returnrandom.choices(molecules,k=size,weights=weights)
[docs]defsample_pairs(self,size:int)->List[Tuple[Chem.Mol,Chem.Mol]]:""" Returns a list of pairs of elite molecules of the requested size, weighted by fitness. Args: size: The number of pairs of elite molecules to sample. Returns: List[Tuple[Chem.Mol, Chem.Mol]]: A list of sampled pairs of elite molecules. """pairs=[(elite.molecule,elite.molecule.fitness)foreliteinself.elitesifelite.molecule]molecules,weights=map(list,zip(*pairs))sample_molecules=random.choices(molecules,k=size,weights=weights)sample_pairs=np.random.choice(list(filter(None,sample_molecules)),size=(size,2),replace=True)sample_pairs=[tuple(sample_pair)forsample_pairinsample_pairs]returnsample_pairs
[docs]classArbiter:""" A catalog class containing different drug-like filters for small molecules. This class includes the option to run structural filters from ChEMBL. Attributes: cache_smiles: A list to store SMILES strings of molecules for duplication checks. rules_dict: A DataFrame containing filter rules loaded from a CSV file. rules_list: A list of SMARTS strings for the filter rules. tolerance_list: A list of tolerance values for the filter rules. pattern_list: A list of RDKit molecule patterns for the filter rules. Methods: __init__(arbiter_config): Initializes the Arbiter with the given configuration. __call__(molecules): Applies the chosen filters to a list of molecules and removes duplicates. unique_molecules(molecules): Checks if a molecule in a list of molecules is duplicated. molecule_filter(molecular_graph): Checks if a given molecular structure passes through the chosen filters. toxicity(molecular_graph): Checks if a given molecule fails the structural filters. hologenicity(molecular_graph): Checks if a given molecule fails the hologenicity filters. ring_infraction(molecular_graph): Checks if a given molecule fails the ring infraction filters. veber_infraction(molecular_graph): Checks if a given molecule fails the veber infraction filters. """def__init__(self,config)->None:""" Initializes the Arbiter with the given configuration. Args: config: Configuration object containing settings for the filters. """self.cache_smiles=[]self.rules_dict=pd.read_csv(hydra.utils.to_absolute_path("data/smarts/alert_collection.csv"))self.rules_dict=self.rules_dict[self.rules_dict.rule_set_name.isin(config.rules)]self.rules_list=self.rules_dict["smarts"].values.tolist()self.tolerance_list=pd.to_numeric(self.rules_dict["max"]).values.tolist()self.pattern_list=[Chem.MolFromSmarts(smarts)forsmartsinself.rules_list]def__call__(self,molecules:List[Molecule]):""" Applies the chosen filters (hologenicity, veber infractions, ChEMBL structural alerts, etc.) to a list of molecules and removes duplicates. Args: molecules: A list of molecules to be filtered. Returns: List[Molecule]: A list of filtered molecules. """filtered_molecules=[]molecules=self.unique_molecules(molecules)formoleculeinmolecules:molecular_graph=Chem.MolFromSmiles(molecule.smiles)ifself.molecule_filter(molecular_graph):filtered_molecules.append(molecule)returnfiltered_molecules
[docs]defunique_molecules(self,molecules:List[Molecule])->List[Molecule]:""" Checks if a molecule in a list of molecules is duplicated, either in this batch or before. Args: molecules: A list of molecules to check for duplicates. Returns: List[Molecule]: A list of unique molecules. """unique_molecules=[]formoleculeinmolecules:ifmolecule.smilesnotinself.cache_smiles:unique_molecules.append(molecule)self.cache_smiles.append(molecule.smiles)returnunique_molecules
[docs]defmolecule_filter(self,molecular_graph:Chem.Mol)->bool:""" Checks if a given molecular structure passes through the chosen filters (hologenicity, veber infractions, ChEMBL structural alerts, etc.). Args: molecular_graph: The molecular graph to be checked. Returns: bool: True if the molecule passes all filters, False otherwise. """toxicity=self.toxicity(molecular_graph)hologenicity=self.hologenicity(molecular_graph)veber_infraction=self.veber_infraction(molecular_graph)validity=not(toxicityorhologenicityorveber_infraction)ifmolecular_graph.HasSubstructMatch(Chem.MolFromSmarts("[R]")):ring_infraction=self.ring_infraction(molecular_graph)validity=validityandnot(ring_infraction)returnvalidity
[docs]deftoxicity(self,molecular_graph:Chem.Mol)->bool:""" Checks if a given molecule fails the structural filters. Args: molecular_graph: The molecular graph to be checked. Returns: bool: True if the molecule fails the structural filters, False otherwise. """forpattern,toleranceinzip(self.pattern_list,self.tolerance_list):iflen(molecular_graph.GetSubstructMatches(pattern))>tolerance:returnTruereturnFalse
[docs]@staticmethoddefhologenicity(molecular_graph:Chem.Mol)->bool:""" Checks if a given molecule fails the hologenicity filters. Args: molecular_graph: The molecular graph to be checked. Returns: bool: True if the molecule fails the hologenicity filters, False otherwise. """fluorine_saturation=len(molecular_graph.GetSubstructMatches(Chem.MolFromSmarts("[F]")))>6bromide_saturation=len(molecular_graph.GetSubstructMatches(Chem.MolFromSmarts("[Br]")))>3chlorine_saturation=len(molecular_graph.GetSubstructMatches(Chem.MolFromSmarts("[Cl]")))>3returnchlorine_saturationorbromide_saturationorfluorine_saturation
[docs]@staticmethoddefring_infraction(molecular_graph:Chem.Mol)->bool:""" Checks if a given molecule fails the ring infraction filters. Args: molecular_graph: The molecular graph to be checked. Returns: bool: True if the molecule fails the ring infraction filters, False otherwise. """ring_allene=molecular_graph.HasSubstructMatch(Chem.MolFromSmarts("[R]=[R]=[R]"))macro_cycle=max([len(j)forjinmolecular_graph.GetRingInfo().AtomRings()])>6double_bond_in_small_ring=molecular_graph.HasSubstructMatch(Chem.MolFromSmarts("[r3,r4]=[r3,r4]"))returnring_alleneormacro_cycleordouble_bond_in_small_ring
[docs]@staticmethoddefveber_infraction(molecular_graph:Chem.Mol)->bool:""" Checks if a given molecule fails the veber infraction filters. Args: molecular_graph: The molecular graph to be checked. Returns: bool: True if the molecule fails the veber infraction filters, False otherwise. """rotatable_bond_saturation=Lipinski.NumRotatableBonds(molecular_graph)>10hydrogen_bond_saturation=Lipinski.NumHAcceptors(molecular_graph)+Lipinski.NumHDonors(molecular_graph)>10returnrotatable_bond_saturationorhydrogen_bond_saturation