|
--- |
|
library_name: keras-hub |
|
--- |
|
## Model Overview |
|
BART encoder-decoder network. |
|
|
|
This class implements a Transformer-based encoder-decoder model as |
|
described in |
|
["BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension"](https://arxiv.org/abs/1910.13461). |
|
|
|
The default constructor gives a fully customizable, randomly initialized BART |
|
model with any number of layers, heads, and embedding dimensions. To load |
|
preset architectures and weights, use the `from_preset` constructor. |
|
|
|
Disclaimer: Pre-trained models are provided on an "as is" basis, without |
|
warranties or conditions of any kind. The underlying model is provided by a |
|
third party and subject to a separate license, available |
|
[here](https://github.com/facebookresearch/fairseq/). |
|
|
|
|
|
__Arguments__ |
|
|
|
|
|
- __vocabulary_size__: int. The size of the token vocabulary. |
|
- __num_layers__: int. The number of transformer encoder layers and |
|
transformer decoder layers. |
|
- __num_heads__: int. The number of attention heads for each transformer. |
|
The hidden size must be divisible by the number of attention heads. |
|
- __hidden_dim__: int. The size of the transformer encoding and pooler layers. |
|
- __intermediate_dim__: int. The output dimension of the first Dense layer in |
|
a two-layer feedforward network for each transformer. |
|
- __dropout__: float. Dropout probability for the Transformer encoder. |
|
- __max_sequence_length__: int. The maximum sequence length that this encoder |
|
can consume. If None, `max_sequence_length` uses the value from |
|
sequence length. This determines the variable shape for positional |
|
embeddings. |
|
|
|
## Example Usage |
|
```python |
|
import keras |
|
import keras_hub |
|
import numpy as np |
|
``` |
|
|
|
Use `generate()` to do text generation, given an input context. |
|
```python |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en") |
|
bart_lm.generate("The quick brown fox", max_length=30) |
|
|
|
# Generate with batched inputs. |
|
bart_lm.generate(["The quick brown fox", "The whale"], max_length=30) |
|
``` |
|
|
|
Compile the `generate()` function with a custom sampler. |
|
```python |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en") |
|
bart_lm.compile(sampler="greedy") |
|
bart_lm.generate("The quick brown fox", max_length=30) |
|
``` |
|
|
|
Use `generate()` with encoder inputs and an incomplete decoder input (prompt). |
|
```python |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en") |
|
bart_lm.generate( |
|
{ |
|
"encoder_text": "The quick brown fox", |
|
"decoder_text": "The fast" |
|
} |
|
) |
|
``` |
|
|
|
Use `generate()` without preprocessing. |
|
```python |
|
# Preprocessed inputs, with encoder inputs corresponding to |
|
# "The quick brown fox", and the decoder inputs to "The fast". Use |
|
# `"padding_mask"` to indicate values that should not be overridden. |
|
prompt = { |
|
"encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]), |
|
"encoder_padding_mask": np.array( |
|
[[True, True, True, True, True, True, False, False]] |
|
), |
|
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]), |
|
"decoder_padding_mask": np.array([[True, True, True, True, False, False]]) |
|
} |
|
|
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset( |
|
"bart_base_en", |
|
preprocessor=None, |
|
) |
|
bart_lm.generate(prompt) |
|
``` |
|
|
|
Call `fit()` on a single batch. |
|
```python |
|
features = { |
|
"encoder_text": ["The quick brown fox jumped.", "I forgot my homework."], |
|
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."] |
|
} |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en") |
|
bart_lm.fit(x=features, batch_size=2) |
|
``` |
|
|
|
Call `fit()` without preprocessing. |
|
```python |
|
x = { |
|
"encoder_token_ids": np.array([[0, 133, 2119, 2, 1]] * 2), |
|
"encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2), |
|
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2]] * 2), |
|
"decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2), |
|
} |
|
y = np.array([[0, 133, 1769, 2, 1]] * 2) |
|
sw = np.array([[1, 1, 1, 1, 0]] * 2) |
|
|
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset( |
|
"bart_base_en", |
|
preprocessor=None, |
|
) |
|
bart_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) |
|
``` |
|
|
|
## Example Usage with Hugging Face URI |
|
|
|
```python |
|
import keras |
|
import keras_hub |
|
import numpy as np |
|
``` |
|
|
|
Use `generate()` to do text generation, given an input context. |
|
```python |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_base_en") |
|
bart_lm.generate("The quick brown fox", max_length=30) |
|
|
|
# Generate with batched inputs. |
|
bart_lm.generate(["The quick brown fox", "The whale"], max_length=30) |
|
``` |
|
|
|
Compile the `generate()` function with a custom sampler. |
|
```python |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_base_en") |
|
bart_lm.compile(sampler="greedy") |
|
bart_lm.generate("The quick brown fox", max_length=30) |
|
``` |
|
|
|
Use `generate()` with encoder inputs and an incomplete decoder input (prompt). |
|
```python |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_base_en") |
|
bart_lm.generate( |
|
{ |
|
"encoder_text": "The quick brown fox", |
|
"decoder_text": "The fast" |
|
} |
|
) |
|
``` |
|
|
|
Use `generate()` without preprocessing. |
|
```python |
|
# Preprocessed inputs, with encoder inputs corresponding to |
|
# "The quick brown fox", and the decoder inputs to "The fast". Use |
|
# `"padding_mask"` to indicate values that should not be overridden. |
|
prompt = { |
|
"encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]), |
|
"encoder_padding_mask": np.array( |
|
[[True, True, True, True, True, True, False, False]] |
|
), |
|
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]), |
|
"decoder_padding_mask": np.array([[True, True, True, True, False, False]]) |
|
} |
|
|
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset( |
|
"hf://keras/bart_base_en", |
|
preprocessor=None, |
|
) |
|
bart_lm.generate(prompt) |
|
``` |
|
|
|
Call `fit()` on a single batch. |
|
```python |
|
features = { |
|
"encoder_text": ["The quick brown fox jumped.", "I forgot my homework."], |
|
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."] |
|
} |
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_base_en") |
|
bart_lm.fit(x=features, batch_size=2) |
|
``` |
|
|
|
Call `fit()` without preprocessing. |
|
```python |
|
x = { |
|
"encoder_token_ids": np.array([[0, 133, 2119, 2, 1]] * 2), |
|
"encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2), |
|
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2]] * 2), |
|
"decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2), |
|
} |
|
y = np.array([[0, 133, 1769, 2, 1]] * 2) |
|
sw = np.array([[1, 1, 1, 1, 0]] * 2) |
|
|
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset( |
|
"hf://keras/bart_base_en", |
|
preprocessor=None, |
|
) |
|
bart_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) |
|
``` |
|
|