Getting log probabilities from the Inference API?
I was interested in getting log-probabilities per token (including both prompt and completion tokens), similar to the output given the logprobs
argument to the OpenAI models. Is this available in the API at all, or could it be?
+1 for this. To emulate the "best of" option available with GPT-3, it would be great to be able to specify different seeds and get the total likelihood for each result. Then one could just take the best of 20 requests.
At a glance, num_return_sequences
seems like it might be helpful, but in my testing I always get just a single result when calling the API even while using this option. Either way, without the probabilities it wouldn't be much use anyhow.
@sleven
what do you mean best of
option ? Isn't what you are describing being exactly what num_beams
would be doing (using beam search do find the best generation among n
?)
Looking at their docs it's exactly what num_beams
does in transformers
. https://beta.openai.com/docs/api-reference/completions/create#completions/create-best_of
Is that what you are referring to ?
As for the logprobs, it's not yet available for bloom, but we could enable part of it at some point.
Are there any other parameters from gpt-3 that would be important ?
Actually no, beam search is different from "best of". I mean, I might be wrong, I don't have GPT-3's source code, but I think their "best of" feature is the "sample and rank" method from Google's Meena paper (Adiwardana et al).
Sample and rank generates N fully independent responses using plain random sampling with a temperature, then selects the response with the highest total log likelihood divided by length ("length-normalized log-likelihood scores").
Beam search is at its core still a greedy algorithm that continuously extends a set of B "leading" hypothesis at each step. Beam search is more efficient since it throws away "losers" early, but it tends to suffer from repetitive and uninteresting results since it can't go very far down a path with an unlikely (but perhaps more interesting) token in it.
Since sample and rank generates full completions it can sample some more unusual tokens and still generate a completion with good average likelihood by choosing its other tokens well. So you get more diverse outcomes.
One could implement sample and rank with the API by requesting X completions (no beam search, sampling on, some temperature > 0) with different seeds. But in addition to the textual response you'd need the likelihood assigned to each token in the response so you can calculate the total log likelihood for each of the X completions. (Then pick the top 1 or present the top 2 to the end user for selection or whatever you want.)
Are there any other parameters from gpt-3 that would be important ?
stop
for specifying stop tokens would be nice. So you could ask to generate exactly one line by giving \n
as a stop token, for example.
And although I requested logprobs
in the context of getting the likelihood per generated token (which would correspond to logprobs=1
), you can also do logprobs=5
which gives you the top k candidates for each token. Example here showing that when prompted with "what is the capital of France", GPT-3's most likely candidates are Paris, par, PAR, etc. This can help give a sense of how "confident" the model is about a response beyond the simple likelihood of the token. E.g. in this example, if we had a way to recognise that Paris, par, PAR are all the same idea, we can conclude the model is more confident the answer is Paris than just the likelihood of the Paris token alone suggests.
Sorry to hijack this post for a related question. I'm also very interested to access the logprobs
from BLOOM. I'm not an expert, but I've read that the BloomForCausalLM
has the logits
return values, which should be the prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). This is the first time working in-depth with these models, so maybe I'm completely on the wrong track and "prediction scores of the language modeling head" are not "likelihood per generated token".
From "Natural Language Processing with Transformers" By Lewis Tunstall, Leandro von Werra, Thomas Wolf. I'm currently trying to implement this.
https://github.com/nlp-with-transformers/notebooks/blob/main/05_text-generation.ipynb
This has become an awkward monologue. Anyways, I'm suggesting the following:
import torch
from transformers import BloomTokenizerFast, BloomForCausalLM
import torch.nn.functional as F
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m")
prompt = prompt = "The horse raced past the barn fell."
encoded = tokenizer(prompt, return_tensors="pt").to("cpu")
input_ids = encoded["input_ids"]
output = model(input_ids=input_ids)
# neglecting the first token, since we make no prediction about it
shift_labels = input_ids[..., 1:].contiguous()
shift_logits = output.logits[..., :-1, :].contiguous()
print("TOKEN : LOGPROB\n")
print(tokenizer.decode(input_ids[0].tolist()[0])," : ", None)
for label_id,logit in zip(shift_labels[0].tolist(), shift_logits[0]):
logprob = F.log_softmax(logit, dim=0).tolist()[label_id]
print(tokenizer.decode(label_id)," : ", logprob)
This results in:
TOKEN : LOGPROB
The : None
horse : -10.14762020111084
rac : -8.096358299255371
ed : -0.07634077966213226
past : -3.50999116897583
the : -1.575127363204956
barn : -4.399406433105469
fell : -10.955772399902344
. : -4.558294296264648
Please let me know if that makes sense. If it does, I'll write up a function for @Narsil
Thank you very much for the well crafted answer !
Really extremely helpful to get me up to speed as to what OpenAI does.
I don't think there is support for BestOf
within transformers
in the generate
function itself (which we hackishly leverage).
Any PR here would be super welcome.
And for the logprobs that's what I had in mind, it's sort of supported but I think we have multiple sort of logprobs to output so we'd have to be careful about what how we return them.
Just to be super straightforward, there's a priorization for commercial offerings over bloom so if that's something that interests you I think sending an email to [email protected]
is the best bet to fast track this. For the free offering of bloom (the widget of this page), we'll probably integrate back some of this at some point, but it'll take more work as we're more committed to actually changing our libs to make things generic usable by everyone (and right now there's still quite a lot to just enable what you're currently seeing in a nice open source form).
No problem. Glad to help! I appreciate the open source of your work. Is there a simple way I can contribute this? Probably not enough for a pull request, but let me know:
def logprobs_from_prompt(prompt, tokenizer, model):
encoded = tokenizer(prompt, return_tensors="pt").to("cpu")
input_ids = encoded["input_ids"]
output = model(input_ids=input_ids)
shift_labels = input_ids[..., 1:].contiguous()
shift_logits = output.logits[..., :-1, :].contiguous()
log_probs = []
log_probs.append((tokenizer.decode(input_ids[0].tolist()[0]), None))
for idx, (label_id, logit) in enumerate(zip(shift_labels[0].tolist(), shift_logits[0])):
logprob = F.log_softmax(logit, dim=0).tolist()[label_id]
log_probs.append((tokenizer.decode(label_id), float(logprob)))
return log_probs
Nice work
@Pwicke
. The logic looks sound to me. Possibly you can avoid that tolist
which might be expensive when all you want is a single element. That list will be like 250880 elements long, one element per possible token, won't it? Maybe just F.log_softmax(logit, dim=0)[label_id].item()
would work.
A sniff test that things look right is that ed
has a high probability after rac
in your test data, so that indicates there's no off by one error or anything like that.
Just for reference this is what the OpenAI API responds with the logprob
setting, same prompt:
curl https://api.openai.com/v1/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $OPENAI_API_KEY" \
-d '{
"model": "text-davinci-002",
"prompt": "The horse raced past the barn fell.",
"temperature": 0.7,
"max_tokens": 0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"logprobs": 1,
"echo": true
}' | jq .
{
"id": "cmpl-5qTJCj3Q8PO8gQvAQbmBMH1TlyAHf",
"object": "text_completion",
"created": 1663187402,
"model": "text-davinci-002",
"choices": [
{
"text": "The horse raced past the barn fell.",
"index": 0,
"logprobs": {
"tokens": [
"The",
" horse",
" raced",
" past",
" the",
" barn",
" fell",
"."
],
"token_logprobs": [
null,
-9.926095,
-5.8342543,
-0.5715009,
-0.008218852,
-0.0009814043,
-0.3402846,
-1.3558618
],
"top_logprobs": [
null,
{
" first": -4.1999307
},
{
" is": -2.169767
},
{
" past": -0.5715009
},
{
" the": -0.008218852
},
{
" barn": -0.0009814043
},
{
" fell": -0.3402846
},
{
"\n": -0.9338951
}
],
"text_offset": [
0,
3,
9,
15,
20,
24,
29,
34
]
},
"finish_reason": "length"
}
],
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
So GPT-3's equivalent output values are:
The: null,
horse: -9.926095,
raced: -5.8342543,
past: -0.5715009,
the: -0.008218852,
barn: -0.0009814043,
fell: -0.3402846,
.: -1.3558618
I think this serves as excellent validation that @Pwicke 's method is sound. The scores and tokenization are different (GPT is really confident it's a barn), naturally since it's a different model, but if you squint they look the same.
(Note that in addition to the actual token used, openai sends top_logprobs
for the top n tokens (1 in this case but they allow up to 5). So in our case "horse" scored -9.926095 but "first" was a more probable choice with -4.1999307. That could be grabbed with a simple torch.topk
call on the logprob vector which gives both indices and values.)
(took the liberty of formatting some of the comments in this thread to add syntax highlighting)
Hello @sleven , @Brendan , @Pwicke ,
You can now ask for the log probabilities by using the details: true
parameter.
curl https://api-inference.huggingface.co/models/bigscience/bloom \
-X POST \
-d '{"inputs": "test", "parameters":{"details":true,"max_new_tokens":2}}'
[
{
"details" : {
"finish_reason" : "length",
"generated_tokens" : 2,
"tokens" : [
[
9234,
"test",
null
],
[
17,
".",
-1.7421875
],
[
16357,
"mark",
-2.421875
]
]
},
"generated_text" : "test.mark"
}
]
If you face any problem, feel free to comment or open an issue here.
def logprobs_from_prompt(prompt, tokenizer, model):
encoded = tokenizer(prompt, return_tensors="pt").to("cpu")
input_ids = encoded["input_ids"]
output = model(input_ids=input_ids)
shift_labels = input_ids[..., 1:].contiguous()
shift_logits = output.logits[..., :-1, :].contiguous()
log_probs = []
log_probs.append((tokenizer.decode(input_ids[0].tolist()[0]), None))
for idx, (label_id, logit) in enumerate(zip(shift_labels[0].tolist(), shift_logits[0])):
logprob = F.log_softmax(logit, dim=0).tolist()[label_id]
log_probs.append((tokenizer.decode(label_id), float(logprob)))
return log_probs
Thanks for this. This returns log probabilities for the input tokens, right?
The shape of output.logits
seems to correspond to [batch_size, input_size, vocab_size]
. I'm wondering, is it possible to get logits for each token in the output sequence?