|
import enum |
|
import warnings |
|
|
|
from ..utils import add_end_docstrings, is_tf_available, is_torch_available |
|
from .base import PIPELINE_INIT_ARGS, Pipeline |
|
|
|
|
|
if is_torch_available(): |
|
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
|
|
|
if is_tf_available(): |
|
import tensorflow as tf |
|
|
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
|
|
|
|
|
class ReturnType(enum.Enum): |
|
TENSORS = 0 |
|
NEW_TEXT = 1 |
|
FULL_TEXT = 2 |
|
|
|
|
|
@add_end_docstrings(PIPELINE_INIT_ARGS) |
|
class TextGenerationPipeline(Pipeline): |
|
""" |
|
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a |
|
specified text prompt. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import pipeline |
|
|
|
>>> generator = pipeline(model="gpt2") |
|
>>> generator("I can't believe you did such a ", do_sample=False) |
|
[{'generated_text': "I can't believe you did such a icky thing to me. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I"}] |
|
|
|
>>> # These parameters will return suggestions, and only the newly created text making it easier for prompting suggestions. |
|
>>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False) |
|
``` |
|
|
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text |
|
generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about |
|
text generation parameters in [Text generation strategies](../generation_strategies) and [Text |
|
generation](text_generation). |
|
|
|
This language generation pipeline can currently be loaded from [`pipeline`] using the following task identifier: |
|
`"text-generation"`. |
|
|
|
The models that this pipeline can use are models that have been trained with an autoregressive language modeling |
|
objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available models |
|
on [huggingface.co/models](https://huggingface.co/models?filter=text-generation). |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
XL_PREFIX = """ |
|
In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The |
|
voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western |
|
Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision |
|
and denounces one of the men as a horse thief. Although his father initially slaps him for making such an |
|
accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of |
|
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop, |
|
begging for his blessing. <eod> </s> <eos> |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.check_model_type( |
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
|
) |
|
if "prefix" not in self._preprocess_params: |
|
|
|
|
|
|
|
|
|
prefix = None |
|
if self.model.config.prefix is not None: |
|
prefix = self.model.config.prefix |
|
if prefix is None and self.model.__class__.__name__ in [ |
|
"XLNetLMHeadModel", |
|
"TransfoXLLMHeadModel", |
|
"TFXLNetLMHeadModel", |
|
"TFTransfoXLLMHeadModel", |
|
]: |
|
|
|
prefix = self.XL_PREFIX |
|
if prefix is not None: |
|
|
|
preprocess_params, forward_params, _ = self._sanitize_parameters(prefix=prefix, **self._forward_params) |
|
self._preprocess_params = {**self._preprocess_params, **preprocess_params} |
|
self._forward_params = {**self._forward_params, **forward_params} |
|
|
|
def _sanitize_parameters( |
|
self, |
|
return_full_text=None, |
|
return_tensors=None, |
|
return_text=None, |
|
return_type=None, |
|
clean_up_tokenization_spaces=None, |
|
prefix=None, |
|
handle_long_generation=None, |
|
stop_sequence=None, |
|
add_special_tokens=False, |
|
**generate_kwargs, |
|
): |
|
preprocess_params = {"add_special_tokens": add_special_tokens} |
|
if prefix is not None: |
|
preprocess_params["prefix"] = prefix |
|
if prefix: |
|
prefix_inputs = self.tokenizer( |
|
prefix, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework |
|
) |
|
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1] |
|
|
|
if handle_long_generation is not None: |
|
if handle_long_generation not in {"hole"}: |
|
raise ValueError( |
|
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected" |
|
" [None, 'hole']" |
|
) |
|
preprocess_params["handle_long_generation"] = handle_long_generation |
|
|
|
preprocess_params.update(generate_kwargs) |
|
forward_params = generate_kwargs |
|
|
|
postprocess_params = {} |
|
if return_full_text is not None and return_type is None: |
|
if return_text is not None: |
|
raise ValueError("`return_text` is mutually exclusive with `return_full_text`") |
|
if return_tensors is not None: |
|
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`") |
|
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT |
|
if return_tensors is not None and return_type is None: |
|
if return_text is not None: |
|
raise ValueError("`return_text` is mutually exclusive with `return_tensors`") |
|
return_type = ReturnType.TENSORS |
|
if return_type is not None: |
|
postprocess_params["return_type"] = return_type |
|
if clean_up_tokenization_spaces is not None: |
|
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces |
|
|
|
if stop_sequence is not None: |
|
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False) |
|
if len(stop_sequence_ids) > 1: |
|
warnings.warn( |
|
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of" |
|
" the stop sequence will be used as the stop sequence string in the interim." |
|
) |
|
generate_kwargs["eos_token_id"] = stop_sequence_ids[0] |
|
|
|
return preprocess_params, forward_params, postprocess_params |
|
|
|
|
|
def _parse_and_tokenize(self, *args, **kwargs): |
|
""" |
|
Parse arguments and tokenize |
|
""" |
|
|
|
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: |
|
kwargs.update({"add_space_before_punct_symbol": True}) |
|
|
|
return super()._parse_and_tokenize(*args, **kwargs) |
|
|
|
def __call__(self, text_inputs, **kwargs): |
|
""" |
|
Complete the prompt(s) given as inputs. |
|
|
|
Args: |
|
args (`str` or `List[str]`): |
|
One or several prompts (or one list of prompts) to complete. |
|
return_tensors (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to |
|
`True`, the decoded text is not returned. |
|
return_text (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return the decoded texts in the outputs. |
|
return_full_text (`bool`, *optional*, defaults to `True`): |
|
If set to `False` only added text is returned, otherwise the full text is returned. Only meaningful if |
|
*return_text* is set to True. |
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): |
|
Whether or not to clean up the potential extra spaces in the text output. |
|
prefix (`str`, *optional*): |
|
Prefix added to prompt. |
|
handle_long_generation (`str`, *optional*): |
|
By default, this pipelines does not handle long generation (ones that exceed in one form or the other |
|
the model maximum length). There is no perfect way to adress this (more info |
|
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common |
|
strategies to work around that problem depending on your use case. |
|
|
|
- `None` : default strategy where nothing in particular happens |
|
- `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might |
|
truncate a lot of the prompt and not suitable when generation exceed the model capacity) |
|
|
|
generate_kwargs: |
|
Additional keyword arguments to pass along to the generate method of the model (see the generate method |
|
corresponding to your framework [here](./model#generative-models)). |
|
|
|
Return: |
|
A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination |
|
of both `generated_text` and `generated_token_ids`): |
|
|
|
- **generated_text** (`str`, present when `return_text=True`) -- The generated text. |
|
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token |
|
ids of the generated text. |
|
""" |
|
return super().__call__(text_inputs, **kwargs) |
|
|
|
def preprocess( |
|
self, prompt_text, prefix="", handle_long_generation=None, add_special_tokens=False, **generate_kwargs |
|
): |
|
inputs = self.tokenizer( |
|
prefix + prompt_text, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework |
|
) |
|
inputs["prompt_text"] = prompt_text |
|
|
|
if handle_long_generation == "hole": |
|
cur_len = inputs["input_ids"].shape[-1] |
|
if "max_new_tokens" in generate_kwargs: |
|
new_tokens = generate_kwargs["max_new_tokens"] |
|
else: |
|
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len |
|
if new_tokens < 0: |
|
raise ValueError("We cannot infer how many new tokens are expected") |
|
if cur_len + new_tokens > self.tokenizer.model_max_length: |
|
keep_length = self.tokenizer.model_max_length - new_tokens |
|
if keep_length <= 0: |
|
raise ValueError( |
|
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the" |
|
" models max length" |
|
) |
|
|
|
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:] |
|
if "attention_mask" in inputs: |
|
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:] |
|
|
|
return inputs |
|
|
|
def _forward(self, model_inputs, **generate_kwargs): |
|
input_ids = model_inputs["input_ids"] |
|
attention_mask = model_inputs.get("attention_mask", None) |
|
|
|
if input_ids.shape[1] == 0: |
|
input_ids = None |
|
attention_mask = None |
|
in_b = 1 |
|
else: |
|
in_b = input_ids.shape[0] |
|
prompt_text = model_inputs.pop("prompt_text") |
|
|
|
|
|
|
|
prefix_length = generate_kwargs.pop("prefix_length", 0) |
|
if prefix_length > 0: |
|
has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].max_new_tokens is not None |
|
) |
|
if not has_max_new_tokens: |
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length |
|
generate_kwargs["max_length"] += prefix_length |
|
has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].min_new_tokens is not None |
|
) |
|
if not has_min_new_tokens and "min_length" in generate_kwargs: |
|
generate_kwargs["min_length"] += prefix_length |
|
|
|
|
|
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) |
|
out_b = generated_sequence.shape[0] |
|
if self.framework == "pt": |
|
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) |
|
elif self.framework == "tf": |
|
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) |
|
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} |
|
|
|
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): |
|
generated_sequence = model_outputs["generated_sequence"][0] |
|
input_ids = model_outputs["input_ids"] |
|
prompt_text = model_outputs["prompt_text"] |
|
generated_sequence = generated_sequence.numpy().tolist() |
|
records = [] |
|
for sequence in generated_sequence: |
|
if return_type == ReturnType.TENSORS: |
|
record = {"generated_token_ids": sequence} |
|
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: |
|
|
|
text = self.tokenizer.decode( |
|
sequence, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
|
|
|
|
if input_ids is None: |
|
prompt_length = 0 |
|
else: |
|
prompt_length = len( |
|
self.tokenizer.decode( |
|
input_ids[0], |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
) |
|
|
|
all_text = text[prompt_length:] |
|
if return_type == ReturnType.FULL_TEXT: |
|
all_text = prompt_text + all_text |
|
|
|
record = {"generated_text": all_text} |
|
records.append(record) |
|
|
|
return records |
|
|