fromgauche.kernels.fingerprint_kernels.tanimoto_kernelimportTanimotoKernelfromabcimportABC,abstractmethodfromrdkit.ChemimportrdFingerprintGeneratorfrombotorch.fitimportfit_gpytorch_modelfromgpytorch.distributionsimportMultivariateNormalfromgpytorch.kernelsimportScaleKernelfromgpytorch.likelihoodsimportGaussianLikelihoodfromgpytorch.meansimportConstantMeanfromgpytorch.mllsimportExactMarginalLogLikelihoodfromgpytorch.modelsimportExactGPfromrdkitimportChemimporttorchimportnumpyasnpimportselfiesassffromsklearn.feature_extraction.textimportCountVectorizerclassGP_Surrogate(ABC):""" A strategy class representing a surrogate model for predicting fitness values of molecules. The surrogate model is based on Gaussian Processes (GP) regression using the Tanimoto kernel. Attributes: - model: The Gaussian Process (GP) regression model. - mll: The Exact Marginal Log-Likelihood (mll) associated with the GP model. - state_dict: The state dictionary of the GP model for model persistence. - representations: Numpy array storing fingerprint representations of molecules. - fitnesses: Numpy array storing fitness values of molecules. Methods: - __init__: Initializes the Surrogate object with default or provided parameters. - __call__: Evaluates the surrogate model on a set of molecules, updating their predicted fitness and uncertainty. - update_model: Abstract method that needs to be implemented to update the surrogate model with new molecules and their fitness values. - intitialise_model: Abstract method that needs to be implemented to initializes the surrogate model with an initial set of molecules and their fitness values. - calculate_representations: Abstract method that needs to be implemented to calculate represenations of molecules for the surrogate model. """def__init__(self,config)->None:""" Initializes the Surrogate object with default or provided parameters. Args: config: An object specifying the configuration for the surrogate model. """self.config=configself.model=Noneself.mll=Noneself.state_dict=Noneself.encodings=Noneself.fitnesses=NonereturnNonedef__call__(self,molecules):""" Evaluates the surrogate model on a set of molecules, updating their predicted fitness and uncertainty. Args: molecules: A list of Molecule objects to be evaluated. Returns: List[Molecule]: A list of molecules with updated predicted fitness and uncertainty. """encodings=self.calculate_encodings(molecules)self.update_model()molecules=self.inference(molecules,encodings)returnmoleculesdefupdate_model(self):""" Updates the surrogate model with new molecules and their fitness values. """self.model=TanimotoGP(torch.tensor(self.encodings),torch.tensor(self.fitnesses).flatten())self.mll=ExactMarginalLogLikelihood(self.model.likelihood,self.model)ifself.state_dictisnotNone:self.model.load_state_dict(self.state_dict)fit_gpytorch_model(self.mll)self.state_dict=self.model.state_dict()returnNonedefinference(self,molecules,encodings):""" Performs inference on a set of molecules to predict fitness and uncertainty. Args: molecules: A list of Molecule objects to be evaluated. encodings: The encodings of the molecules. Returns: List[Molecule]: A list of molecules with updated predicted fitness and uncertainty. """self.mll.eval()self.model.eval()predictions=self.model(torch.tensor(encodings))formolecule,prediction_mean,prediction_varianceinzip(molecules,predictions.mean,predictions.variance):molecule.predicted_fitness=prediction_mean.detach().item()molecule.predicted_uncertainty=prediction_variance.detach().item()returnmolecules@abstractmethoddefadd_to_prior_data(self,molecules):""" Adds new molecules to the prior data for the surrogate model. Args: molecules: A list of Molecule objects to be added to the prior data. Returns: None """raiseNotImplementedError@abstractmethoddefcalculate_encodings(self,molecules):""" Calculates representations (encodings) of molecules for the surrogate model. Args: molecules: A list of Molecule objects to be encoded. Returns: Numpy array: Encodings of the molecules. """raiseNotImplementedErrorclassTanimotoGP(ExactGP):""" A Gaussian Process (GP) regression model using the Tanimoto kernel for molecular data. Attributes: mean_module: The mean function of the GP model. covar_module: The covariance (kernel) function of the GP model. likelihood: The likelihood function of the GP model. Methods: __init__: Initializes the TanimotoGP model with training data. forward: Performs the forward pass of the GP model to compute the predictive distribution. """def__init__(self,train_X,train_Y):""" Initializes the TanimotoGP model with training data. Args: train_X: Tensor of training input data (molecular encodings). train_Y: Tensor of training output data (fitness values). """super().__init__(train_X,train_Y,likelihood=GaussianLikelihood())self.mean_module=ConstantMean()self.covar_module=ScaleKernel(base_kernel=TanimotoKernel())self.to(train_X)defforward(self,x):""" Performs the forward pass of the GP model to compute the predictive distribution. Args: x: Tensor of input data for which predictions are to be made. Returns: MultivariateNormal: The predictive mean and covariance for the input data. """mean_x=self.mean_module(x)covar_x=self.covar_module(x)returnMultivariateNormal(mean_x,covar_x)
[docs]classFingerprint_Surrogate(GP_Surrogate):""" A surrogate model using molecular fingerprints for predicting fitness values. The surrogate model is based on Gaussian Processes (GP) regression. Attributes: representation: The type of molecular fingerprint used for encoding molecules. generator: The fingerprint generator corresponding to the chosen representation. Methods: __init__: Initializes the Fingerprint_Surrogate object with the specified fingerprint representation. calculate_encodings: Calculates fingerprint encodings for a list of molecules. add_to_prior_data: Adds new molecules and their fitness values to the training data for the GP model. """def__init__(self,config):""" Initializes the Fingerprint_Surrogate object with the specified fingerprint representation. Args: config: An object specifying the configuration for the surrogate model. """super().__init__(config)self.representation=self.config.representationmatchself.representation:case"ECFP4":self.generator=rdFingerprintGenerator.GetMorganGenerator(radius=2,fpSize=2048)case"ECFP6":self.generator=rdFingerprintGenerator.GetMorganGenerator(radius=3,fpSize=2048)case"FCFP4":self.generator=rdFingerprintGenerator.GetMorganGenerator(radius=2,fpSize=2048,atomInvariantsGenerator=rdFingerprintGenerator.GetMorganFeatureAtomInvGen())case"FCFP6":self.generator=rdFingerprintGenerator.GetMorganGenerator(radius=3,fpSize=2048,atomInvariantsGenerator=rdFingerprintGenerator.GetMorganFeatureAtomInvGen())case"RDFP":self.generator=rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=2048)case"APFP":self.generator=rdFingerprintGenerator.GetAtomPairGenerator(fpSize=2048)case"TTFP":self.generator=rdFingerprintGenerator.GetTopologicalTorsionGenerator(fpSize=2048)case_:raiseValueError(f"{self.representation} is not a supported fingerprint type.")
[docs]defcalculate_encodings(self,molecules):""" Calculates fingerprint encodings for a list of molecules. Args: molecules: A list of molecules to be encoded. Returns: List[np.ndarray]: A list of fingerprint encodings. """molecular_graphs=[Chem.MolFromSmiles(Chem.CanonSmiles(molecule.smiles))formoleculeinmolecules]returnnp.array([self.generator.GetFingerprintAsNumPy(molecular_graph)formolecular_graphinmolecular_graphs]).astype(np.float64)
[docs]defadd_to_prior_data(self,molecules):""" Adds new molecules and their fitness values to the training data for the GP model. Args: molecules: A list of new molecules to be added to the training data. Returns: None """ifself.encodingsisnotNoneandself.fitnessesisnotNone:self.encodings=np.append(self.encodings,self.calculate_encodings(molecules),axis=0)self.fitnesses=np.append(self.fitnesses,np.array([molecule.fitnessformoleculeinmolecules]),axis=None,)else:self.encodings=self.calculate_encodings(molecules)self.fitnesses=np.array([molecule.fitnessformoleculeinmolecules])returnNone
[docs]classString_Surrogate(GP_Surrogate):""" A surrogate model using molecular string representations (SMILES or SELFIES) for predicting fitness values. The surrogate model is based on Gaussian Processes (GP) regression. Attributes: smiles: A list of SMILES strings representing the molecules. representation: The type of molecular string representation used (e.g., SMILES or SELFIES). cv: A CountVectorizer object for converting molecular strings into numerical representations. Methods: __init__: Initializes the String_Surrogate object with the specified molecular string representation. calculate_encodings: Calculates string encodings for a list of molecules. add_to_prior_data: Adds new molecules and their fitness values to the training data for the GP model. """def__init__(self,config):""" Initializes the String_Surrogate object with the specified molecular string representation. Args: config: An object specifying the configuration for the surrogate model. """super().__init__(config)self.smiles=[]self.representation=self.config.representationself.cv=CountVectorizer(ngram_range=(1,self.config.max_ngram),analyzer="char",lowercase=False)
[docs]defcalculate_encodings(self,molecules):""" Calculates string encodings for a list of molecules. Args: molecules: A list of molecules to be encoded. Returns: np.ndarray: A 2D array of string encodings. """smiles=[molecule.smilesformoleculeinmolecules]combined_smiles=self.smiles+smilesifself.representation=="Smiles":bag_of_characters=self.cv.fit_transform(combined_smiles)elifself.representation=="Selfies":bag_of_characters=self.cv.fit_transform([sf.encoder(smiles)forsmilesincombined_smiles])else:raiseValueError(f"{self.representation} is not a supported type of molecular string.")self.encodings=bag_of_characters[:len(self.smiles)].toarray().astype(np.float64)returnbag_of_characters[-len(smiles):].toarray().astype(np.float64)
[docs]defadd_to_prior_data(self,molecules):""" Adds new molecules and their fitness values to the training data for the GP model. Args: molecules: A list of new molecules to be added to the training data. Returns: None """ifself.smilesisnotNoneandself.fitnessesisnotNone:self.smiles=self.smiles+[molecule.smilesformoleculeinmolecules]self.fitnesses=np.append(self.fitnesses,np.array([molecule.fitnessformoleculeinmolecules]),axis=None)else:self.smiles=[molecule.smilesformoleculeinmolecules]self.fitnesses=np.array([molecule.fitnessformoleculeinmolecules])returnNone