Update code snippet to use sentence-level embeddings
Browse filesHello!
## Pull Request overview
* Use sentence level embeddings (pooler output) in code snippet
## Details
Before, the snippet gave only token-level embeddings, which are not very useful in practice generally. I think the pooler output will be more useful.
- Tom Aarsen
README.md
CHANGED
@@ -24,7 +24,7 @@ This checkpoint is first trained on code data via masked language modeling (MLM)
|
|
24 |
### How to use
|
25 |
This checkpoint consists of an encoder (1.3B model), which can be used to extract code embeddings of 2048 dimension. It can be easily loaded using the AutoModel functionality and employs the Starcoder tokenizer (https://arxiv.org/pdf/2305.06161.pdf).
|
26 |
|
27 |
-
```
|
28 |
from transformers import AutoModel, AutoTokenizer
|
29 |
|
30 |
checkpoint = "codesage/codesage-large"
|
@@ -33,10 +33,10 @@ device = "cuda" # for GPU usage or "cpu" for CPU usage
|
|
33 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
34 |
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)
|
35 |
|
36 |
-
inputs = tokenizer
|
37 |
-
embedding = model(inputs)
|
38 |
-
print(f'Dimension of the embedding: {embedding
|
39 |
-
# Dimension of the embedding: torch.Size([
|
40 |
```
|
41 |
|
42 |
### BibTeX entry and citation info
|
|
|
24 |
### How to use
|
25 |
This checkpoint consists of an encoder (1.3B model), which can be used to extract code embeddings of 2048 dimension. It can be easily loaded using the AutoModel functionality and employs the Starcoder tokenizer (https://arxiv.org/pdf/2305.06161.pdf).
|
26 |
|
27 |
+
```python
|
28 |
from transformers import AutoModel, AutoTokenizer
|
29 |
|
30 |
checkpoint = "codesage/codesage-large"
|
|
|
33 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
34 |
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)
|
35 |
|
36 |
+
inputs = tokenizer("def print_hello_world():\tprint('Hello World!')", return_tensors="pt").to(device)
|
37 |
+
embedding = model(**inputs).pooler_output
|
38 |
+
print(f'Dimension of the embedding: {embedding.size()}')
|
39 |
+
# Dimension of the embedding: torch.Size([1, 2048])
|
40 |
```
|
41 |
|
42 |
### BibTeX entry and citation info
|