Spaces:
Build error
Build error
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
import torch | |
from transformers.file_utils import ModelOutput | |
class TranceptionCausalLMOutputWithCrossAttentions(ModelOutput): | |
""" | |
Class for Tranception causal language model (or autoregressive) outputs. | |
Args: | |
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
Language modeling loss (for next-token prediction). | |
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of | |
shape `(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
sequence_length)`. | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
heads. | |
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
sequence_length)`. | |
Cross attentions weights after the attention softmax, used to compute the weighted average in the | |
cross-attention heads. | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, | |
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder | |
setting. Only relevant if `config.is_decoder = True`. | |
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see | |
`past_key_values` input) to speed up sequential decoding. | |
fused_shift_log_probas (`torch.FloatTensor` of shape (batch_size, sequence_length, config.vocab_size), *optional*, returned when config.retrieval_aggregation_mode is not None. | |
log_probas for each residue position after aggregating autoregressive logits and retrieval logits. | |
""" | |
loss: Optional[torch.FloatTensor] = None | |
logits: torch.FloatTensor = None | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[torch.FloatTensor]] = None | |
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
fused_shift_log_probas: Optional[torch.FloatTensor] = None |