ESM-2 RNA Binding Site LoRA
This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of the esm2_t12_35M_UR50D model for the (binary) token classification task of predicting RNA binding sites of proteins. You can also find a version of this model that was fine-tuned without LoRA here.
Training procedure
This is a Low Rank Adaptation (LoRA) of esm2_t12_35M_UR50D
,
trained on 166
protein sequences in the RNA binding sites dataset
using a 85/15
train/test split. This model was trained with class weighting due to the imbalanced nature
of the RNA binding site dataset (fewer binding sites than non-binding sites). This model has slightly improved
precision, recall, and F1 score over AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding
but may suffer from mild overfitting, as indicated by the training loss being slightly lower than the eval loss (see metrics below).
If you are searching for binding sites and aren't worried about false positives, the higher recall may make this model
preferable to the other RNA binding site predictors.
You can train your own version
using this notebook!
You just need the RNA binding_sites.xml
file found here.
You may also need to run some pip install
statements at the beginning of the script. If you are running in colab run:
!pip install transformers[torch] datasets peft -q
!pip install accelerate -U -q
Try to improve upon these metrics by adjusting the hyperparameters:
{'eval_loss': 0.500779926776886,
'eval_precision': 0.1708695652173913,
'eval_recall': 0.8397435897435898,
'eval_f1': 0.2839595375722543,
'eval_auc': 0.771835775620126,
'epoch': 11.0}
{'loss': 0.4171,
'learning_rate': 0.00032491416877500004,
'epoch': 11.43}
A similar model can also be trained using the Github with a training script and conda env YAML, which can be found here. This version uses wandb sweeps for hyperparameter search. However, it does not use class weighting.
Framework versions
- PEFT 0.4.0
Using the Model
To use the model, try running the following pip install statements:
!pip install transformers peft -q
then try tunning:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_UR50D_RNA_LoRA_weighted"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
- Downloads last month
- 4