|
--- |
|
metrics: |
|
- matthews_correlation |
|
- f1 |
|
tags: |
|
- biology |
|
- medical |
|
--- |
|
This version of DNABERT2 has been changed to be able to output the attention too, for attention analysis. |
|
|
|
**To the author of DNABERT2, feel free to use those modifications.** |
|
|
|
Use ```--model_name_or_path jaandoui/DNABERT2-AttentionExtracted``` instead of the original repository to have access to the attention. |
|
|
|
Most of the modifications were done in Bert_Layer.py. |
|
It has been modified especially for fine tuning and hasn't been tried for pretraining. |
|
Before or next to each modification, you can find ```"JAANDOUI"``` so to see al modifications, search for ```"JAANDOUI"```. |
|
```"JAANDOUI TODO"``` means that if that part is going to be used, maybe something might be missing. |
|
|
|
Now in ```Trainer``` (or ```CustomTrainer``` if overwritten) in ```compute_loss(..)``` when defining the model: |
|
```outputs = model(**inputs, return_dict=True, output_attentions=True)``` |
|
activate the extraction of attention: ```output_attentions=True``` (and ```return_dict=True``` (optional)). |
|
You can now extract the attention in ```outputs.attentions``` |
|
Note than the output has a third dimension, mostly of value 12, referring to the layer ```outputs.attentions[-1]``` refers to the attention of the last layer. |
|
Read more about model outputs here: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/output#transformers.utils.ModelOutput |
|
|
|
I'm also not using Triton, therefore cannot guarantee that it will work with it. |
|
|
|
I also read that there were some problems with extracting attention when using Flash Attention here: https://github.com/huggingface/transformers/issues/28903 |
|
Not sure if that is relevant for us, since it's about Mistral models. |
|
|
|
I'm still exploring this attention, please don't take it as if it works 100%. I'll update the repository when I'm sure. |
|
|
|
The official link to DNABERT2 [DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome |
|
](https://arxiv.org/pdf/2306.15006.pdf). |
|
|
|
READ ME OF THE OFFICIAL DNABERT2: |
|
We sincerely appreciate the MosaicML team for the [MosaicBERT](https://openreview.net/forum?id=5zipcfLC2Z) implementation, which serves as the base of DNABERT-2 development. |
|
|
|
DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome. |
|
|
|
To load the model from huggingface: |
|
``` |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) |
|
``` |
|
|
|
To calculate the embedding of a dna sequence |
|
``` |
|
dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC" |
|
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"] |
|
hidden_states = model(inputs)[0] # [1, sequence_length, 768] |
|
|
|
# embedding with mean pooling |
|
embedding_mean = torch.mean(hidden_states[0], dim=0) |
|
print(embedding_mean.shape) # expect to be 768 |
|
|
|
# embedding with max pooling |
|
embedding_max = torch.max(hidden_states[0], dim=0)[0] |
|
print(embedding_max.shape) # expect to be 768 |
|
``` |