Spaces:
Runtime error
Runtime error
rynmurdock
commited on
Commit
•
c5ca37a
1
Parent(s):
5b8f2e0
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Optimus/.gitignore +8 -0
- Optimus/README.md +121 -0
- Optimus/code/README.md +41 -0
- Optimus/code/app.py +0 -0
- Optimus/code/examples/README.md +392 -0
- Optimus/code/examples/__pycache__/utils_glue.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/__pycache__/grad_app.cpython-310.pyc +0 -0
- Optimus/code/examples/big_ae/__pycache__/utils.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/debug_data.py +6 -0
- Optimus/code/examples/big_ae/eval_dialog_multi_response.py +378 -0
- Optimus/code/examples/big_ae/eval_dialog_response.py +295 -0
- Optimus/code/examples/big_ae/grad_app.py +486 -0
- Optimus/code/examples/big_ae/metrics.py +196 -0
- Optimus/code/examples/big_ae/modules/__init__.py +7 -0
- Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-310.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-310.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-310.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-310.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-310.pyc +0 -0
- Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-37.pyc +0 -0
- Optimus/code/examples/big_ae/modules/arae.py +274 -0
- Optimus/code/examples/big_ae/modules/cara.py +374 -0
- Optimus/code/examples/big_ae/modules/ctrl_gen.py +371 -0
- Optimus/code/examples/big_ae/modules/decoders/dec_gpt2.py +358 -0
- Optimus/code/examples/big_ae/modules/decoders/decoder.py +79 -0
- Optimus/code/examples/big_ae/modules/encoders/__init__.py +1 -0
- Optimus/code/examples/big_ae/modules/encoders/enc_lstm.py +126 -0
- Optimus/code/examples/big_ae/modules/encoders/encoder.py +58 -0
- Optimus/code/examples/big_ae/modules/encoders/gaussian_encoder.py +147 -0
- Optimus/code/examples/big_ae/modules/spacefusion.py +143 -0
- Optimus/code/examples/big_ae/modules/utils.py +40 -0
- Optimus/code/examples/big_ae/modules/vae.py +638 -0
- Optimus/code/examples/big_ae/run_data_filtering.py +507 -0
- Optimus/code/examples/big_ae/run_dialog_dataloader.py +483 -0
- Optimus/code/examples/big_ae/run_encoding_generation.py +487 -0
- Optimus/code/examples/big_ae/run_generation_from_prior.py +414 -0
- Optimus/code/examples/big_ae/run_gpt2_generation.py +390 -0
- Optimus/code/examples/big_ae/run_latent_generation.py +577 -0
- Optimus/code/examples/big_ae/run_lm_ae_pretraining.py +692 -0
- Optimus/code/examples/big_ae/run_lm_causal_pretraining.py +692 -0
- Optimus/code/examples/big_ae/run_lm_finetuning_baseline.py +573 -0
- Optimus/code/examples/big_ae/run_lm_gpt2_training.py +658 -0
- Optimus/code/examples/big_ae/run_lm_vae_label_ctrl_gen.py +875 -0
- Optimus/code/examples/big_ae/run_lm_vae_pretraining.py +669 -0
Optimus/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/datasets/glue_data/glue_data
|
2 |
+
data/datasets/glue_data/train.tx
|
3 |
+
data/datasets/glue_data/cached_lm_gpt_bert_256_train.jsont
|
4 |
+
code/runs
|
5 |
+
output/*
|
6 |
+
code/pytorch_transformers/__pycache__/*
|
7 |
+
code/examples/big_ae/modules/encoders/__pycache__/*
|
8 |
+
|
Optimus/README.md
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Optimus: the first pre-trained Big VAE language model <img src="doc/figs/logo_optimus.png" width="100" align="right">
|
2 |
+
|
3 |
+
This repository contains source code necessary to reproduce the results presented in the EMNLP 2020 paper [Optimus: Organizing Sentences via Pre-trained Modeling of a Latent Space](https://arxiv.org/abs/2004.04092).
|
4 |
+
|
5 |
+
|
6 |
+
|<img src="doc/figs/optimus_scheme.png" width="350"> | <img src="doc/figs/headfig_optimus.png" width="800">
|
7 |
+
|-------------------------|:-------------------------:|
|
8 |
+
| The network architecture of Optimus: encoder for representation learning and decoder for generation | Sentences are organized and manipulated in a pre-trained compact and smooth latent space
|
9 |
+
|
10 |
+
|
11 |
+
For more on this project, see the [Microsoft Research Blog post](https://www.microsoft.com/en-us/research/blog/a-deep-generative-model-trifecta-three-advances-that-work-towards-harnessing-large-scale-power/).
|
12 |
+
|
13 |
+
|
14 |
+
## News
|
15 |
+
|
16 |
+
May 21, 2020: Releasing a [`demo`](http://40.71.23.172:8899/) for latent space manipulation, including sentence interpolation and analogy. Check out the [`website`](http://40.71.23.172:8899/).
|
17 |
+
|
18 |
+
May 20, 2020: The latent space manipulation code is cleaned and released. See instructions at [`optimius_for_snli.md`](doc/optimius_for_snli.md).
|
19 |
+
|
20 |
+
May 13, 2020: The fine-tuning code for langauge modeling is released. See instructions at [`optimus_finetune_language_models.md`](doc/optimus_finetune_language_models.md)
|
21 |
+
|
22 |
+
## Contents
|
23 |
+
There are four steps to use this codebase to reproduce the results in the paper.
|
24 |
+
|
25 |
+
1. [Dependencies](#dependencies)
|
26 |
+
2. [Prepare datasets](#prepare-datasets)
|
27 |
+
3. [Model training](#Model-training)
|
28 |
+
1. Pre-training on setences in Wikipedia
|
29 |
+
2. Languange Modeling
|
30 |
+
3. Guided Language Generation
|
31 |
+
4. Low-resource Language Understanding
|
32 |
+
4. [Collect and plot results](#collect-and-plot-results)
|
33 |
+
|
34 |
+
|
35 |
+
## Dependencies
|
36 |
+
|
37 |
+
Pull docker from Docker Hub at: `chunyl/pytorch-transformers:v2`. Please see the instruction at [`doc/env.md`](doc/env.md)
|
38 |
+
|
39 |
+
The project is organized into the following structures, with ensential files & folders visualized. `output` saves the models checkpoints.
|
40 |
+
```
|
41 |
+
├── Optimus
|
42 |
+
└── code
|
43 |
+
├── examples
|
44 |
+
├── big_ae
|
45 |
+
├── modules
|
46 |
+
├── vae.py
|
47 |
+
└── ...
|
48 |
+
├── run_lm_vae_pretraining_phdist_beta.py
|
49 |
+
├── run_lm_vae_training.py
|
50 |
+
└── ...
|
51 |
+
├── pytorch_transformers
|
52 |
+
├── modeling_bert.py
|
53 |
+
├── modeling_gpt2.py
|
54 |
+
└── ...
|
55 |
+
├── scripts
|
56 |
+
├── scripts_docker
|
57 |
+
├── scripts_local
|
58 |
+
├── scripts_philly
|
59 |
+
└── data
|
60 |
+
└── datasets
|
61 |
+
├── wikipedia_json_64_filtered
|
62 |
+
└── ...
|
63 |
+
├── snli_data
|
64 |
+
└── ...
|
65 |
+
└── output
|
66 |
+
├── pretrain
|
67 |
+
├── LM
|
68 |
+
└── ...
|
69 |
+
```
|
70 |
+
|
71 |
+
## Prepare Datasets
|
72 |
+
|
73 |
+
Please download or preparation the data via following the instructions at [`data/download_datasets.md`](data/download_datasets.md).
|
74 |
+
|
75 |
+
## Model Training
|
76 |
+
|
77 |
+
**1. Pre-training on setences in Wikipedia**
|
78 |
+
|
79 |
+
We pre-trained our models on Philly (a Microsoft internal compute cluster), the code is specialized for multi-node multi-GPU compute on this platform. The pre-training main python is [`run_lm_vae_pretraining_phdist_beta.py`](code/examples/big_ae/run_lm_vae_pretraining_phdist_beta.py). You may need to adjust the distributed training scripts.
|
80 |
+
|
81 |
+
**2. Languange Modeling**
|
82 |
+
|
83 |
+
To have a fair comparison with existing VAE languange models, we consider a model with latent dimension 32. The pre-trained model is fine-tuned on four commonly datasets for one epoch. Please see the details at [`doc/optimus_finetune_language_models.md`](doc/optimus_finetune_language_models.md)
|
84 |
+
|
85 |
+
**3. Guided Language Generation**
|
86 |
+
|
87 |
+
|
88 |
+
**Latent Space Manipulation** To ensure good performance, we consider a model with latent dimension 768. The pre-trained model is fine-tuned on SNLI dataset, where sentences show related patterns. Please see the details at
|
89 |
+
Please see the details at [`doc/optimius_for_snli.md`](doc/optimius_for_snli.md)
|
90 |
+
|
91 |
+
**4. Low-resource Language Understanding**
|
92 |
+
|
93 |
+
## Collect and Plot Results
|
94 |
+
|
95 |
+
Once the networks are trained and the results are saved, we extracted key results using Python script. The results can be plotted using the included IPython notebook `plots/main_plots.ipynb`.
|
96 |
+
Start the IPython Notebook server:
|
97 |
+
|
98 |
+
```
|
99 |
+
$ cd plots
|
100 |
+
$ ipython notebook
|
101 |
+
```
|
102 |
+
|
103 |
+
Select the `main_plots.ipynb` notebook and execute the included
|
104 |
+
code. Note that without modification, we have copyed our extracted results into the notebook, and script will output figures in the paper. If you've run your own training and wish to plot results, you'll have to organize your results in the same format instead.
|
105 |
+
|
106 |
+
|
107 |
+
## Questions?
|
108 |
+
|
109 |
+
Please drop me ([Chunyuan](http://chunyuan.li/)) a line if you have any questions.
|
110 |
+
|
111 |
+
|
112 |
+
```
|
113 |
+
@inproceedings{li2020_Optimus,
|
114 |
+
title={Optimus: Organizing Sentences via Pre-trained Modeling of a Latent Space},
|
115 |
+
author={Li, Chunyuan and Gao, Xiang and Li, Yuan and Li, Xiujun and Peng, Baolin and Zhang, Yizhe and Gao, Jianfeng},
|
116 |
+
booktitle={EMNLP},
|
117 |
+
year={2020}
|
118 |
+
}
|
119 |
+
```
|
120 |
+
|
121 |
+
|
Optimus/code/README.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Set up Environment
|
2 |
+
|
3 |
+
Pull docker from Docker Hub at: chunyl/pytorch-transformers:v2
|
4 |
+
|
5 |
+
Edit the project path to the absolute path on your computer by changing the "SCRIPTPATH" in [run_docker.sh](./scripts/scripts_docker/run_docker.sh)
|
6 |
+
|
7 |
+
In this directory ("code"), and run docker
|
8 |
+
|
9 |
+
sh scripts/scripts_docker/run_docker.sh
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
## Fine-tune Language Models
|
15 |
+
|
16 |
+
sh scripts/scripts_local/run_ft_lm_vae_optimus.sh
|
17 |
+
|
18 |
+
|
19 |
+
The main training script is [`run_lm_vae_training.py`](./examples/big_ae/run_lm_vae_training.py) and conducts the fine-tuning loop, taking the following options (among others) as arguments:
|
20 |
+
|
21 |
+
- `--checkpoint_dir`: the folder that the pre-trained Optimus is saved.
|
22 |
+
- `--gloabl_step_eval`: it specifies the checkpoint (the steps that Optimus is trained).
|
23 |
+
- `--train_data_file` and `--eval_data_file`: the path for training and testing datasets for the downstream fine-tuning.
|
24 |
+
- `--dataset`: the dataset for fine-tuning. such as `Penn`
|
25 |
+
- `--num_train_epochs`: number of training epochs (type=int); default 1.
|
26 |
+
- `--dim_target_kl`: the hyper-paramter used in dimension-wise thresholding used in fine-tuning(type=float); default 0.5.
|
27 |
+
- `--beta`: the maximum beta value used in cyclical annealing schedule used in fine-tuning(type=float); default 1.0.
|
28 |
+
- `--ratio_zero`: the proportion of beta=0 in one period for fine-tuning(type=float); default 0.5
|
29 |
+
- `--ratio_increase`: the proportion of beta that increases from 0 to the maximum value in one period in cyclical annealing schedule used in fine-tuning(type=float); default 0.25.
|
30 |
+
|
31 |
+
|
32 |
+
For more options, please see [`run_lm_vae_training.py`](./examples/big_ae/run_lm_vae_training.py) and see the examples we provided in [`run_ft_lm_vae_optimus.sh`](./scripts/scripts_local/run_ft_lm_vae_optimus.sh), or [more running scripts we used to run the code on a cluster](./scripts/scripts_philly).
|
33 |
+
|
34 |
+
|
35 |
+
## Play with the latent space
|
36 |
+
|
37 |
+
sh scripts/scripts_local/eval_optimus_latent_space.sh
|
38 |
+
|
39 |
+
The main training script is [`run_latent_generation.py`](./examples/big_ae/run_latent_generation.py) and evaluates the various ways to generate text conditioned on latent vectors, taking the following options (among others) as arguments:
|
40 |
+
|
41 |
+
- `--play_mode`: The current scripts supports two ways to play with the pre-trained VAE models: [`reconstrction`, `interpolation`]
|
Optimus/code/app.py
ADDED
File without changes
|
Optimus/code/examples/README.md
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Examples
|
2 |
+
|
3 |
+
In this section a few examples are put together. All of these examples work for several models, making use of the very
|
4 |
+
similar API between the different models.
|
5 |
+
|
6 |
+
| Section | Description |
|
7 |
+
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
8 |
+
| [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
|
9 |
+
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
|
10 |
+
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
|
11 |
+
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
|
12 |
+
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
|
13 |
+
|
14 |
+
## Language model fine-tuning
|
15 |
+
|
16 |
+
Based on the script [`run_lm_finetuning.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_lm_finetuning.py).
|
17 |
+
|
18 |
+
Fine-tuning the library models for language modeling on a text dataset for GPT, GPT-2, BERT and RoBERTa (DistilBERT
|
19 |
+
to be added soon). GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa
|
20 |
+
are fine-tuned using a masked language modeling (MLM) loss.
|
21 |
+
|
22 |
+
Before running the following example, you should get a file that contains text on which the language model will be
|
23 |
+
fine-tuned. A good example of such text is the [WikiText-2 dataset](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/).
|
24 |
+
|
25 |
+
We will refer to two different files: `$TRAIN_FILE`, which contains text for training, and `$TEST_FILE`, which contains
|
26 |
+
text that will be used for evaluation.
|
27 |
+
|
28 |
+
### GPT-2/GPT and causal language modeling
|
29 |
+
|
30 |
+
The following example fine-tunes GPT-2 on WikiText-2. We're using the raw WikiText-2 (no tokens were replaced before
|
31 |
+
the tokenization). The loss here is that of causal language modeling.
|
32 |
+
|
33 |
+
```bash
|
34 |
+
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
|
35 |
+
export TEST_FILE=/path/to/dataset/wiki.test.raw
|
36 |
+
|
37 |
+
python run_lm_finetuning.py \
|
38 |
+
--output_dir=output \
|
39 |
+
--model_type=gpt2 \
|
40 |
+
--model_name_or_path=gpt2 \
|
41 |
+
--do_train \
|
42 |
+
--train_data_file=$TRAIN_FILE \
|
43 |
+
--do_eval \
|
44 |
+
--eval_data_file=$TEST_FILE
|
45 |
+
```
|
46 |
+
|
47 |
+
This takes about half an hour to train on a single K80 GPU and about one minute for the evaluation to run. It reaches
|
48 |
+
a score of ~20 perplexity once fine-tuned on the dataset.
|
49 |
+
|
50 |
+
### RoBERTa/BERT and masked language modeling
|
51 |
+
|
52 |
+
The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different
|
53 |
+
as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their
|
54 |
+
pre-training: masked language modeling.
|
55 |
+
|
56 |
+
In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore, converge
|
57 |
+
slightly slower (over-fitting takes more epochs).
|
58 |
+
|
59 |
+
We use the `--mlm` flag so that the script may change its loss function.
|
60 |
+
|
61 |
+
```bash
|
62 |
+
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
|
63 |
+
export TEST_FILE=/path/to/dataset/wiki.test.raw
|
64 |
+
|
65 |
+
python run_lm_finetuning.py \
|
66 |
+
--output_dir=output \
|
67 |
+
--model_type=roberta \
|
68 |
+
--model_name_or_path=roberta-base \
|
69 |
+
--do_train \
|
70 |
+
--train_data_file=$TRAIN_FILE \
|
71 |
+
--do_eval \
|
72 |
+
--eval_data_file=$TEST_FILE \
|
73 |
+
--mlm
|
74 |
+
```
|
75 |
+
|
76 |
+
## Language generation
|
77 |
+
|
78 |
+
Based on the script [`run_generation.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_generation.py).
|
79 |
+
|
80 |
+
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet.
|
81 |
+
A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you
|
82 |
+
can try out the different models available in the library.
|
83 |
+
|
84 |
+
Example usage:
|
85 |
+
|
86 |
+
```bash
|
87 |
+
python run_generation.py \
|
88 |
+
--model_type=gpt2 \
|
89 |
+
--model_name_or_path=gpt2
|
90 |
+
```
|
91 |
+
|
92 |
+
## GLUE
|
93 |
+
|
94 |
+
Based on the script [`run_glue.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_glue.py).
|
95 |
+
|
96 |
+
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
97 |
+
Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa.
|
98 |
+
|
99 |
+
GLUE is made up of a total of 9 different tasks. We get the following results on the dev set of the benchmark with an
|
100 |
+
uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran on 8 V100 GPUs with a total train
|
101 |
+
batch size of 24. Some of these tasks have a small dataset and training can lead to high variance in the results
|
102 |
+
between different runs. We report the median on 5 runs (with different seeds) for each of the metrics.
|
103 |
+
|
104 |
+
| Task | Metric | Result |
|
105 |
+
|-------|------------------------------|-------------|
|
106 |
+
| CoLA | Matthew's corr | 48.87 |
|
107 |
+
| SST-2 | Accuracy | 91.74 |
|
108 |
+
| MRPC | F1/Accuracy | 90.70/86.27 |
|
109 |
+
| STS-B | Person/Spearman corr. | 91.39/91.04 |
|
110 |
+
| QQP | Accuracy/F1 | 90.79/87.66 |
|
111 |
+
| MNLI | Matched acc./Mismatched acc. | 83.70/84.83 |
|
112 |
+
| QNLI | Accuracy | 89.31 |
|
113 |
+
| RTE | Accuracy | 71.43 |
|
114 |
+
| WNLI | Accuracy | 43.66 |
|
115 |
+
|
116 |
+
Some of these results are significantly different from the ones reported on the test set
|
117 |
+
of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the webite.
|
118 |
+
|
119 |
+
Before running anyone of these GLUE tasks you should download the
|
120 |
+
[GLUE data](https://gluebenchmark.com/tasks) by running
|
121 |
+
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
122 |
+
and unpack it to some directory `$GLUE_DIR`.
|
123 |
+
|
124 |
+
```bash
|
125 |
+
export GLUE_DIR=/path/to/glue
|
126 |
+
export TASK_NAME=MRPC
|
127 |
+
|
128 |
+
python run_glue.py \
|
129 |
+
--model_type bert \
|
130 |
+
--model_name_or_path bert-base-cased \
|
131 |
+
--task_name $TASK_NAME \
|
132 |
+
--do_train \
|
133 |
+
--do_eval \
|
134 |
+
--do_lower_case \
|
135 |
+
--data_dir $GLUE_DIR/$TASK_NAME \
|
136 |
+
--max_seq_length 128 \
|
137 |
+
--per_gpu_train_batch_size 32 \
|
138 |
+
--learning_rate 2e-5 \
|
139 |
+
--num_train_epochs 3.0 \
|
140 |
+
--output_dir /tmp/$TASK_NAME/
|
141 |
+
```
|
142 |
+
|
143 |
+
where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
|
144 |
+
|
145 |
+
The dev set results will be present within the text file `eval_results.txt` in the specified output_dir.
|
146 |
+
In case of MNLI, since there are two separate dev sets (matched and mismatched), there will be a separate
|
147 |
+
output folder called `/tmp/MNLI-MM/` in addition to `/tmp/MNLI/`.
|
148 |
+
|
149 |
+
The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI,
|
150 |
+
CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being
|
151 |
+
said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well,
|
152 |
+
since the data processor for each task inherits from the base class DataProcessor.
|
153 |
+
|
154 |
+
### MRPC
|
155 |
+
|
156 |
+
#### Fine-tuning example
|
157 |
+
|
158 |
+
The following examples fine-tune BERT on the Microsoft Research Paraphrase Corpus (MRPC) corpus and runs in less
|
159 |
+
than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
|
160 |
+
|
161 |
+
Before running anyone of these GLUE tasks you should download the
|
162 |
+
[GLUE data](https://gluebenchmark.com/tasks) by running
|
163 |
+
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
164 |
+
and unpack it to some directory `$GLUE_DIR`.
|
165 |
+
|
166 |
+
```bash
|
167 |
+
export GLUE_DIR=/path/to/glue
|
168 |
+
|
169 |
+
python run_glue.py \
|
170 |
+
--model_type bert \
|
171 |
+
--model_name_or_path bert-base-cased \
|
172 |
+
--task_name MRPC \
|
173 |
+
--do_train \
|
174 |
+
--do_eval \
|
175 |
+
--do_lower_case \
|
176 |
+
--data_dir $GLUE_DIR/MRPC/ \
|
177 |
+
--max_seq_length 128 \
|
178 |
+
--per_gpu_train_batch_size 32 \
|
179 |
+
--learning_rate 2e-5 \
|
180 |
+
--num_train_epochs 3.0 \
|
181 |
+
--output_dir /tmp/mrpc_output/
|
182 |
+
```
|
183 |
+
|
184 |
+
Our test ran on a few seeds with [the original implementation hyper-
|
185 |
+
parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation
|
186 |
+
results between 84% and 88%.
|
187 |
+
|
188 |
+
#### Using Apex and mixed-precision
|
189 |
+
|
190 |
+
Using Apex and 16 bit precision, the fine-tuning on MRPC only takes 27 seconds. First install
|
191 |
+
[apex](https://github.com/NVIDIA/apex), then run the following example:
|
192 |
+
|
193 |
+
```bash
|
194 |
+
export GLUE_DIR=/path/to/glue
|
195 |
+
|
196 |
+
python run_glue.py \
|
197 |
+
--model_type bert \
|
198 |
+
--model_name_or_path bert-base-cased \
|
199 |
+
--task_name MRPC \
|
200 |
+
--do_train \
|
201 |
+
--do_eval \
|
202 |
+
--do_lower_case \
|
203 |
+
--data_dir $GLUE_DIR/MRPC/ \
|
204 |
+
--max_seq_length 128 \
|
205 |
+
--per_gpu_train_batch_size 32 \
|
206 |
+
--learning_rate 2e-5 \
|
207 |
+
--num_train_epochs 3.0 \
|
208 |
+
--output_dir /tmp/mrpc_output/ \
|
209 |
+
--fp16
|
210 |
+
```
|
211 |
+
|
212 |
+
#### Distributed training
|
213 |
+
|
214 |
+
Here is an example using distributed training on 8 V100 GPUs. The model used is the BERT whole-word-masking and it
|
215 |
+
reaches F1 > 92 on MRPC.
|
216 |
+
|
217 |
+
```bash
|
218 |
+
export GLUE_DIR=/path/to/glue
|
219 |
+
|
220 |
+
python -m torch.distributed.launch \
|
221 |
+
--nproc_per_node 8 run_glue.py \
|
222 |
+
--model_type bert \
|
223 |
+
--model_name_or_path bert-base-cased \
|
224 |
+
--task_name MRPC \
|
225 |
+
--do_train \
|
226 |
+
--do_eval \
|
227 |
+
--do_lower_case \
|
228 |
+
--data_dir $GLUE_DIR/MRPC/ \
|
229 |
+
--max_seq_length 128 \
|
230 |
+
--per_gpu_train_batch_size 8 \
|
231 |
+
--learning_rate 2e-5 \
|
232 |
+
--num_train_epochs 3.0 \
|
233 |
+
--output_dir /tmp/mrpc_output/
|
234 |
+
```
|
235 |
+
|
236 |
+
Training with these hyper-parameters gave us the following results:
|
237 |
+
|
238 |
+
```bash
|
239 |
+
acc = 0.8823529411764706
|
240 |
+
acc_and_f1 = 0.901702786377709
|
241 |
+
eval_loss = 0.3418912578906332
|
242 |
+
f1 = 0.9210526315789473
|
243 |
+
global_step = 174
|
244 |
+
loss = 0.07231863956341798
|
245 |
+
```
|
246 |
+
|
247 |
+
### MNLI
|
248 |
+
|
249 |
+
The following example uses the BERT-large, uncased, whole-word-masking model and fine-tunes it on the MNLI task.
|
250 |
+
|
251 |
+
```bash
|
252 |
+
export GLUE_DIR=/path/to/glue
|
253 |
+
|
254 |
+
python -m torch.distributed.launch \
|
255 |
+
--nproc_per_node 8 run_glue.py \
|
256 |
+
--model_type bert \
|
257 |
+
--model_name_or_path bert-base-cased \
|
258 |
+
--task_name mnli \
|
259 |
+
--do_train \
|
260 |
+
--do_eval \
|
261 |
+
--do_lower_case \
|
262 |
+
--data_dir $GLUE_DIR/MNLI/ \
|
263 |
+
--max_seq_length 128 \
|
264 |
+
--per_gpu_train_batch_size 8 \
|
265 |
+
--learning_rate 2e-5 \
|
266 |
+
--num_train_epochs 3.0 \
|
267 |
+
--output_dir output_dir \
|
268 |
+
```
|
269 |
+
|
270 |
+
The results are the following:
|
271 |
+
|
272 |
+
```bash
|
273 |
+
***** Eval results *****
|
274 |
+
acc = 0.8679706601466992
|
275 |
+
eval_loss = 0.4911287787382479
|
276 |
+
global_step = 18408
|
277 |
+
loss = 0.04755385363816904
|
278 |
+
|
279 |
+
***** Eval results *****
|
280 |
+
acc = 0.8747965825874695
|
281 |
+
eval_loss = 0.45516540421714036
|
282 |
+
global_step = 18408
|
283 |
+
loss = 0.04755385363816904
|
284 |
+
```
|
285 |
+
|
286 |
+
##Multiple Choice
|
287 |
+
|
288 |
+
Based on the script [`run_multiple_choice.py`]().
|
289 |
+
|
290 |
+
#### Fine-tuning on SWAG
|
291 |
+
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
|
292 |
+
|
293 |
+
```
|
294 |
+
#training on 4 tesla V100(16GB) GPUS
|
295 |
+
export SWAG_DIR=/path/to/swag_data_dir
|
296 |
+
python ./examples/single_model_scripts/run_multiple_choice.py \
|
297 |
+
--model_type roberta \
|
298 |
+
--task_name swag \
|
299 |
+
--model_name_or_path roberta-base \
|
300 |
+
--do_train \
|
301 |
+
--do_eval \
|
302 |
+
--do_lower_case \
|
303 |
+
--data_dir $SWAG_DIR \
|
304 |
+
--learning_rate 5e-5 \
|
305 |
+
--num_train_epochs 3 \
|
306 |
+
--max_seq_length 80 \
|
307 |
+
--output_dir models_bert/swag_base \
|
308 |
+
--per_gpu_eval_batch_size=16 \
|
309 |
+
--per_gpu_train_batch_size=16 \
|
310 |
+
--gradient_accumulation_steps 2 \
|
311 |
+
--overwrite_output
|
312 |
+
```
|
313 |
+
Training with the defined hyper-parameters yields the following results:
|
314 |
+
```
|
315 |
+
***** Eval results *****
|
316 |
+
eval_acc = 0.8338998300509847
|
317 |
+
eval_loss = 0.44457291918821606
|
318 |
+
```
|
319 |
+
|
320 |
+
## SQuAD
|
321 |
+
|
322 |
+
Based on the script [`run_squad.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py).
|
323 |
+
|
324 |
+
#### Fine-tuning on SQuAD
|
325 |
+
|
326 |
+
This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large)
|
327 |
+
on a single tesla V100 16GB. The data for SQuAD can be downloaded with the following links and should be saved in a
|
328 |
+
$SQUAD_DIR directory.
|
329 |
+
|
330 |
+
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
|
331 |
+
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
|
332 |
+
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
|
333 |
+
|
334 |
+
```bash
|
335 |
+
export SQUAD_DIR=/path/to/SQUAD
|
336 |
+
|
337 |
+
python run_squad.py \
|
338 |
+
--model_type bert \
|
339 |
+
--model_name_or_path bert-base-cased \
|
340 |
+
--do_train \
|
341 |
+
--do_eval \
|
342 |
+
--do_lower_case \
|
343 |
+
--train_file $SQUAD_DIR/train-v1.1.json \
|
344 |
+
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
345 |
+
--per_gpu_train_batch_size 12 \
|
346 |
+
--learning_rate 3e-5 \
|
347 |
+
--num_train_epochs 2.0 \
|
348 |
+
--max_seq_length 384 \
|
349 |
+
--doc_stride 128 \
|
350 |
+
--output_dir /tmp/debug_squad/
|
351 |
+
```
|
352 |
+
|
353 |
+
Training with the previously defined hyper-parameters yields the following results:
|
354 |
+
|
355 |
+
```bash
|
356 |
+
f1 = 88.52
|
357 |
+
exact_match = 81.22
|
358 |
+
```
|
359 |
+
|
360 |
+
#### Distributed training
|
361 |
+
|
362 |
+
|
363 |
+
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking uncased model to reach a F1 > 93 on SQuAD:
|
364 |
+
|
365 |
+
```bash
|
366 |
+
python -m torch.distributed.launch --nproc_per_node=8 run_squad.py \
|
367 |
+
--model_type bert \
|
368 |
+
--model_name_or_path bert-base-cased \
|
369 |
+
--do_train \
|
370 |
+
--do_eval \
|
371 |
+
--do_lower_case \
|
372 |
+
--train_file $SQUAD_DIR/train-v1.1.json \
|
373 |
+
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
374 |
+
--learning_rate 3e-5 \
|
375 |
+
--num_train_epochs 2 \
|
376 |
+
--max_seq_length 384 \
|
377 |
+
--doc_stride 128 \
|
378 |
+
--output_dir ../models/wwm_uncased_finetuned_squad/ \
|
379 |
+
--per_gpu_train_batch_size 24 \
|
380 |
+
--gradient_accumulation_steps 12
|
381 |
+
```
|
382 |
+
|
383 |
+
Training with the previously defined hyper-parameters yields the following results:
|
384 |
+
|
385 |
+
```bash
|
386 |
+
f1 = 93.15
|
387 |
+
exact_match = 86.91
|
388 |
+
```
|
389 |
+
|
390 |
+
This fine-tuneds model is available as a checkpoint under the reference
|
391 |
+
`bert-large-uncased-whole-word-masking-finetuned-squad`.
|
392 |
+
|
Optimus/code/examples/__pycache__/utils_glue.cpython-37.pyc
ADDED
Binary file (21.5 kB). View file
|
|
Optimus/code/examples/big_ae/__pycache__/grad_app.cpython-310.pyc
ADDED
Binary file (14 kB). View file
|
|
Optimus/code/examples/big_ae/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (40.3 kB). View file
|
|
Optimus/code/examples/big_ae/debug_data.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
|
4 |
+
output_dir = "../output/philly_rr1_vae_wikipedia_pretraining_2nd_file"
|
5 |
+
|
6 |
+
data = torch.load(os.path.join(output_dir, 'batch_debug_6621.pt')
|
Optimus/code/examples/big_ae/eval_dialog_multi_response.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from nltk.translate.bleu_score import sentence_bleu
|
5 |
+
from nltk.translate.bleu_score import SmoothingFunction
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity as cosine
|
7 |
+
from collections import Counter
|
8 |
+
import os, pickle, pdb
|
9 |
+
|
10 |
+
class Metrics:
|
11 |
+
# based on https://raw.githubusercontent.com/guxd/DialogWAE/29f206af05bfe5fe28fec4448e208310a7c9258d/experiments/metrics.py
|
12 |
+
|
13 |
+
def __init__(self, path_word2vec='../data/datasets/dailydialog_data/glove.twitter.27B.200d.txt'):
|
14 |
+
"""
|
15 |
+
:param word2vec - a numpy array of word2vec with shape [vocab_size x emb_size]
|
16 |
+
"""
|
17 |
+
super(Metrics, self).__init__()
|
18 |
+
self.load_word2vec(path_word2vec)
|
19 |
+
#self.word2vec = dict()
|
20 |
+
|
21 |
+
def load_word2vec(self, path_word2vec):
|
22 |
+
path_pkl = path_word2vec + '.pkl'
|
23 |
+
if os.path.exists(path_pkl):
|
24 |
+
print('loading word2vec from '+path_pkl)
|
25 |
+
self.word2vec = pickle.load(open(path_pkl, 'rb'))
|
26 |
+
else:
|
27 |
+
self.word2vec = dict()
|
28 |
+
for i, line in enumerate(open(path_word2vec, encoding='utf-8')):
|
29 |
+
ss = line.strip('\n').split()
|
30 |
+
self.word2vec[ss[0]] = [float(v) for v in ss[1:]]
|
31 |
+
if i % 1e4 == 0:
|
32 |
+
print('processed %ik word2vec'%(i/1e3))
|
33 |
+
print('dumping word2vec to '+path_pkl)
|
34 |
+
pickle.dump(self.word2vec, open(path_pkl, 'wb'))
|
35 |
+
self.embed_dim = len(list(self.word2vec.values())[0])
|
36 |
+
print('loaded %i word2vec of dim %i'%(len(self.word2vec), self.embed_dim))
|
37 |
+
|
38 |
+
def embedding(self, seqs):
|
39 |
+
# note: different from original implementation
|
40 |
+
batch_size, seqlen = seqs.shape
|
41 |
+
embs = np.zeros([batch_size, seqlen, self.embed_dim])
|
42 |
+
for i in range(batch_size):
|
43 |
+
for j in range(seqlen):
|
44 |
+
w = seqs[i,j]
|
45 |
+
if w != '' and w in self.word2vec:
|
46 |
+
embs[i, j, :] = self.word2vec[w]
|
47 |
+
return embs
|
48 |
+
|
49 |
+
|
50 |
+
def extrema(self, embs, lens): # embs: [batch_size x seq_len x emb_size] lens: [batch_size]
|
51 |
+
"""
|
52 |
+
computes the value of every single dimension in the word vectors which has the greatest
|
53 |
+
difference from zero.
|
54 |
+
:param seq: sequence
|
55 |
+
:param seqlen: length of sequence
|
56 |
+
"""
|
57 |
+
# Find minimum and maximum value for every dimension in predictions
|
58 |
+
batch_size, seq_len, emb_size = embs.shape
|
59 |
+
max_mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
|
60 |
+
for i,length in enumerate(lens):
|
61 |
+
max_mask[i,:length,:]=1
|
62 |
+
min_mask = 1-max_mask
|
63 |
+
seq_max = (embs*max_mask).max(1) # [batch_sz x emb_sz]
|
64 |
+
seq_min = (embs+min_mask).min(1)
|
65 |
+
# Find the maximum absolute value in min and max data
|
66 |
+
comp_mask = seq_max >= np.abs(seq_min)# [batch_sz x emb_sz]
|
67 |
+
# Add vectors for finding final sequence representation for predictions
|
68 |
+
extrema_emb = seq_max* comp_mask + seq_min* np.logical_not(comp_mask)
|
69 |
+
return extrema_emb
|
70 |
+
|
71 |
+
def mean(self, embs, lens):
|
72 |
+
batch_size, seq_len, emb_size=embs.shape
|
73 |
+
mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
|
74 |
+
for i,length in enumerate(lens):
|
75 |
+
mask[i,:length,:]=1
|
76 |
+
return (embs*mask).sum(1)/(mask.sum(1)+1e-8)
|
77 |
+
|
78 |
+
def sim_bleu(self, hyps, ref):
|
79 |
+
"""
|
80 |
+
:param ref - a list of tokens of the reference
|
81 |
+
:param hyps - a list of tokens of the hypothesis
|
82 |
+
|
83 |
+
:return maxbleu - recall bleu
|
84 |
+
:return avgbleu - precision bleu
|
85 |
+
"""
|
86 |
+
scores = []
|
87 |
+
for hyp in hyps:
|
88 |
+
try:
|
89 |
+
scores.append(sentence_bleu([ref], hyp, smoothing_function=SmoothingFunction().method7,
|
90 |
+
weights=[1./3, 1./3, 1./3]))
|
91 |
+
except:
|
92 |
+
scores.append(0.0)
|
93 |
+
return np.max(scores), np.mean(scores)
|
94 |
+
|
95 |
+
|
96 |
+
def sim_bow(self, pred, pred_lens, ref, ref_lens):
|
97 |
+
"""
|
98 |
+
:param pred - ndarray [batch_size x seqlen]
|
99 |
+
:param pred_lens - list of integers
|
100 |
+
:param ref - ndarray [batch_size x seqlen]
|
101 |
+
"""
|
102 |
+
# look up word embeddings for prediction and reference
|
103 |
+
emb_pred = self.embedding(pred) # [batch_sz x seqlen1 x emb_sz]
|
104 |
+
emb_ref = self.embedding(ref) # [batch_sz x seqlen2 x emb_sz]
|
105 |
+
|
106 |
+
ext_emb_pred=self.extrema(emb_pred, pred_lens)
|
107 |
+
ext_emb_ref=self.extrema(emb_ref, ref_lens)
|
108 |
+
bow_extrema=cosine(ext_emb_pred, ext_emb_ref) # [batch_sz_pred x batch_sz_ref]
|
109 |
+
|
110 |
+
avg_emb_pred = self.mean(emb_pred, pred_lens) # Calculate mean over seq
|
111 |
+
avg_emb_ref = self.mean(emb_ref, ref_lens)
|
112 |
+
bow_avg = cosine(avg_emb_pred, avg_emb_ref) # [batch_sz_pred x batch_sz_ref]
|
113 |
+
|
114 |
+
|
115 |
+
batch_pred, seqlen_pred, emb_size=emb_pred.shape
|
116 |
+
batch_ref, seqlen_ref, emb_size=emb_ref.shape
|
117 |
+
cos_sim = cosine(emb_pred.reshape((-1, emb_size)), emb_ref.reshape((-1, emb_size))) # [(batch_sz*seqlen1)x(batch_sz*seqlen2)]
|
118 |
+
cos_sim = cos_sim.reshape((batch_pred, seqlen_pred, batch_ref, seqlen_ref))
|
119 |
+
# Find words with max cosine similarity
|
120 |
+
max12 = cos_sim.max(1).mean(2) # max over seqlen_pred
|
121 |
+
max21 = cos_sim.max(3).mean(1) # max over seqlen_ref
|
122 |
+
bow_greedy=(max12+max21)/2 # [batch_pred x batch_ref(1)]
|
123 |
+
return np.max(bow_extrema), np.max(bow_avg), np.max(bow_greedy)
|
124 |
+
|
125 |
+
def div_distinct(self, seqs, seq_lens):
|
126 |
+
"""
|
127 |
+
distinct-1 distinct-2 metrics for diversity measure proposed
|
128 |
+
by Li et al. "A Diversity-Promoting Objective Function for Neural Conversation Models"
|
129 |
+
we counted numbers of distinct unigrams and bigrams in the generated responses
|
130 |
+
and divide the numbers by total number of unigrams and bigrams.
|
131 |
+
The two metrics measure how informative and diverse the generated responses are.
|
132 |
+
High numbers and high ratios mean that there is much content in the generated responses,
|
133 |
+
and high numbers further indicate that the generated responses are long
|
134 |
+
"""
|
135 |
+
batch_size = seqs.shape[0]
|
136 |
+
intra_dist1, intra_dist2=np.zeros(batch_size), np.zeros(batch_size)
|
137 |
+
|
138 |
+
n_unigrams, n_bigrams, n_unigrams_total , n_bigrams_total = 0. ,0., 0., 0.
|
139 |
+
unigrams_all, bigrams_all = Counter(), Counter()
|
140 |
+
for b in range(batch_size):
|
141 |
+
unigrams= Counter([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
|
142 |
+
bigrams = Counter([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
|
143 |
+
intra_dist1[b]=(len(unigrams.items())+1e-12)/(seq_lens[b]+1e-5)
|
144 |
+
intra_dist2[b]=(len(bigrams.items())+1e-12)/(max(0, seq_lens[b]-1)+1e-5)
|
145 |
+
|
146 |
+
unigrams_all.update([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
|
147 |
+
bigrams_all.update([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
|
148 |
+
n_unigrams_total += seq_lens[b]
|
149 |
+
n_bigrams_total += max(0, seq_lens[b]-1)
|
150 |
+
|
151 |
+
inter_dist1 = (len(unigrams_all.items())+1e-12)/(n_unigrams_total+1e-5)
|
152 |
+
inter_dist2 = (len(bigrams_all.items())+1e-12)/(n_bigrams_total+1e-5)
|
153 |
+
return intra_dist1, intra_dist2, inter_dist1, inter_dist2
|
154 |
+
|
155 |
+
import pdb
|
156 |
+
|
157 |
+
def eval_multi_ref(path, path_multi_ref=None):
|
158 |
+
"""
|
159 |
+
based on: https://github.com/guxd/DialogWAE/blob/29f206af05bfe5fe28fec4448e208310a7c9258d/sample.py
|
160 |
+
path: each line is '\t'.join([src, ref, hyp])
|
161 |
+
path_multi_ref: each line is '\t'.join([src, hyp])
|
162 |
+
the order of unique src appeared in `path_multi_ref` should be the same as that in `path`
|
163 |
+
"""
|
164 |
+
metrics = Metrics()
|
165 |
+
d_ref = dict()
|
166 |
+
d_hyp = dict()
|
167 |
+
src2ix = dict()
|
168 |
+
ix2src = dict()
|
169 |
+
ix = 0
|
170 |
+
for line in open(path, encoding='utf-8'):
|
171 |
+
line = line.strip('\n').strip()
|
172 |
+
if len(line) == 0:
|
173 |
+
continue
|
174 |
+
|
175 |
+
# pdb.set_trace()
|
176 |
+
src, ref, hyp = line.split('\t')
|
177 |
+
#src, ref = line.split('\t'); hyp = ref
|
178 |
+
src = src.replace(' EOS ',' [SEP] ').strip()
|
179 |
+
ref = ref.strip().split()
|
180 |
+
hyp = hyp.strip().split()
|
181 |
+
if src not in d_ref:
|
182 |
+
d_ref[src] = ref
|
183 |
+
d_hyp[src] = [hyp]
|
184 |
+
src2ix[src] = ix
|
185 |
+
ix2src[ix] = src
|
186 |
+
ix += 1
|
187 |
+
else:
|
188 |
+
d_hyp[src].append(hyp)
|
189 |
+
print('loaded %i src-ref-hyp tuples'%(len(d_ref)))
|
190 |
+
|
191 |
+
def chr_only(s):
|
192 |
+
ret = ''
|
193 |
+
for c in s:
|
194 |
+
if c.isalpha():
|
195 |
+
ret += c
|
196 |
+
return ret
|
197 |
+
|
198 |
+
if path_multi_ref is not None:
|
199 |
+
set_src4multiref = set()
|
200 |
+
ix = -1
|
201 |
+
d_multi_ref = dict()
|
202 |
+
for line in open(path_multi_ref, encoding='utf-8'):
|
203 |
+
line = line.strip('\n').strip()
|
204 |
+
if len(line) == 0:
|
205 |
+
continue
|
206 |
+
src4multiref, ref = line.split('\t')[:2]
|
207 |
+
src4multiref = src4multiref.replace(' EOS ', ' ').replace(' [SEP] ',' ').strip()
|
208 |
+
ref = ref.strip().split()
|
209 |
+
if src4multiref not in set_src4multiref:
|
210 |
+
set_src4multiref.add(src4multiref)
|
211 |
+
ix += 1
|
212 |
+
src = ix2src[ix]
|
213 |
+
id_hyp = chr_only(src)
|
214 |
+
id_multiref = chr_only(src4multiref)
|
215 |
+
if id_multiref != id_hyp:
|
216 |
+
print('[ERROR] cannot match src4multiref and src4hyp')
|
217 |
+
print('src4multiref:', src4multiref)
|
218 |
+
print('src4hyp:', ix2src[ix])
|
219 |
+
# pdb.set_trace()
|
220 |
+
raise ValueError
|
221 |
+
d_multi_ref[src] = [ref]
|
222 |
+
else:
|
223 |
+
d_multi_ref[src].append(ref)
|
224 |
+
|
225 |
+
n_ref = [len(d_multi_ref[k]) for k in d_multi_ref]
|
226 |
+
print('loaded %i src with multi-ref, avg n_ref = %.3f'%(len(d_multi_ref), np.mean(n_ref)))
|
227 |
+
|
228 |
+
n_miss = 0
|
229 |
+
for src in d_ref:
|
230 |
+
if src not in d_multi_ref:
|
231 |
+
n_miss += 1
|
232 |
+
print('[WARNING] cannot find multiref for src: '+src)
|
233 |
+
d_multi_ref[src] = [d_ref[src]]
|
234 |
+
if n_miss > 5:
|
235 |
+
raise ValueError
|
236 |
+
|
237 |
+
n = len(d_ref)
|
238 |
+
print(path)
|
239 |
+
print('n_src\t%i'%n)
|
240 |
+
|
241 |
+
avg_lens = 0
|
242 |
+
maxbleu = 0
|
243 |
+
avgbleu = 0
|
244 |
+
intra_dist1, intra_dist2, inter_dist1, inter_dist2 = 0,0,0,0
|
245 |
+
bow_extrema, bow_avg, bow_greedy = 0,0,0
|
246 |
+
for src in d_ref:
|
247 |
+
|
248 |
+
# BLEU ----
|
249 |
+
|
250 |
+
if path_multi_ref is None:
|
251 |
+
m, a = metrics.sim_bleu(d_hyp[src], d_ref[src])
|
252 |
+
else:
|
253 |
+
n_ref = len(d_multi_ref[src])
|
254 |
+
m, a = 0, 0
|
255 |
+
for ref in d_multi_ref[src]:
|
256 |
+
_m, _a = metrics.sim_bleu(d_hyp[src], ref)
|
257 |
+
m += _m
|
258 |
+
a += _a
|
259 |
+
m /= n_ref
|
260 |
+
a /= n_ref
|
261 |
+
|
262 |
+
maxbleu += m
|
263 |
+
avgbleu += a
|
264 |
+
|
265 |
+
# diversity ----
|
266 |
+
|
267 |
+
seq_len = [len(hyp) for hyp in d_hyp[src]]
|
268 |
+
max_len = max(seq_len)
|
269 |
+
seqs = []
|
270 |
+
for hyp in d_hyp[src]:
|
271 |
+
padded = hyp + [''] * (max_len - len(hyp))
|
272 |
+
seqs.append(np.reshape(padded, [1, -1]))
|
273 |
+
seqs = np.concatenate(seqs, axis=0)
|
274 |
+
intra1, intra2, inter1, inter2 = metrics.div_distinct(seqs, seq_len)
|
275 |
+
intra_dist1 += np.mean(intra1)
|
276 |
+
intra_dist2 += np.mean(intra2)
|
277 |
+
inter_dist1 += inter1
|
278 |
+
inter_dist2 += inter2
|
279 |
+
|
280 |
+
avg_lens += np.mean(seq_len)
|
281 |
+
|
282 |
+
# BOW ----
|
283 |
+
|
284 |
+
def calc_bow(ref):
|
285 |
+
n_hyp = len(d_hyp[src])
|
286 |
+
seqs_ref = np.concatenate([np.reshape(ref, [1,-1])] * n_hyp, axis=0)
|
287 |
+
seq_len_ref = [len(ref)] * n_hyp
|
288 |
+
return metrics.sim_bow(seqs, seq_len, seqs_ref, seq_len_ref)
|
289 |
+
|
290 |
+
if path_multi_ref is None:
|
291 |
+
extrema, avg, greedy = calc_bow(d_ref[src])
|
292 |
+
else:
|
293 |
+
extrema, avg, greedy = 0, 0, 0
|
294 |
+
for ref in d_multi_ref[src]:
|
295 |
+
e, a, g = calc_bow(ref)
|
296 |
+
extrema += e
|
297 |
+
avg += a
|
298 |
+
greedy += g
|
299 |
+
extrema /= n_ref
|
300 |
+
avg /= n_ref
|
301 |
+
greedy /= n_ref
|
302 |
+
|
303 |
+
bow_extrema += extrema
|
304 |
+
bow_avg += avg
|
305 |
+
bow_greedy += greedy
|
306 |
+
|
307 |
+
recall_bleu = maxbleu/n
|
308 |
+
prec_bleu = avgbleu/n
|
309 |
+
f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
|
310 |
+
|
311 |
+
print('BLEU')
|
312 |
+
print(' R\t%.3f'%recall_bleu)
|
313 |
+
print(' P\t%.3f'%prec_bleu)
|
314 |
+
print(' F1\t%.3f'%f1)
|
315 |
+
print('BOW')
|
316 |
+
print(' A\t%.3f'%(bow_avg/n))
|
317 |
+
print(' E\t%.3f'%(bow_extrema/n))
|
318 |
+
print(' G\t%.3f'%(bow_greedy/n))
|
319 |
+
print('intra_dist')
|
320 |
+
print(' 1\t%.3f'%(intra_dist1/n))
|
321 |
+
print(' 2\t%.3f'%(intra_dist2/n))
|
322 |
+
print('inter_dist')
|
323 |
+
print(' 1\t%.3f'%(inter_dist1/n))
|
324 |
+
print(' 2\t%.3f'%(inter_dist2/n))
|
325 |
+
print('avg_L\t%.1f'%(avg_lens/n))
|
326 |
+
|
327 |
+
results = {
|
328 |
+
"BLEU_R": recall_bleu, "BLEU_P": prec_bleu, "BLEU_F1": f1, "BOW_A": bow_avg/n, "BOW_E": bow_extrema/n, "BOW_G": bow_greedy/n, "intra_dist1": intra_dist1/n, "intra_dist2": intra_dist2/n, "inter_dist1": inter_dist1/n, "inter_dist2": inter_dist2/n, "avg_L": avg_lens/n
|
329 |
+
}
|
330 |
+
|
331 |
+
return results
|
332 |
+
|
333 |
+
|
334 |
+
def create_rand_baseline():
|
335 |
+
path = 'data/datasets/dailydialog_data/test.txt'
|
336 |
+
srcs = []
|
337 |
+
refs = []
|
338 |
+
for line in open(path, encoding='utf-8'):
|
339 |
+
src, ref = line.strip('\n').split('\t')
|
340 |
+
srcs.append(src.strip())
|
341 |
+
refs.append(ref.strip())
|
342 |
+
|
343 |
+
hyps = set()
|
344 |
+
path = 'data/datasets/dailydialog_data/train.txt'
|
345 |
+
for line in open(path, encoding='utf-8'):
|
346 |
+
_, ref = line.strip('\n').split('\t')
|
347 |
+
hyps.add(ref)
|
348 |
+
if len(hyps) == len(srcs) *10:
|
349 |
+
print('collected training ref')
|
350 |
+
break
|
351 |
+
|
352 |
+
hyps = list(hyps)
|
353 |
+
lines = []
|
354 |
+
j = 0
|
355 |
+
for i in range(len(srcs)):
|
356 |
+
lines += ['\t'.join([srcs[i], refs[i], hyp]) for hyp in hyps[j:j+10]]
|
357 |
+
j = j + 10
|
358 |
+
with open('out/rand.tsv', 'w', encoding='utf-8') as f:
|
359 |
+
f.write('\n'.join(lines))
|
360 |
+
|
361 |
+
|
362 |
+
def create_human_baseline():
|
363 |
+
path = 'data/datasets/dailydialog_data/test.txt'
|
364 |
+
lines = []
|
365 |
+
for line in open(path, encoding='utf-8'):
|
366 |
+
src, ref = line.strip('\n').split('\t')
|
367 |
+
src = src.strip()
|
368 |
+
ref = ref.strip()
|
369 |
+
lines.append('\t'.join([src, ref, ref]))
|
370 |
+
|
371 |
+
with open('out/human.tsv', 'w', encoding='utf-8') as f:
|
372 |
+
f.write('\n'.join(lines))
|
373 |
+
|
374 |
+
|
375 |
+
if __name__ == "__main__":
|
376 |
+
path = 'D:/data/switchboard/test.txt.1ref'
|
377 |
+
path_multi_ref = 'D:/data/switchboard/test.txt'
|
378 |
+
eval_multi_ref(path_multi_ref, path)
|
Optimus/code/examples/big_ae/eval_dialog_response.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from nltk.translate.bleu_score import sentence_bleu
|
5 |
+
from nltk.translate.bleu_score import SmoothingFunction
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity as cosine
|
7 |
+
from collections import Counter
|
8 |
+
import os, pickle
|
9 |
+
|
10 |
+
class Metrics:
|
11 |
+
# based on https://raw.githubusercontent.com/guxd/DialogWAE/29f206af05bfe5fe28fec4448e208310a7c9258d/experiments/metrics.py
|
12 |
+
|
13 |
+
def __init__(self, path_word2vec='../data/datasets/dailydialog_data/glove.twitter.27B.200d.txt'):
|
14 |
+
"""
|
15 |
+
:param word2vec - a numpy array of word2vec with shape [vocab_size x emb_size]
|
16 |
+
"""
|
17 |
+
self.path_word2vec = path_word2vec
|
18 |
+
super(Metrics, self).__init__()
|
19 |
+
self.load_word2vec(path_word2vec)
|
20 |
+
|
21 |
+
def load_word2vec(self, path_word2vec):
|
22 |
+
path_pkl = path_word2vec + '.pkl'
|
23 |
+
if os.path.exists(path_pkl):
|
24 |
+
print('loading word2vec from '+path_pkl)
|
25 |
+
self.word2vec = pickle.load(open(path_pkl, 'rb'))
|
26 |
+
else:
|
27 |
+
self.word2vec = dict()
|
28 |
+
for i, line in enumerate(open(path_word2vec, encoding='utf-8')):
|
29 |
+
ss = line.strip('\n').split()
|
30 |
+
self.word2vec[ss[0]] = [float(v) for v in ss[1:]]
|
31 |
+
if i % 1e4 == 0:
|
32 |
+
print('processed %ik word2vec'%(i/1e3))
|
33 |
+
print('dumping word2vec to '+path_pkl)
|
34 |
+
pickle.dump(self.word2vec, open(path_pkl, 'wb'))
|
35 |
+
# pdb.set_trace()
|
36 |
+
self.embed_dim = len(self.word2vec["."]) # len(self.word2vec.values()[0])
|
37 |
+
print('loaded %i word2vec of dim %i'%(len(self.word2vec), self.embed_dim))
|
38 |
+
|
39 |
+
def embedding(self, seqs):
|
40 |
+
# note: different from original implementation
|
41 |
+
batch_size, seqlen = seqs.shape
|
42 |
+
embs = np.zeros([batch_size, seqlen, self.embed_dim])
|
43 |
+
for i in range(batch_size):
|
44 |
+
for j in range(seqlen):
|
45 |
+
w = seqs[i,j]
|
46 |
+
if w != '' and w in self.word2vec:
|
47 |
+
embs[i, j, :] = self.word2vec[w]
|
48 |
+
return embs
|
49 |
+
|
50 |
+
|
51 |
+
def extrema(self, embs, lens): # embs: [batch_size x seq_len x emb_size] lens: [batch_size]
|
52 |
+
"""
|
53 |
+
computes the value of every single dimension in the word vectors which has the greatest
|
54 |
+
difference from zero.
|
55 |
+
:param seq: sequence
|
56 |
+
:param seqlen: length of sequence
|
57 |
+
"""
|
58 |
+
# Find minimum and maximum value for every dimension in predictions
|
59 |
+
batch_size, seq_len, emb_size = embs.shape
|
60 |
+
max_mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
|
61 |
+
for i,length in enumerate(lens):
|
62 |
+
max_mask[i,:length,:]=1
|
63 |
+
min_mask = 1-max_mask
|
64 |
+
seq_max = (embs*max_mask).max(1) # [batch_sz x emb_sz]
|
65 |
+
seq_min = (embs+min_mask).min(1)
|
66 |
+
# Find the maximum absolute value in min and max data
|
67 |
+
comp_mask = seq_max >= np.abs(seq_min)# [batch_sz x emb_sz]
|
68 |
+
# Add vectors for finding final sequence representation for predictions
|
69 |
+
extrema_emb = seq_max* comp_mask + seq_min* np.logical_not(comp_mask)
|
70 |
+
return extrema_emb
|
71 |
+
|
72 |
+
def mean(self, embs, lens):
|
73 |
+
batch_size, seq_len, emb_size=embs.shape
|
74 |
+
mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
|
75 |
+
for i,length in enumerate(lens):
|
76 |
+
mask[i,:length,:]=1
|
77 |
+
return (embs*mask).sum(1)/(mask.sum(1)+1e-8)
|
78 |
+
|
79 |
+
def sim_bleu(self, hyps, ref):
|
80 |
+
"""
|
81 |
+
:param ref - a list of tokens of the reference
|
82 |
+
:param hyps - a list of tokens of the hypothesis
|
83 |
+
|
84 |
+
:return maxbleu - recall bleu
|
85 |
+
:return avgbleu - precision bleu
|
86 |
+
"""
|
87 |
+
scores = []
|
88 |
+
for hyp in hyps:
|
89 |
+
try:
|
90 |
+
scores.append(sentence_bleu([ref], hyp, smoothing_function=SmoothingFunction().method7,
|
91 |
+
weights=[1./3, 1./3, 1./3]))
|
92 |
+
except:
|
93 |
+
scores.append(0.0)
|
94 |
+
return np.max(scores), np.mean(scores)
|
95 |
+
|
96 |
+
|
97 |
+
def sim_bow(self, pred, pred_lens, ref, ref_lens):
|
98 |
+
"""
|
99 |
+
:param pred - ndarray [batch_size x seqlen]
|
100 |
+
:param pred_lens - list of integers
|
101 |
+
:param ref - ndarray [batch_size x seqlen]
|
102 |
+
"""
|
103 |
+
# look up word embeddings for prediction and reference
|
104 |
+
emb_pred = self.embedding(pred) # [batch_sz x seqlen1 x emb_sz]
|
105 |
+
emb_ref = self.embedding(ref) # [batch_sz x seqlen2 x emb_sz]
|
106 |
+
|
107 |
+
ext_emb_pred=self.extrema(emb_pred, pred_lens)
|
108 |
+
ext_emb_ref=self.extrema(emb_ref, ref_lens)
|
109 |
+
bow_extrema=cosine(ext_emb_pred, ext_emb_ref) # [batch_sz_pred x batch_sz_ref]
|
110 |
+
|
111 |
+
avg_emb_pred = self.mean(emb_pred, pred_lens) # Calculate mean over seq
|
112 |
+
avg_emb_ref = self.mean(emb_ref, ref_lens)
|
113 |
+
bow_avg = cosine(avg_emb_pred, avg_emb_ref) # [batch_sz_pred x batch_sz_ref]
|
114 |
+
|
115 |
+
|
116 |
+
batch_pred, seqlen_pred, emb_size=emb_pred.shape
|
117 |
+
batch_ref, seqlen_ref, emb_size=emb_ref.shape
|
118 |
+
cos_sim = cosine(emb_pred.reshape((-1, emb_size)), emb_ref.reshape((-1, emb_size))) # [(batch_sz*seqlen1)x(batch_sz*seqlen2)]
|
119 |
+
cos_sim = cos_sim.reshape((batch_pred, seqlen_pred, batch_ref, seqlen_ref))
|
120 |
+
# Find words with max cosine similarity
|
121 |
+
max12 = cos_sim.max(1).mean(2) # max over seqlen_pred
|
122 |
+
max21 = cos_sim.max(3).mean(1) # max over seqlen_ref
|
123 |
+
bow_greedy=(max12+max21)/2 # [batch_pred x batch_ref(1)]
|
124 |
+
return np.max(bow_extrema), np.max(bow_avg), np.max(bow_greedy)
|
125 |
+
|
126 |
+
def div_distinct(self, seqs, seq_lens):
|
127 |
+
"""
|
128 |
+
distinct-1 distinct-2 metrics for diversity measure proposed
|
129 |
+
by Li et al. "A Diversity-Promoting Objective Function for Neural Conversation Models"
|
130 |
+
we counted numbers of distinct unigrams and bigrams in the generated responses
|
131 |
+
and divide the numbers by total number of unigrams and bigrams.
|
132 |
+
The two metrics measure how informative and diverse the generated responses are.
|
133 |
+
High numbers and high ratios mean that there is much content in the generated responses,
|
134 |
+
and high numbers further indicate that the generated responses are long
|
135 |
+
"""
|
136 |
+
batch_size = seqs.shape[0]
|
137 |
+
intra_dist1, intra_dist2=np.zeros(batch_size), np.zeros(batch_size)
|
138 |
+
|
139 |
+
n_unigrams, n_bigrams, n_unigrams_total , n_bigrams_total = 0. ,0., 0., 0.
|
140 |
+
unigrams_all, bigrams_all = Counter(), Counter()
|
141 |
+
for b in range(batch_size):
|
142 |
+
unigrams= Counter([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
|
143 |
+
bigrams = Counter([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
|
144 |
+
intra_dist1[b]=(len(unigrams.items())+1e-12)/(seq_lens[b]+1e-5)
|
145 |
+
intra_dist2[b]=(len(bigrams.items())+1e-12)/(max(0, seq_lens[b]-1)+1e-5)
|
146 |
+
|
147 |
+
unigrams_all.update([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
|
148 |
+
bigrams_all.update([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
|
149 |
+
n_unigrams_total += seq_lens[b]
|
150 |
+
n_bigrams_total += max(0, seq_lens[b]-1)
|
151 |
+
|
152 |
+
inter_dist1 = (len(unigrams_all.items())+1e-12)/(n_unigrams_total+1e-5)
|
153 |
+
inter_dist2 = (len(bigrams_all.items())+1e-12)/(n_bigrams_total+1e-5)
|
154 |
+
return intra_dist1, intra_dist2, inter_dist1, inter_dist2
|
155 |
+
|
156 |
+
import pdb
|
157 |
+
|
158 |
+
def eval_dialog_response(generated_text_file_path):
|
159 |
+
"""
|
160 |
+
based on: https://github.com/guxd/DialogWAE/blob/29f206af05bfe5fe28fec4448e208310a7c9258d/sample.py
|
161 |
+
quoted from the DialogWAE paper: https://arxiv.org/pdf/1805.12352.pdf
|
162 |
+
* "For each test context, we sample 10 responses from the models and compute their BLEU scores"
|
163 |
+
* "We use Glove vectors" "For each test context, we report the maximum BOW embedding score among the 10 sampled responses."
|
164 |
+
* "intra-dist as the average of distinct values within each sampled response"
|
165 |
+
" "inter-dist as the distinct value among all sampled responses."
|
166 |
+
"""
|
167 |
+
metrics = Metrics()
|
168 |
+
d_ref = dict()
|
169 |
+
d_hyp = dict()
|
170 |
+
for line in open(generated_text_file_path, encoding='utf-8'):
|
171 |
+
line = line.strip('\n').strip()
|
172 |
+
if len(line) == 0:
|
173 |
+
continue
|
174 |
+
src, ref, hyp = line.split('\t')
|
175 |
+
src = src.strip()
|
176 |
+
ref = ref.strip().split()
|
177 |
+
hyp = hyp.strip().split()
|
178 |
+
if src not in d_ref:
|
179 |
+
d_ref[src] = ref
|
180 |
+
d_hyp[src] = [hyp]
|
181 |
+
else:
|
182 |
+
d_hyp[src].append(hyp)
|
183 |
+
|
184 |
+
n = len(d_ref)
|
185 |
+
print(generated_text_file_path)
|
186 |
+
print('n_src\t%i'%n)
|
187 |
+
|
188 |
+
avg_lens = 0
|
189 |
+
maxbleu = 0
|
190 |
+
avgbleu = 0
|
191 |
+
intra_dist1, intra_dist2, inter_dist1, inter_dist2 = 0,0,0,0
|
192 |
+
bow_extrema, bow_avg, bow_greedy = 0,0,0
|
193 |
+
for src in d_ref:
|
194 |
+
m, a = metrics.sim_bleu(d_hyp[src], d_ref[src])
|
195 |
+
maxbleu += m
|
196 |
+
avgbleu += a
|
197 |
+
|
198 |
+
seq_len = [len(hyp) for hyp in d_hyp[src]]
|
199 |
+
max_len = max(seq_len)
|
200 |
+
seqs = []
|
201 |
+
for hyp in d_hyp[src]:
|
202 |
+
padded = hyp + [''] * (max_len - len(hyp))
|
203 |
+
seqs.append(np.reshape(padded, [1, -1]))
|
204 |
+
seqs = np.concatenate(seqs, axis=0)
|
205 |
+
intra1, intra2, inter1, inter2 = metrics.div_distinct(seqs, seq_len)
|
206 |
+
intra_dist1 += np.mean(intra1)
|
207 |
+
intra_dist2 += np.mean(intra2)
|
208 |
+
inter_dist1 += inter1
|
209 |
+
inter_dist2 += inter2
|
210 |
+
|
211 |
+
n_hyp = len(d_hyp[src])
|
212 |
+
seqs_ref = np.concatenate([np.reshape(d_ref[src], [1,-1])] * n_hyp, axis=0)
|
213 |
+
seq_len_ref = [len(d_ref[src])] * n_hyp
|
214 |
+
if metrics.word2vec is not None:
|
215 |
+
extrema, avg, greedy = metrics.sim_bow(seqs, seq_len, seqs_ref, seq_len_ref)
|
216 |
+
bow_extrema += extrema
|
217 |
+
bow_avg += avg
|
218 |
+
bow_greedy += greedy
|
219 |
+
|
220 |
+
avg_lens += np.mean(seq_len)
|
221 |
+
|
222 |
+
recall_bleu = maxbleu/n
|
223 |
+
prec_bleu = avgbleu/n
|
224 |
+
f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
|
225 |
+
|
226 |
+
print('BLEU')
|
227 |
+
print(' R\t%.3f'%recall_bleu)
|
228 |
+
print(' P\t%.3f'%prec_bleu)
|
229 |
+
print(' F1\t%.3f'%f1)
|
230 |
+
print('BOW')
|
231 |
+
print(' A\t%.3f'%(bow_avg/n))
|
232 |
+
print(' E\t%.3f'%(bow_extrema/n))
|
233 |
+
print(' G\t%.3f'%(bow_greedy/n))
|
234 |
+
print('intra_dist')
|
235 |
+
print(' 1\t%.3f'%(intra_dist1/n))
|
236 |
+
print(' 2\t%.3f'%(intra_dist2/n))
|
237 |
+
print('inter_dist')
|
238 |
+
print(' 1\t%.3f'%(inter_dist1/n))
|
239 |
+
print(' 2\t%.3f'%(inter_dist2/n))
|
240 |
+
print('avg_L\t%.1f'%(avg_lens/n))
|
241 |
+
|
242 |
+
results = {
|
243 |
+
"BLEU_R": recall_bleu, "BLEU_P": prec_bleu, "BLEU_F1": f1, "BOW_A": bow_avg/n, "BOW_E": bow_extrema/n, "BOW_G": bow_greedy/n, "intra_dist1": intra_dist1/n, "intra_dist2": intra_dist2/n, "inter_dist1": inter_dist1/n, "inter_dist2": inter_dist2/n, "avg_L": avg_lens/n
|
244 |
+
}
|
245 |
+
|
246 |
+
return results
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
def create_rand_baseline():
|
251 |
+
path = 'data/datasets/dailydialog_data/test.txt'
|
252 |
+
srcs = []
|
253 |
+
refs = []
|
254 |
+
for line in open(path, encoding='utf-8'):
|
255 |
+
src, ref = line.strip('\n').split('\t')
|
256 |
+
srcs.append(src.strip())
|
257 |
+
refs.append(ref.strip())
|
258 |
+
|
259 |
+
hyps = set()
|
260 |
+
path = 'data/datasets/dailydialog_data/train.txt'
|
261 |
+
for line in open(path, encoding='utf-8'):
|
262 |
+
_, ref = line.strip('\n').split('\t')
|
263 |
+
hyps.add(ref)
|
264 |
+
if len(hyps) == len(srcs) *10:
|
265 |
+
print('collected training ref')
|
266 |
+
break
|
267 |
+
|
268 |
+
hyps = list(hyps)
|
269 |
+
lines = []
|
270 |
+
j = 0
|
271 |
+
for i in range(len(srcs)):
|
272 |
+
lines += ['\t'.join([srcs[i], refs[i], hyp]) for hyp in hyps[j:j+10]]
|
273 |
+
j = j + 10
|
274 |
+
with open('out/rand.tsv', 'w', encoding='utf-8') as f:
|
275 |
+
f.write('\n'.join(lines))
|
276 |
+
|
277 |
+
|
278 |
+
def create_human_baseline():
|
279 |
+
path = 'data/datasets/dailydialog_data/test.txt'
|
280 |
+
lines = []
|
281 |
+
for line in open(path, encoding='utf-8'):
|
282 |
+
src, ref = line.strip('\n').split('\t')
|
283 |
+
src = src.strip()
|
284 |
+
ref = ref.strip()
|
285 |
+
lines.append('\t'.join([src, ref, ref]))
|
286 |
+
|
287 |
+
with open('out/human.tsv', 'w', encoding='utf-8') as f:
|
288 |
+
f.write('\n'.join(lines))
|
289 |
+
|
290 |
+
|
291 |
+
if __name__ == "__main__":
|
292 |
+
#create_rand_baseline()
|
293 |
+
#create_human_baseline()
|
294 |
+
eval_dialog_response('out/eval_text_generation_results (1).txt')
|
295 |
+
#eval('out/rand.tsv')
|
Optimus/code/examples/big_ae/grad_app.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""message_bottle.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colab.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD
|
8 |
+
"""
|
9 |
+
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import matplotlib
|
12 |
+
|
13 |
+
import argparse
|
14 |
+
import glob
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
import pickle
|
18 |
+
import random
|
19 |
+
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from tqdm import tqdm, trange
|
26 |
+
from types import SimpleNamespace
|
27 |
+
|
28 |
+
import sys
|
29 |
+
sys.path.append('/home/ryn_mote/Misc/generative_recommender/text_space/Optimus/code/examples/big_ae/')
|
30 |
+
sys.path.append('/home/ryn_mote/Misc/generative_recommender/text_space/Optimus/code/')
|
31 |
+
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
|
32 |
+
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
|
33 |
+
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
34 |
+
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
35 |
+
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
36 |
+
from pytorch_transformers import BertForLatentConnector, BertTokenizer
|
37 |
+
|
38 |
+
from modules import VAE
|
39 |
+
|
40 |
+
import torch
|
41 |
+
import torch.nn as nn
|
42 |
+
import torch.nn.functional as F
|
43 |
+
torch.set_float32_matmul_precision('high')
|
44 |
+
|
45 |
+
from tqdm import tqdm
|
46 |
+
|
47 |
+
################################################
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
52 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
53 |
+
Args:
|
54 |
+
logits: logits distribution shape (vocabulary size)
|
55 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
56 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
57 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
58 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
59 |
+
"""
|
60 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
61 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
62 |
+
if top_k > 0:
|
63 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
64 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
65 |
+
logits[indices_to_remove] = filter_value
|
66 |
+
|
67 |
+
if top_p > 0.0:
|
68 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
69 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
70 |
+
|
71 |
+
# Remove tokens with cumulative probability above the threshold
|
72 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
73 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
74 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
75 |
+
sorted_indices_to_remove[..., 0] = 0
|
76 |
+
|
77 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
78 |
+
logits[indices_to_remove] = filter_value
|
79 |
+
return logits
|
80 |
+
|
81 |
+
def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
|
82 |
+
|
83 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
84 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
85 |
+
generated = context
|
86 |
+
with torch.no_grad():
|
87 |
+
while True:
|
88 |
+
# for _ in trange(length):
|
89 |
+
inputs = {'input_ids': generated, 'past': past}
|
90 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
91 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
92 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
93 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
94 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
95 |
+
|
96 |
+
# pdb.set_trace()
|
97 |
+
if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
|
98 |
+
break
|
99 |
+
|
100 |
+
return generated
|
101 |
+
|
102 |
+
|
103 |
+
def latent_code_from_text(text,):# args):
|
104 |
+
tokenized1 = tokenizer_encoder.encode(text)
|
105 |
+
tokenized1 = [101] + tokenized1 + [102]
|
106 |
+
coded1 = torch.Tensor([tokenized1])
|
107 |
+
coded1 =torch.Tensor.long(coded1)
|
108 |
+
with torch.no_grad():
|
109 |
+
x0 = coded1
|
110 |
+
x0 = x0.to('cuda')
|
111 |
+
pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
|
112 |
+
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
|
113 |
+
latent_z = mean.squeeze(1)
|
114 |
+
coded_length = len(tokenized1)
|
115 |
+
return latent_z, coded_length
|
116 |
+
|
117 |
+
# args
|
118 |
+
def text_from_latent_code(latent_z):
|
119 |
+
past = latent_z
|
120 |
+
context_tokens = tokenizer_decoder.encode('<BOS>')
|
121 |
+
|
122 |
+
length = 128 # maximum length, but not used
|
123 |
+
out = sample_sequence_conditional(
|
124 |
+
model=model_vae.decoder,
|
125 |
+
context=context_tokens,
|
126 |
+
past=past,
|
127 |
+
length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
128 |
+
temperature=.2,
|
129 |
+
top_k=50,
|
130 |
+
top_p=.98,
|
131 |
+
device='cuda',
|
132 |
+
decoder_tokenizer = tokenizer_decoder
|
133 |
+
)
|
134 |
+
text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
135 |
+
text_x1 = text_x1.split()[1:-1]
|
136 |
+
text_x1 = ' '.join(text_x1)
|
137 |
+
return text_x1
|
138 |
+
|
139 |
+
|
140 |
+
################################################
|
141 |
+
# Load model
|
142 |
+
|
143 |
+
|
144 |
+
MODEL_CLASSES = {
|
145 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
146 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
|
147 |
+
}
|
148 |
+
|
149 |
+
latent_size = 768
|
150 |
+
model_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-full-31250/'
|
151 |
+
encoder_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-encoder-31250/'
|
152 |
+
decoder_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-decoder-31250/'
|
153 |
+
block_size = 100
|
154 |
+
|
155 |
+
# Load a trained Encoder model and vocabulary that you have fine-tuned
|
156 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES['bert']
|
157 |
+
model_encoder = encoder_model_class.from_pretrained(encoder_path, latent_size=latent_size)
|
158 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained('bert-base-cased', do_lower_case=True)
|
159 |
+
|
160 |
+
model_encoder.to('cuda')
|
161 |
+
if block_size <= 0:
|
162 |
+
block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
163 |
+
block_size = min(block_size, tokenizer_encoder.max_len_single_sentence)
|
164 |
+
|
165 |
+
# Load a trained Decoder model and vocabulary that you have fine-tuned
|
166 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES['gpt2']
|
167 |
+
model_decoder = decoder_model_class.from_pretrained(decoder_path, latent_size=latent_size)
|
168 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained('gpt2', do_lower_case=False)
|
169 |
+
model_decoder.to('cuda')
|
170 |
+
if block_size <= 0:
|
171 |
+
block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
172 |
+
block_size = min(block_size, tokenizer_decoder.max_len_single_sentence)
|
173 |
+
|
174 |
+
# Load full model
|
175 |
+
output_full_dir = '/home/ryn_mote/Misc/generative_recommender/text_space/'
|
176 |
+
checkpoint = torch.load(os.path.join(model_path, 'training.bin'))
|
177 |
+
|
178 |
+
# Chunyuan: Add Padding token to GPT2
|
179 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
180 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
181 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
182 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
183 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
184 |
+
|
185 |
+
|
186 |
+
# Evaluation
|
187 |
+
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, SimpleNamespace(**{'latent_size': latent_size, 'device':'cuda'}))
|
188 |
+
model_vae.load_state_dict(checkpoint['model_state_dict'])
|
189 |
+
print("Pre-trained Optimus is successfully loaded")
|
190 |
+
model_vae.to('cuda').to(torch.bfloat16)
|
191 |
+
|
192 |
+
l = latent_code_from_text('A photo of a mountain.')[0]
|
193 |
+
t = text_from_latent_code(l)
|
194 |
+
print(t, l, l.shape)
|
195 |
+
################################################
|
196 |
+
|
197 |
+
import gradio as gr
|
198 |
+
import numpy as np
|
199 |
+
from sklearn.svm import SVC
|
200 |
+
from sklearn.inspection import permutation_importance
|
201 |
+
from sklearn import preprocessing
|
202 |
+
import pandas as pd
|
203 |
+
import random
|
204 |
+
import time
|
205 |
+
|
206 |
+
|
207 |
+
dtype = torch.bfloat16
|
208 |
+
torch.set_grad_enabled(False)
|
209 |
+
|
210 |
+
prompt_list = [p for p in list(set(
|
211 |
+
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
212 |
+
|
213 |
+
start_time = time.time()
|
214 |
+
|
215 |
+
####################### Setup Model
|
216 |
+
|
217 |
+
# TODO put back
|
218 |
+
# @spaces.GPU()
|
219 |
+
def generate(prompt, in_embs=None,):
|
220 |
+
if prompt != '':
|
221 |
+
print(prompt)
|
222 |
+
#in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None
|
223 |
+
in_embs = .9 * in_embs.to('cuda') + .5 * latent_code_from_text(prompt)[0] if in_embs != None else latent_code_from_text(prompt)[0]
|
224 |
+
else:
|
225 |
+
print('From embeds.')
|
226 |
+
in_embs = in_embs / in_embs.abs().max() * .6
|
227 |
+
in_embs = in_embs.to('cuda').to(torch.bfloat16)
|
228 |
+
plt.close('all')
|
229 |
+
plt.hist(np.array(in_embs.detach().to('cpu').to(torch.float)).flatten(), bins=5)
|
230 |
+
plt.savefig('real_im_emb_plot.jpg')
|
231 |
+
|
232 |
+
|
233 |
+
text = text_from_latent_code(in_embs)
|
234 |
+
in_embs = latent_code_from_text(text)[0]
|
235 |
+
print(text)
|
236 |
+
return text, in_embs.to('cpu')
|
237 |
+
|
238 |
+
|
239 |
+
#######################
|
240 |
+
|
241 |
+
# TODO add to state instead of shared across all
|
242 |
+
glob_idx = 0
|
243 |
+
|
244 |
+
def next_one(embs, ys, calibrate_prompts):
|
245 |
+
global glob_idx
|
246 |
+
glob_idx = glob_idx + 1
|
247 |
+
|
248 |
+
with torch.no_grad():
|
249 |
+
if len(calibrate_prompts) > 0:
|
250 |
+
print('######### Calibrating with sample prompts #########')
|
251 |
+
prompt = calibrate_prompts.pop(0)
|
252 |
+
text, img_embs = generate(prompt)
|
253 |
+
embs += img_embs
|
254 |
+
print(len(embs))
|
255 |
+
return text, embs, ys, calibrate_prompts
|
256 |
+
else:
|
257 |
+
print('######### Roaming #########')
|
258 |
+
|
259 |
+
|
260 |
+
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
|
261 |
+
if len(list(set(ys))) <= 1:
|
262 |
+
embs.append(.01*torch.randn(latent_size))
|
263 |
+
embs.append(.01*torch.randn(latent_size))
|
264 |
+
ys.append(0)
|
265 |
+
ys.append(1)
|
266 |
+
if len(list(ys)) < 10:
|
267 |
+
embs += [.01*torch.randn(latent_size)] * 3
|
268 |
+
ys += [0] * 3
|
269 |
+
|
270 |
+
pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
|
271 |
+
neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
|
272 |
+
|
273 |
+
# the embs & ys stay tied by index but we shuffle to drop randomly
|
274 |
+
random.shuffle(pos_indices)
|
275 |
+
random.shuffle(neg_indices)
|
276 |
+
|
277 |
+
#if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
|
278 |
+
# pos_indices = pos_indices[32:]
|
279 |
+
if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6:
|
280 |
+
pos_indices = pos_indices[5:]
|
281 |
+
if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6:
|
282 |
+
neg_indices = neg_indices[5:]
|
283 |
+
|
284 |
+
|
285 |
+
if len(neg_indices) > 25:
|
286 |
+
neg_indices = neg_indices[1:]
|
287 |
+
|
288 |
+
print(len(pos_indices), len(neg_indices))
|
289 |
+
indices = pos_indices + neg_indices
|
290 |
+
|
291 |
+
embs = [embs[i] for i in indices]
|
292 |
+
ys = [ys[i] for i in indices]
|
293 |
+
|
294 |
+
|
295 |
+
indices = list(range(len(embs)))
|
296 |
+
|
297 |
+
# also add the latest 0 and the latest 1
|
298 |
+
has_0 = False
|
299 |
+
has_1 = False
|
300 |
+
for i in reversed(range(len(ys))):
|
301 |
+
if ys[i] == 0 and has_0 == False:
|
302 |
+
indices.append(i)
|
303 |
+
has_0 = True
|
304 |
+
elif ys[i] == 1 and has_1 == False:
|
305 |
+
indices.append(i)
|
306 |
+
has_1 = True
|
307 |
+
if has_0 and has_1:
|
308 |
+
break
|
309 |
+
|
310 |
+
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
|
311 |
+
# this ends up adding a rating but losing an embedding, it seems.
|
312 |
+
# let's take off a rating if so to continue without indexing errors.
|
313 |
+
if len(ys) > len(embs):
|
314 |
+
print('ys are longer than embs; popping latest rating')
|
315 |
+
ys.pop(-1)
|
316 |
+
|
317 |
+
feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu'))
|
318 |
+
scaler = preprocessing.StandardScaler().fit(feature_embs)
|
319 |
+
feature_embs = scaler.transform(feature_embs)
|
320 |
+
chosen_y = np.array([ys[i] for i in indices])
|
321 |
+
|
322 |
+
print('Gathering coefficients')
|
323 |
+
lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y)
|
324 |
+
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
|
325 |
+
print(coef_.shape, 'COEF')
|
326 |
+
print('Gathered')
|
327 |
+
|
328 |
+
rng_prompt = random.choice(prompt_list)
|
329 |
+
w = 1# if len(embs) % 2 == 0 else 0
|
330 |
+
im_emb = w * coef_.to(dtype=dtype)
|
331 |
+
|
332 |
+
prompt= '' if glob_idx % 3 != 0 else rng_prompt
|
333 |
+
text, im_emb = generate(prompt, im_emb)
|
334 |
+
embs += im_emb
|
335 |
+
|
336 |
+
|
337 |
+
return text, embs, ys, calibrate_prompts
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
def start(_, embs, ys, calibrate_prompts):
|
348 |
+
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
|
349 |
+
return [
|
350 |
+
gr.Button(value='Like (L)', interactive=True),
|
351 |
+
gr.Button(value='Neither (Space)', interactive=True),
|
352 |
+
gr.Button(value='Dislike (A)', interactive=True),
|
353 |
+
gr.Button(value='Start', interactive=False),
|
354 |
+
text,
|
355 |
+
embs,
|
356 |
+
ys,
|
357 |
+
calibrate_prompts
|
358 |
+
]
|
359 |
+
|
360 |
+
|
361 |
+
def choose(text, choice, embs, ys, calibrate_prompts):
|
362 |
+
if choice == 'Like (L)':
|
363 |
+
choice = 1
|
364 |
+
elif choice == 'Neither (Space)':
|
365 |
+
embs = embs[:-1]
|
366 |
+
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
|
367 |
+
return text, embs, ys, calibrate_prompts
|
368 |
+
else:
|
369 |
+
choice = 0
|
370 |
+
|
371 |
+
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
372 |
+
# TODO skip allowing rating
|
373 |
+
if text == None:
|
374 |
+
print('NSFW -- choice is disliked')
|
375 |
+
choice = 0
|
376 |
+
|
377 |
+
ys += [choice]*1
|
378 |
+
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
|
379 |
+
return text, embs, ys, calibrate_prompts
|
380 |
+
|
381 |
+
css = '''.gradio-container{max-width: 700px !important}
|
382 |
+
#description{text-align: center}
|
383 |
+
#description h1, #description h3{display: block}
|
384 |
+
#description p{margin-top: 0}
|
385 |
+
.fade-in-out {animation: fadeInOut 3s forwards}
|
386 |
+
@keyframes fadeInOut {
|
387 |
+
0% {
|
388 |
+
background: var(--bg-color);
|
389 |
+
}
|
390 |
+
100% {
|
391 |
+
background: var(--button-secondary-background-fill);
|
392 |
+
}
|
393 |
+
}
|
394 |
+
'''
|
395 |
+
js_head = '''
|
396 |
+
<script>
|
397 |
+
document.addEventListener('keydown', function(event) {
|
398 |
+
if (event.key === 'a' || event.key === 'A') {
|
399 |
+
// Trigger click on 'dislike' if 'A' is pressed
|
400 |
+
document.getElementById('dislike').click();
|
401 |
+
} else if (event.key === ' ' || event.keyCode === 32) {
|
402 |
+
// Trigger click on 'neither' if Spacebar is pressed
|
403 |
+
document.getElementById('neither').click();
|
404 |
+
} else if (event.key === 'l' || event.key === 'L') {
|
405 |
+
// Trigger click on 'like' if 'L' is pressed
|
406 |
+
document.getElementById('like').click();
|
407 |
+
}
|
408 |
+
});
|
409 |
+
function fadeInOut(button, color) {
|
410 |
+
button.style.setProperty('--bg-color', color);
|
411 |
+
button.classList.remove('fade-in-out');
|
412 |
+
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
|
413 |
+
|
414 |
+
button.classList.add('fade-in-out');
|
415 |
+
button.addEventListener('animationend', () => {
|
416 |
+
button.classList.remove('fade-in-out'); // Reset the animation state
|
417 |
+
}, {once: true});
|
418 |
+
}
|
419 |
+
document.body.addEventListener('click', function(event) {
|
420 |
+
const target = event.target;
|
421 |
+
if (target.id === 'dislike') {
|
422 |
+
fadeInOut(target, '#ff1717');
|
423 |
+
} else if (target.id === 'like') {
|
424 |
+
fadeInOut(target, '#006500');
|
425 |
+
} else if (target.id === 'neither') {
|
426 |
+
fadeInOut(target, '#cccccc');
|
427 |
+
}
|
428 |
+
});
|
429 |
+
|
430 |
+
</script>
|
431 |
+
'''
|
432 |
+
|
433 |
+
with gr.Blocks(css=css, head=js_head) as demo:
|
434 |
+
gr.Markdown('''# Compass
|
435 |
+
### Generative Recommenders for Exporation of Text
|
436 |
+
|
437 |
+
Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
|
438 |
+
''', elem_id="description")
|
439 |
+
embs = gr.State([])
|
440 |
+
ys = gr.State([])
|
441 |
+
calibrate_prompts = gr.State([
|
442 |
+
'the moon is melting into my glass of tea',
|
443 |
+
'a sea slug -- pair of claws scuttling -- jelly fish glowing',
|
444 |
+
'an adorable creature. It may be a goblin or a pig or a slug.',
|
445 |
+
'an animation about a gorgeous nebula',
|
446 |
+
'a sketch of an impressive mountain by da vinci',
|
447 |
+
'a watercolor painting: the octopus writhes',
|
448 |
+
])
|
449 |
+
def l():
|
450 |
+
return None
|
451 |
+
|
452 |
+
with gr.Row(elem_id='output-image'):
|
453 |
+
text = gr.Textbox(interactive=False, elem_id="text")
|
454 |
+
with gr.Row(equal_height=True):
|
455 |
+
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
456 |
+
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
|
457 |
+
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
458 |
+
b1.click(
|
459 |
+
choose,
|
460 |
+
[text, b1, embs, ys, calibrate_prompts],
|
461 |
+
[text, embs, ys, calibrate_prompts]
|
462 |
+
)
|
463 |
+
b2.click(
|
464 |
+
choose,
|
465 |
+
[text, b2, embs, ys, calibrate_prompts],
|
466 |
+
[text, embs, ys, calibrate_prompts]
|
467 |
+
)
|
468 |
+
b3.click(
|
469 |
+
choose,
|
470 |
+
[text, b3, embs, ys, calibrate_prompts],
|
471 |
+
[text, embs, ys, calibrate_prompts]
|
472 |
+
)
|
473 |
+
with gr.Row():
|
474 |
+
b4 = gr.Button(value='Start')
|
475 |
+
b4.click(start,
|
476 |
+
[b4, embs, ys, calibrate_prompts],
|
477 |
+
[b1, b2, b3, b4, text, embs, ys, calibrate_prompts])
|
478 |
+
with gr.Row():
|
479 |
+
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
|
480 |
+
<div style='text-align:center; font-size:14px'>Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating.
|
481 |
+
</ div>
|
482 |
+
<br><br>
|
483 |
+
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
|
484 |
+
</ div>''')
|
485 |
+
|
486 |
+
demo.launch(share=True)
|
Optimus/code/examples/big_ae/metrics.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from multiprocessing import Pool
|
3 |
+
import pdb
|
4 |
+
import numpy as np
|
5 |
+
import nltk
|
6 |
+
nltk.download('punkt')
|
7 |
+
|
8 |
+
from nltk.translate.bleu_score import SmoothingFunction
|
9 |
+
|
10 |
+
try:
|
11 |
+
from multiprocessing import cpu_count
|
12 |
+
except:
|
13 |
+
from os import cpu_count
|
14 |
+
|
15 |
+
class Metrics(object):
|
16 |
+
def __init__(self):
|
17 |
+
self.name = 'Metric'
|
18 |
+
|
19 |
+
def get_name(self):
|
20 |
+
return self.name
|
21 |
+
|
22 |
+
def set_name(self, name):
|
23 |
+
self.name = name
|
24 |
+
|
25 |
+
def get_score(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
class Bleu(Metrics):
|
30 |
+
def __init__(self, test_text='', real_text='', gram=3, num_real_sentences=500, num_fake_sentences=10000):
|
31 |
+
super(Bleu, self).__init__()
|
32 |
+
self.name = 'Bleu'
|
33 |
+
self.test_data = test_text
|
34 |
+
self.real_data = real_text
|
35 |
+
self.gram = gram
|
36 |
+
self.sample_size = num_real_sentences
|
37 |
+
self.reference = None
|
38 |
+
self.is_first = True
|
39 |
+
self.num_sentences = num_fake_sentences
|
40 |
+
|
41 |
+
|
42 |
+
def get_name(self):
|
43 |
+
return self.name
|
44 |
+
|
45 |
+
def get_score(self, is_fast=True, ignore=False):
|
46 |
+
if ignore:
|
47 |
+
return 0
|
48 |
+
if self.is_first:
|
49 |
+
self.get_reference()
|
50 |
+
self.is_first = False
|
51 |
+
if is_fast:
|
52 |
+
return self.get_bleu_fast()
|
53 |
+
return self.get_bleu_parallel()
|
54 |
+
|
55 |
+
# fetch REAL DATA
|
56 |
+
def get_reference(self):
|
57 |
+
if self.reference is None:
|
58 |
+
reference = list()
|
59 |
+
with open(self.real_data) as real_data:
|
60 |
+
for text in real_data:
|
61 |
+
text = nltk.word_tokenize(text)
|
62 |
+
reference.append(text)
|
63 |
+
self.reference = reference
|
64 |
+
return reference
|
65 |
+
else:
|
66 |
+
return self.reference
|
67 |
+
|
68 |
+
def get_bleu(self):
|
69 |
+
raise Exception('make sure you call BLEU paralell')
|
70 |
+
ngram = self.gram
|
71 |
+
bleu = list()
|
72 |
+
reference = self.get_reference()
|
73 |
+
weight = tuple((1. / ngram for _ in range(ngram)))
|
74 |
+
with open(self.test_data) as test_data:
|
75 |
+
for hypothesis in test_data:
|
76 |
+
hypothesis = nltk.word_tokenize(hypothesis)
|
77 |
+
bleu.append(nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
|
78 |
+
smoothing_function=SmoothingFunction().method1))
|
79 |
+
return sum(bleu) / len(bleu)
|
80 |
+
|
81 |
+
def calc_bleu(self, reference, hypothesis, weight):
|
82 |
+
return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
|
83 |
+
smoothing_function=SmoothingFunction().method1)
|
84 |
+
|
85 |
+
def get_bleu_fast(self):
|
86 |
+
reference = self.get_reference()
|
87 |
+
reference = reference[0:self.sample_size]
|
88 |
+
return self.get_bleu_parallel(reference=reference)
|
89 |
+
|
90 |
+
def get_bleu_parallel(self, reference=None):
|
91 |
+
ngram = self.gram
|
92 |
+
if reference is None:
|
93 |
+
reference = self.get_reference()
|
94 |
+
weight = tuple((1. / ngram for _ in range(ngram)))
|
95 |
+
pool = Pool(cpu_count())
|
96 |
+
result = list()
|
97 |
+
maxx = self.num_sentences
|
98 |
+
with open(self.test_data) as test_data:
|
99 |
+
for i, hypothesis in enumerate(test_data):
|
100 |
+
#print('i : {}'.format(i))
|
101 |
+
hypothesis = nltk.word_tokenize(hypothesis)
|
102 |
+
result.append(pool.apply_async(self.calc_bleu, args=(reference, hypothesis, weight)))
|
103 |
+
if i > maxx : break
|
104 |
+
score = 0.0
|
105 |
+
cnt = 0
|
106 |
+
for it, i in enumerate(result):
|
107 |
+
#print('i : {}'.format(it))
|
108 |
+
score += i.get()
|
109 |
+
cnt += 1
|
110 |
+
pool.close()
|
111 |
+
pool.join()
|
112 |
+
return score / cnt
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
class SelfBleu(Metrics):
|
118 |
+
def __init__(self, test_text='', gram=3, model_path='', num_sentences=500):
|
119 |
+
super(SelfBleu, self).__init__()
|
120 |
+
self.name = 'Self-Bleu'
|
121 |
+
self.test_data = test_text
|
122 |
+
self.gram = gram
|
123 |
+
self.sample_size = num_sentences
|
124 |
+
self.reference = None
|
125 |
+
self.is_first = True
|
126 |
+
|
127 |
+
|
128 |
+
def get_name(self):
|
129 |
+
return self.name
|
130 |
+
|
131 |
+
def get_score(self, is_fast=True, ignore=False):
|
132 |
+
if ignore:
|
133 |
+
return 0
|
134 |
+
if self.is_first:
|
135 |
+
self.get_reference()
|
136 |
+
self.is_first = False
|
137 |
+
if is_fast:
|
138 |
+
return self.get_bleu_fast()
|
139 |
+
return self.get_bleu_parallel()
|
140 |
+
|
141 |
+
def get_reference(self):
|
142 |
+
if self.reference is None:
|
143 |
+
reference = list()
|
144 |
+
with open(self.test_data) as real_data:
|
145 |
+
for text in real_data:
|
146 |
+
text = nltk.word_tokenize(text)
|
147 |
+
reference.append(text)
|
148 |
+
self.reference = reference
|
149 |
+
return reference
|
150 |
+
else:
|
151 |
+
return self.reference
|
152 |
+
|
153 |
+
def get_bleu(self):
|
154 |
+
ngram = self.gram
|
155 |
+
bleu = list()
|
156 |
+
reference = self.get_reference()
|
157 |
+
weight = tuple((1. / ngram for _ in range(ngram)))
|
158 |
+
with open(self.test_data) as test_data:
|
159 |
+
for hypothesis in test_data:
|
160 |
+
hypothesis = nltk.word_tokenize(hypothesis)
|
161 |
+
bleu.append(nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
|
162 |
+
smoothing_function=SmoothingFunction().method1))
|
163 |
+
return sum(bleu) / len(bleu)
|
164 |
+
|
165 |
+
def calc_bleu(self, reference, hypothesis, weight):
|
166 |
+
return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
|
167 |
+
smoothing_function=SmoothingFunction().method1)
|
168 |
+
|
169 |
+
def get_bleu_fast(self):
|
170 |
+
reference = self.get_reference()
|
171 |
+
# random.shuffle(reference)
|
172 |
+
reference = reference[0:self.sample_size]
|
173 |
+
return self.get_bleu_parallel(reference=reference)
|
174 |
+
|
175 |
+
def get_bleu_parallel(self, reference=None):
|
176 |
+
ngram = self.gram
|
177 |
+
if reference is None:
|
178 |
+
reference = self.get_reference()
|
179 |
+
weight = tuple((1. / ngram for _ in range(ngram)))
|
180 |
+
pool = Pool(cpu_count())
|
181 |
+
result = list()
|
182 |
+
sentence_num = len(reference)
|
183 |
+
for index in range(sentence_num):
|
184 |
+
#genious:
|
185 |
+
hypothesis = reference[index]
|
186 |
+
other = reference[:index] + reference[index+1:]
|
187 |
+
result.append(pool.apply_async(self.calc_bleu, args=(other, hypothesis, weight)))
|
188 |
+
|
189 |
+
score = 0.0
|
190 |
+
cnt = 0
|
191 |
+
for i in result:
|
192 |
+
score += i.get()
|
193 |
+
cnt += 1
|
194 |
+
pool.close()
|
195 |
+
pool.join()
|
196 |
+
return score / cnt
|
Optimus/code/examples/big_ae/modules/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders import *
|
2 |
+
from .decoders import *
|
3 |
+
from .vae import *
|
4 |
+
from .utils import *
|
5 |
+
from .spacefusion import *
|
6 |
+
from .cara import *
|
7 |
+
from .arae import *
|
Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (327 Bytes). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (270 Bytes). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-310.pyc
ADDED
Binary file (6.64 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-37.pyc
ADDED
Binary file (6.44 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-310.pyc
ADDED
Binary file (8.63 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-37.pyc
ADDED
Binary file (8.41 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-310.pyc
ADDED
Binary file (4.44 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-37.pyc
ADDED
Binary file (4.37 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.34 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (1.28 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-310.pyc
ADDED
Binary file (14.8 kB). View file
|
|
Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-37.pyc
ADDED
Binary file (15 kB). View file
|
|
Optimus/code/examples/big_ae/modules/arae.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from .utils import log_sum_exp
|
5 |
+
import pdb
|
6 |
+
import sys
|
7 |
+
sys.path.append('../../')
|
8 |
+
from pytorch_transformers.modeling_bert import BertEmbeddings
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class ARAE(nn.Module):
|
13 |
+
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
|
14 |
+
super(ARAE, self).__init__()
|
15 |
+
self.encoder = encoder
|
16 |
+
self.decoder = decoder
|
17 |
+
self.tokenizer_encoder = tokenizer_encoder
|
18 |
+
self.tokenizer_decoder = tokenizer_decoder
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
self.nz = args.latent_size
|
22 |
+
|
23 |
+
self.bos_token_id_list = self.tokenizer_decoder.encode(self.tokenizer_decoder.bos_token)
|
24 |
+
self.pad_token_id = self.tokenizer_decoder.encode(self.tokenizer_decoder.pad_token)[0]
|
25 |
+
|
26 |
+
# connector: from Bert hidden units to the latent space
|
27 |
+
self.linear = nn.Linear(encoder.config.hidden_size, self.nz, bias=False)
|
28 |
+
|
29 |
+
# # Standard Normal prior
|
30 |
+
# loc = torch.zeros(self.nz, device=args.device)
|
31 |
+
# scale = torch.ones(self.nz, device=args.device)
|
32 |
+
# self.prior = torch.distributions.normal.Normal(loc, scale)
|
33 |
+
|
34 |
+
self.label_embedding = nn.Embedding(args.label_size, self.nz, padding_idx=0) # use the same size as latent_z so as to use the same decoder.linear()
|
35 |
+
self.latent_generator = nn.Linear(self.nz, self.nz)
|
36 |
+
self.latent_classifier = nn.Linear(self.nz, args.label_size if args.label_size > 2 else 1)
|
37 |
+
self.latent_discriminator = nn.Linear(self.nz, 1)
|
38 |
+
|
39 |
+
self.gpt_embeddings = nn.Embedding(self.decoder.config.vocab_size, self.decoder.config.n_embd)
|
40 |
+
self.gpt_embeddings.weight.data = decoder.transformer.wte.weight.data
|
41 |
+
|
42 |
+
self.conv1 = nn.Conv1d(self.encoder.config.hidden_size, self.encoder.config.hidden_size, 3)
|
43 |
+
self.classifier = nn.Linear(self.encoder.config.hidden_size, 1 if args.label_size <= 2 else args.label_size)
|
44 |
+
|
45 |
+
self.CrossEntropyLoss = torch.nn.CrossEntropyLoss()
|
46 |
+
self.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
|
47 |
+
|
48 |
+
def forward(self, input_seq_ids, tgt_seq_ids, cond_labels, attention_mask=None):
|
49 |
+
# inputs: (B, seq_len)
|
50 |
+
# labels: (B, seq_len)
|
51 |
+
# cond_labels: (B), conditional labels.
|
52 |
+
|
53 |
+
ones_label = torch.ones_like(cond_labels).to(dtype=torch.float32)
|
54 |
+
zeros_label = torch.zeros_like(cond_labels).to(dtype=torch.float32)
|
55 |
+
random_noise = torch.nn.init.normal_(torch.empty(input_seq_ids.size(0), self.nz)).to(device=input_seq_ids.device, dtype=torch.float32)
|
56 |
+
|
57 |
+
# Encode inputs
|
58 |
+
outputs = self.encoder(input_seq_ids, attention_mask=attention_mask)
|
59 |
+
pooled_hidden_fea = outputs[1] # (B, dim_h)
|
60 |
+
|
61 |
+
# Encode z
|
62 |
+
latent_z = self.linear(pooled_hidden_fea) # (B, nz)
|
63 |
+
|
64 |
+
# Generate z
|
65 |
+
gen_z = self.latent_generator(random_noise) # (B, nz)
|
66 |
+
|
67 |
+
# Latent discriminator
|
68 |
+
prob_encode_z_dis = self.latent_discriminator(latent_z).squeeze(1).float() # (B)
|
69 |
+
prob_gen_z_dis = self.latent_discriminator(gen_z).squeeze(1).float() # (B)
|
70 |
+
# Train latent discriminator
|
71 |
+
loss_lsd = self.BCEWithLogitsLoss(prob_gen_z_dis, zeros_label) + self.BCEWithLogitsLoss(prob_encode_z_dis, ones_label)
|
72 |
+
acc_encode_z_dis = ((prob_encode_z_dis >= 0).float() == ones_label).float()
|
73 |
+
acc_gen_z_dis = ((prob_gen_z_dis >= 0).float() == zeros_label).float()
|
74 |
+
# Train sampler adversarially
|
75 |
+
loss_lsg = self.BCEWithLogitsLoss(prob_gen_z_dis, ones_label)
|
76 |
+
|
77 |
+
# Latent classifier
|
78 |
+
prob_encode_z_cls = self.latent_classifier(latent_z) # (B, n_labels)
|
79 |
+
if self.args.label_size <= 2:
|
80 |
+
prob_encode_z_cls = prob_encode_z_cls.squeeze(1) # (B)
|
81 |
+
# Train latent classifier
|
82 |
+
loss_lsc = self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
|
83 |
+
acc_encode_z_cls = ((prob_encode_z_cls >= 0).float() == cond_labels.float()).float()
|
84 |
+
# Train encoder adversarially
|
85 |
+
loss_encoder = 1 - self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
|
86 |
+
else:
|
87 |
+
# Train latent classifier
|
88 |
+
loss_lsc = self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
|
89 |
+
acc_encode_z_cls = (torch.argmax(prob_encode_z_cls, dim=-1) == cond_labels).float()
|
90 |
+
# Train encoder adversarially
|
91 |
+
loss_encoder = 1 - self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
|
92 |
+
|
93 |
+
# Embed labels
|
94 |
+
label_emb = self.label_embedding(cond_labels) # (B, hidden_size)
|
95 |
+
past_label = self.decoder.linear(label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
|
96 |
+
if self.args.label_size <= 2:
|
97 |
+
sampled_cond_labels = 1 - cond_labels
|
98 |
+
else:
|
99 |
+
raise NotImplementedError # todo: currently only implemented for binary labels. need to change for multi-class labels.
|
100 |
+
sampled_label_emb = self.label_embedding(sampled_cond_labels) # (B, hidden_size)
|
101 |
+
past_sampled_label = self.decoder.linear(sampled_label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
|
102 |
+
|
103 |
+
# Generate based on encoded z and gt labels. (reconstruction)
|
104 |
+
past_z = self.decoder.linear(latent_z) # (B, n_blocks * hidden_size)
|
105 |
+
gen_past_z = self.decoder.linear(gen_z) # (B, n_blocks * hidden_size)
|
106 |
+
|
107 |
+
past = torch.cat([past_z.unsqueeze(1), past_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
108 |
+
outputs = self.decoder(input_ids=tgt_seq_ids, past=past, labels=tgt_seq_ids, label_ignore=self.pad_token_id)
|
109 |
+
loss_rec = outputs[0]
|
110 |
+
|
111 |
+
# Train a classifier in the observation space
|
112 |
+
tgt_emb = self.gpt_embeddings(tgt_seq_ids)
|
113 |
+
tgt_encode = self.conv1(tgt_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
114 |
+
tgt_encode = torch.mean(tgt_encode, dim=-1) # (B, dim_h)
|
115 |
+
prob_cls = self.classifier(tgt_encode) # (B, n_labels)
|
116 |
+
if self.args.label_size <= 2:
|
117 |
+
prob_cls = prob_cls.squeeze(1)
|
118 |
+
loss_cls = self.BCEWithLogitsLoss(prob_cls, cond_labels.float())
|
119 |
+
pred_cls = (prob_cls >= 0).to(dtype=torch.long)
|
120 |
+
else:
|
121 |
+
loss_cls = self.CrossEntropyLoss(prob_cls, cond_labels)
|
122 |
+
pred_cls = torch.argmax(prob_cls, dim=-1)
|
123 |
+
acc_cls = (pred_cls == cond_labels).float()
|
124 |
+
|
125 |
+
# Loss
|
126 |
+
loss = loss_rec + loss_encoder + loss_lsc + loss_lsd + loss_lsg + loss_cls
|
127 |
+
|
128 |
+
if not self.training:
|
129 |
+
# Generate based on encoded z and gt labels
|
130 |
+
generated = self.sample_sequence_conditional_batch(past=past, context=self.bos_token_id_list)
|
131 |
+
|
132 |
+
# Generate based on encoded z and sampled labels (attribute transfer)
|
133 |
+
at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
134 |
+
at_generated = self.sample_sequence_conditional_batch(past=at_past, context=self.bos_token_id_list) # (B, seq_len)
|
135 |
+
|
136 |
+
# Generate based on sampled z and sampled labels. (conditional generation)
|
137 |
+
cg_past = torch.cat([gen_past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
138 |
+
cg_generated = self.sample_sequence_conditional_batch(past=cg_past, context=self.bos_token_id_list) # (B, seq_len)
|
139 |
+
|
140 |
+
# classifier on gt generated sentences.
|
141 |
+
ge_emb = self.gpt_embeddings(generated)
|
142 |
+
ge_encode = self.conv1(ge_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
143 |
+
ge_encode = torch.mean(ge_encode, dim=-1) # (B, dim_h)
|
144 |
+
prob_ge_cls = self.classifier(ge_encode) # (B, 1)
|
145 |
+
|
146 |
+
if self.args.label_size <= 2:
|
147 |
+
pred_ge_cls = (prob_ge_cls.squeeze(1) >= 0).to(torch.long)
|
148 |
+
else:
|
149 |
+
pred_ge_cls = torch.argmax(prob_ge_cls, dim=-1)
|
150 |
+
acc_ge_cls = (pred_ge_cls == cond_labels).float()
|
151 |
+
|
152 |
+
# classifier on attribute transfer generated sentences.
|
153 |
+
at_emb = self.gpt_embeddings(at_generated)
|
154 |
+
at_encode = self.conv1(at_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
155 |
+
at_encode = torch.mean(at_encode, dim=-1) # (B, dim_h)
|
156 |
+
prob_at_cls = self.classifier(at_encode) # (B, 1)
|
157 |
+
if self.args.label_size <= 2:
|
158 |
+
pred_at_cls = (prob_at_cls.squeeze(1) >= 0).to(torch.long)
|
159 |
+
else:
|
160 |
+
pred_at_cls = torch.argmax(prob_at_cls, dim=-1)
|
161 |
+
acc_at_cls = (pred_at_cls == sampled_cond_labels).float()
|
162 |
+
|
163 |
+
# classifier on conditional generated sentences.
|
164 |
+
cg_emb = self.gpt_embeddings(cg_generated)
|
165 |
+
cg_encode = self.conv1(cg_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
166 |
+
cg_encode = torch.mean(cg_encode, dim=-1) # (B, dim_h)
|
167 |
+
prob_cg_cls = self.classifier(cg_encode) # (B, 1)
|
168 |
+
if self.args.label_size <= 2:
|
169 |
+
pred_cg_cls = (prob_cg_cls.squeeze(1) >= 0).to(torch.long)
|
170 |
+
else:
|
171 |
+
pred_cg_cls = torch.argmax(prob_cg_cls, dim=-1)
|
172 |
+
acc_cg_cls = (pred_cg_cls == sampled_cond_labels).float()
|
173 |
+
|
174 |
+
result = {
|
175 |
+
'sampled_cond_labels': sampled_cond_labels,
|
176 |
+
'cond_labels': cond_labels,
|
177 |
+
|
178 |
+
'tgt_seq_ids': tgt_seq_ids,
|
179 |
+
'generated': generated,
|
180 |
+
'at_generated': at_generated,
|
181 |
+
'cg_generated': cg_generated,
|
182 |
+
|
183 |
+
'acc_encode_z_dis': acc_encode_z_dis,
|
184 |
+
'acc_gen_z_dis': acc_gen_z_dis,
|
185 |
+
'acc_encode_z_cls': acc_encode_z_cls,
|
186 |
+
'acc_cls': acc_cls,
|
187 |
+
'acc_ge_cls': acc_ge_cls,
|
188 |
+
'acc_at_cls': acc_at_cls,
|
189 |
+
'acc_cg_cls': acc_cg_cls,
|
190 |
+
|
191 |
+
'pred_cls': pred_cls,
|
192 |
+
'pred_ge_cls': pred_ge_cls,
|
193 |
+
'pred_at_cls': pred_at_cls,
|
194 |
+
'pred_cg_cls': pred_cg_cls,
|
195 |
+
}
|
196 |
+
|
197 |
+
return result
|
198 |
+
|
199 |
+
loss_dict = {
|
200 |
+
'loss': loss,
|
201 |
+
'loss_rec': loss_rec,
|
202 |
+
'loss_encoder': loss_encoder,
|
203 |
+
'loss_lsc': loss_lsc,
|
204 |
+
'loss_lsd': loss_lsd,
|
205 |
+
'loss_lsg': loss_lsg,
|
206 |
+
'loss_cls': loss_cls,
|
207 |
+
}
|
208 |
+
acc_dict = {
|
209 |
+
'acc_encode_z_dis': acc_encode_z_dis,
|
210 |
+
'acc_gen_z_dis': acc_gen_z_dis,
|
211 |
+
'acc_encode_z_cls': acc_encode_z_cls,
|
212 |
+
'acc_cls': acc_cls,
|
213 |
+
}
|
214 |
+
return loss_dict, acc_dict
|
215 |
+
|
216 |
+
def sample_sequence_conditional_batch(self, past, context):
|
217 |
+
# context: a single id of <BOS>
|
218 |
+
# past: (B, past_seq_len dim_h)
|
219 |
+
num_samples = past.size(0)
|
220 |
+
context = torch.tensor(context, dtype=torch.long, device=past.device)
|
221 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
222 |
+
generated = context # (B, 1)
|
223 |
+
|
224 |
+
# with torch.no_grad():
|
225 |
+
while generated.size(-1) < self.args.block_size:
|
226 |
+
inputs = {'input_ids': generated, 'past': past}
|
227 |
+
outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
228 |
+
lm_logits = outputs[0]
|
229 |
+
next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
|
230 |
+
filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, vocab_size)
|
231 |
+
filtered_logits = F.softmax(filtered_logits, dim=-1)
|
232 |
+
next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
|
233 |
+
generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
|
234 |
+
|
235 |
+
not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
|
236 |
+
if torch.sum(not_finished) == 0:
|
237 |
+
break
|
238 |
+
|
239 |
+
return generated # (B, seq_len)
|
240 |
+
|
241 |
+
def top_k_top_p_filtering_batch(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
242 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
243 |
+
Args:
|
244 |
+
logits: logits distribution shape (vocabulary size)
|
245 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
246 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
247 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
248 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
249 |
+
"""
|
250 |
+
# assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
251 |
+
|
252 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
253 |
+
|
254 |
+
if top_k > 0:
|
255 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
256 |
+
threshold = torch.topk(logits, top_k, dim=-1)[0][:, -1, None]
|
257 |
+
logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size)
|
258 |
+
|
259 |
+
if top_p > 0.0:
|
260 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size)
|
261 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size)
|
262 |
+
|
263 |
+
# Remove tokens with cumulative probability above the threshold
|
264 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
265 |
+
|
266 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
267 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
268 |
+
sorted_indices_to_remove[..., 0] = 0
|
269 |
+
|
270 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
271 |
+
|
272 |
+
logits.masked_fill_(indices_to_remove, filter_value)
|
273 |
+
|
274 |
+
return logits
|
Optimus/code/examples/big_ae/modules/cara.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from .utils import log_sum_exp
|
5 |
+
import pdb
|
6 |
+
import sys
|
7 |
+
sys.path.append('../../')
|
8 |
+
from pytorch_transformers.modeling_bert import BertEmbeddings
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class CARA(nn.Module):
|
13 |
+
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
|
14 |
+
super(CARA, self).__init__()
|
15 |
+
self.encoder = encoder
|
16 |
+
self.decoder = decoder
|
17 |
+
self.tokenizer_encoder = tokenizer_encoder
|
18 |
+
self.tokenizer_decoder = tokenizer_decoder
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
self.nz = args.latent_size
|
22 |
+
|
23 |
+
self.bos_token_id_list = self.tokenizer_decoder.encode(self.tokenizer_decoder.bos_token)
|
24 |
+
self.pad_token_id = self.tokenizer_decoder.encode(self.tokenizer_decoder.pad_token)[0]
|
25 |
+
|
26 |
+
# connector: from Bert hidden units to the latent space
|
27 |
+
self.linear = nn.Linear(encoder.config.hidden_size, self.nz, bias=False)
|
28 |
+
|
29 |
+
# # Standard Normal prior
|
30 |
+
# loc = torch.zeros(self.nz, device=args.device)
|
31 |
+
# scale = torch.ones(self.nz, device=args.device)
|
32 |
+
# self.prior = torch.distributions.normal.Normal(loc, scale)
|
33 |
+
|
34 |
+
self.label_embedding = nn.Embedding(args.label_size, self.nz, padding_idx=0) # use the same size as latent_z so as to use the same decoder.linear()
|
35 |
+
self.latent_generator = nn.Linear(self.nz, self.nz)
|
36 |
+
self.latent_classifier = nn.Linear(self.nz, args.label_size if args.label_size > 2 else 1)
|
37 |
+
self.latent_discriminator = nn.Linear(self.nz, 1)
|
38 |
+
|
39 |
+
self.gpt_embeddings = nn.Embedding(self.decoder.config.vocab_size, self.decoder.config.n_embd)
|
40 |
+
self.gpt_embeddings.weight.data = decoder.transformer.wte.weight.data
|
41 |
+
|
42 |
+
self.conv1 = nn.Conv1d(self.encoder.config.hidden_size, self.encoder.config.hidden_size, 3)
|
43 |
+
self.classifier = nn.Linear(self.encoder.config.hidden_size, 1 if args.label_size <= 2 else args.label_size)
|
44 |
+
|
45 |
+
self.CrossEntropyLoss = torch.nn.CrossEntropyLoss()
|
46 |
+
self.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
|
47 |
+
|
48 |
+
def forward(self, input_seq_ids, tgt_seq_ids, cond_labels, attention_mask):
|
49 |
+
# inputs: (B, seq_len)
|
50 |
+
# labels: (B, seq_len)
|
51 |
+
# cond_labels: (B), conditional labels.
|
52 |
+
|
53 |
+
ones_label = torch.ones_like(cond_labels).to(dtype=torch.float32)
|
54 |
+
zeros_label = torch.zeros_like(cond_labels).to(dtype=torch.float32)
|
55 |
+
random_noise = torch.nn.init.normal_(torch.empty(input_seq_ids.size(0), self.nz)).to(device=input_seq_ids.device, dtype=torch.float32)
|
56 |
+
|
57 |
+
# Encode inputs
|
58 |
+
outputs = self.encoder(input_seq_ids, attention_mask=attention_mask)
|
59 |
+
pooled_hidden_fea = outputs[1] # (B, dim_h)
|
60 |
+
|
61 |
+
# Encode z
|
62 |
+
latent_z = self.linear(pooled_hidden_fea) # (B, nz)
|
63 |
+
|
64 |
+
# Generate z
|
65 |
+
gen_z = self.latent_generator(random_noise) # (B, nz)
|
66 |
+
|
67 |
+
#################### Latent discriminator for sampling from a simple distribution ####################
|
68 |
+
prob_encode_z_dis = self.latent_discriminator(latent_z).squeeze(1).float() # (B)
|
69 |
+
prob_gen_z_dis = self.latent_discriminator(gen_z).squeeze(1).float() # (B)
|
70 |
+
# Train latent discriminator
|
71 |
+
loss_lsd = self.BCEWithLogitsLoss(prob_gen_z_dis, zeros_label) + self.BCEWithLogitsLoss(prob_encode_z_dis, ones_label)
|
72 |
+
acc_encode_z_dis = ((prob_encode_z_dis >= 0).float() == ones_label).float()
|
73 |
+
acc_gen_z_dis = ((prob_gen_z_dis >= 0).float() == zeros_label).float()
|
74 |
+
# Train sampler adversarially
|
75 |
+
loss_lsg = self.BCEWithLogitsLoss(prob_gen_z_dis, ones_label)
|
76 |
+
|
77 |
+
#################### Latent classifier for disentanglement ####################
|
78 |
+
prob_encode_z_cls = self.latent_classifier(latent_z) # (B, n_labels)
|
79 |
+
if self.args.label_size <= 2:
|
80 |
+
prob_encode_z_cls = prob_encode_z_cls.squeeze(1) # (B)
|
81 |
+
# Train latent classifier
|
82 |
+
loss_lsc = self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
|
83 |
+
acc_encode_z_cls = ((prob_encode_z_cls >= 0).float() == cond_labels.float()).float()
|
84 |
+
# Train encoder adversarially
|
85 |
+
loss_encoder = 1 - self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
|
86 |
+
else:
|
87 |
+
# Train latent classifier
|
88 |
+
loss_lsc = self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
|
89 |
+
acc_encode_z_cls = (torch.argmax(prob_encode_z_cls, dim=-1) == cond_labels).float()
|
90 |
+
# Train encoder adversarially
|
91 |
+
loss_encoder = 1 - self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
|
92 |
+
|
93 |
+
|
94 |
+
#################### Recontruction loss with latent z and label emb ####################
|
95 |
+
# Embed labels
|
96 |
+
label_emb = self.label_embedding(cond_labels) # (B, hidden_size)
|
97 |
+
# past_label = self.decoder.linear(label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
|
98 |
+
if self.args.label_size <= 2:
|
99 |
+
sampled_cond_labels = 1 - cond_labels
|
100 |
+
else:
|
101 |
+
raise NotImplementedError # todo: currently only implemented for binary labels. need to change for multi-class labels.
|
102 |
+
sampled_label_emb = self.label_embedding(sampled_cond_labels) # (B, hidden_size)
|
103 |
+
# past_sampled_label = self.decoder.linear(sampled_label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
|
104 |
+
past_sampled_label = sampled_label_emb
|
105 |
+
|
106 |
+
# Generate based on encoded z and gt labels. (reconstruction)
|
107 |
+
# past_z = self.decoder.linear(latent_z) # (B, n_blocks * hidden_size)
|
108 |
+
past_z = latent_z
|
109 |
+
# gen_past_z = self.decoder.linear(gen_z) # (B, n_blocks * hidden_size)
|
110 |
+
gen_past_z = gen_z # (B, n_blocks * hidden_size)
|
111 |
+
|
112 |
+
# past = torch.cat([past_z.unsqueeze(1), past_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
113 |
+
|
114 |
+
past = latent_z + label_emb # (B, n_blocks * hidden_size)
|
115 |
+
|
116 |
+
outputs = self.decoder(input_ids=tgt_seq_ids, past=past, labels=tgt_seq_ids, label_ignore=self.pad_token_id)
|
117 |
+
loss_rec = outputs[0]
|
118 |
+
|
119 |
+
#################### Train a classifier in the observation space ####################
|
120 |
+
tgt_emb = self.gpt_embeddings(tgt_seq_ids)
|
121 |
+
tgt_encode = self.conv1(tgt_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
122 |
+
tgt_encode = torch.mean(tgt_encode, dim=-1) # (B, dim_h)
|
123 |
+
prob_cls = self.classifier(tgt_encode) # (B, n_labels)
|
124 |
+
if self.args.label_size <= 2:
|
125 |
+
prob_cls = prob_cls.squeeze(1)
|
126 |
+
loss_cls = self.BCEWithLogitsLoss(prob_cls, cond_labels.float())
|
127 |
+
pred_cls = (prob_cls >= 0).to(dtype=torch.long)
|
128 |
+
else:
|
129 |
+
loss_cls = self.CrossEntropyLoss(prob_cls, cond_labels)
|
130 |
+
pred_cls = torch.argmax(prob_cls, dim=-1)
|
131 |
+
acc_cls = (pred_cls == cond_labels).float()
|
132 |
+
|
133 |
+
# Generate based on encoded z and sampled labels (attribute transfer)
|
134 |
+
# at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
135 |
+
# at_generated_soft = self.sample_sequence_conditional_batch_soft(past=at_past, context=self.bos_token_id_list) # (B, seq_len, vocab_size)
|
136 |
+
|
137 |
+
# # Classifier on attribute transfer generated sentences. Train Generator on attribute transfer.
|
138 |
+
# at_soft_emb = torch.matmul(at_generated_soft, self.gpt_embeddings.weight)
|
139 |
+
# at_soft_encode = self.conv1(at_soft_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
140 |
+
# at_soft_encode = torch.mean(at_soft_encode, dim=-1) # (B, dim_h)
|
141 |
+
# prob_at_soft_cls = self.classifier(at_soft_encode) # (B, 1)
|
142 |
+
# if self.args.label_size <= 2:
|
143 |
+
# prob_at_soft_cls = prob_at_soft_cls.squeeze(1)
|
144 |
+
# loss_at_soft_cls = self.BCEWithLogitsLoss(prob_at_soft_cls, sampled_cond_labels.float())
|
145 |
+
# pred_at_soft_cls = (prob_at_soft_cls >= 0).to(torch.long)
|
146 |
+
# else:
|
147 |
+
# loss_at_soft_cls = self.CrossEntropyLoss(prob_at_soft_cls, sampled_cond_labels)
|
148 |
+
# pred_at_soft_cls = torch.argmax(prob_at_soft_cls, dim=-1)
|
149 |
+
# acc_at_soft_cls = (pred_at_soft_cls == sampled_cond_labels).float()
|
150 |
+
|
151 |
+
# Loss
|
152 |
+
loss_latent_space = (loss_encoder + loss_lsc) + (loss_lsd + loss_lsg) + self.args.beta_cls * loss_cls # + loss_at_soft_cls
|
153 |
+
loss = loss_rec + 0.0 * loss_latent_space
|
154 |
+
|
155 |
+
if not self.training:
|
156 |
+
# Generate based on encoded z and gt labels
|
157 |
+
generated = self.sample_sequence_conditional_batch(past=past, context=self.bos_token_id_list)
|
158 |
+
|
159 |
+
# Generate based on encoded z and sampled labels (attribute transfer)
|
160 |
+
# at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
161 |
+
at_past = past_z + past_sampled_label # (B, n_blocks * hidden_size)
|
162 |
+
at_generated = self.sample_sequence_conditional_batch(past=at_past, context=self.bos_token_id_list) # (B, seq_len)
|
163 |
+
|
164 |
+
# Generate based on sampled z and sampled labels. (conditional generation)
|
165 |
+
# cg_past = torch.cat([gen_past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
166 |
+
cg_past = gen_past_z + past_sampled_label # (B, n_blocks * hidden_size)
|
167 |
+
cg_generated = self.sample_sequence_conditional_batch(past=cg_past, context=self.bos_token_id_list) # (B, seq_len)
|
168 |
+
|
169 |
+
# classifier on gt generated sentences.
|
170 |
+
ge_emb = self.gpt_embeddings(generated)
|
171 |
+
ge_encode = self.conv1(ge_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
172 |
+
ge_encode = torch.mean(ge_encode, dim=-1) # (B, dim_h)
|
173 |
+
prob_ge_cls = self.classifier(ge_encode) # (B, 1)
|
174 |
+
|
175 |
+
if self.args.label_size <= 2:
|
176 |
+
pred_ge_cls = (prob_ge_cls.squeeze(1) >= 0).to(torch.long)
|
177 |
+
else:
|
178 |
+
pred_ge_cls = torch.argmax(prob_ge_cls, dim=-1)
|
179 |
+
acc_ge_cls = (pred_ge_cls == cond_labels).float()
|
180 |
+
|
181 |
+
# classifier on attribute transfer generated sentences.
|
182 |
+
at_emb = self.gpt_embeddings(at_generated)
|
183 |
+
at_encode = self.conv1(at_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
184 |
+
at_encode = torch.mean(at_encode, dim=-1) # (B, dim_h)
|
185 |
+
prob_at_cls = self.classifier(at_encode) # (B, 1)
|
186 |
+
if self.args.label_size <= 2:
|
187 |
+
pred_at_cls = (prob_at_cls.squeeze(1) >= 0).to(torch.long)
|
188 |
+
else:
|
189 |
+
pred_at_cls = torch.argmax(prob_at_cls, dim=-1)
|
190 |
+
acc_at_cls = (pred_at_cls == sampled_cond_labels).float()
|
191 |
+
|
192 |
+
# classifier on conditional generated sentences.
|
193 |
+
cg_emb = self.gpt_embeddings(cg_generated)
|
194 |
+
cg_encode = self.conv1(cg_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
195 |
+
cg_encode = torch.mean(cg_encode, dim=-1) # (B, dim_h)
|
196 |
+
prob_cg_cls = self.classifier(cg_encode) # (B, 1)
|
197 |
+
if self.args.label_size <= 2:
|
198 |
+
pred_cg_cls = (prob_cg_cls.squeeze(1) >= 0).to(torch.long)
|
199 |
+
else:
|
200 |
+
pred_cg_cls = torch.argmax(prob_cg_cls, dim=-1)
|
201 |
+
acc_cg_cls = (pred_cg_cls == sampled_cond_labels).float()
|
202 |
+
|
203 |
+
result = {
|
204 |
+
'sampled_cond_labels': sampled_cond_labels,
|
205 |
+
'cond_labels': cond_labels,
|
206 |
+
|
207 |
+
'tgt_seq_ids': tgt_seq_ids,
|
208 |
+
'generated': generated,
|
209 |
+
'at_generated': at_generated,
|
210 |
+
'cg_generated': cg_generated,
|
211 |
+
|
212 |
+
'acc_encode_z_dis': acc_encode_z_dis,
|
213 |
+
'acc_gen_z_dis': acc_gen_z_dis,
|
214 |
+
'acc_encode_z_cls': acc_encode_z_cls,
|
215 |
+
'acc_cls': acc_cls,
|
216 |
+
'acc_ge_cls': acc_ge_cls,
|
217 |
+
'acc_at_cls': acc_at_cls,
|
218 |
+
'acc_cg_cls': acc_cg_cls,
|
219 |
+
|
220 |
+
'pred_cls': pred_cls,
|
221 |
+
'pred_ge_cls': pred_ge_cls,
|
222 |
+
'pred_at_cls': pred_at_cls,
|
223 |
+
'pred_cg_cls': pred_cg_cls,
|
224 |
+
}
|
225 |
+
|
226 |
+
return result
|
227 |
+
|
228 |
+
loss_dict = {
|
229 |
+
'loss': loss,
|
230 |
+
'loss_rec': loss_rec,
|
231 |
+
'loss_encoder': loss_encoder,
|
232 |
+
'loss_lsc': loss_lsc,
|
233 |
+
'loss_lsd': loss_lsd,
|
234 |
+
'loss_lsg': loss_lsg,
|
235 |
+
'loss_cls': loss_cls,
|
236 |
+
# 'loss_at_soft_cls': loss_at_soft_cls,
|
237 |
+
}
|
238 |
+
acc_dict = {
|
239 |
+
'acc_encode_z_dis': acc_encode_z_dis,
|
240 |
+
'acc_gen_z_dis': acc_gen_z_dis,
|
241 |
+
'acc_encode_z_cls': acc_encode_z_cls,
|
242 |
+
'acc_cls': acc_cls,
|
243 |
+
# 'acc_at_soft_cls': acc_at_soft_cls,
|
244 |
+
}
|
245 |
+
return loss_dict, acc_dict
|
246 |
+
|
247 |
+
def sample_sequence_conditional_batch(self, past, context):
|
248 |
+
# context: a single id of <BOS>
|
249 |
+
# past: (B, past_seq_len dim_h)
|
250 |
+
num_samples = past.size(0)
|
251 |
+
context = torch.tensor(context, dtype=torch.long, device=past.device)
|
252 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
253 |
+
generated = context # (B, 1)
|
254 |
+
|
255 |
+
# with torch.no_grad():
|
256 |
+
while generated.size(-1) < self.args.block_size:
|
257 |
+
inputs = {'input_ids': generated, 'past': past}
|
258 |
+
outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
259 |
+
lm_logits = outputs[0]
|
260 |
+
|
261 |
+
# softmax sample
|
262 |
+
next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
|
263 |
+
filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
|
264 |
+
filtered_logits = F.softmax(filtered_logits, dim=-1)
|
265 |
+
next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
|
266 |
+
generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
|
267 |
+
|
268 |
+
not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
|
269 |
+
if torch.sum(not_finished) == 0:
|
270 |
+
break
|
271 |
+
|
272 |
+
return generated # (B, seq_len)
|
273 |
+
|
274 |
+
def top_k_top_p_filtering_batch(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
275 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
276 |
+
Args:
|
277 |
+
logits: logits distribution shape (vocabulary size)
|
278 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
279 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
280 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
281 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
282 |
+
"""
|
283 |
+
# assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
284 |
+
|
285 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
286 |
+
|
287 |
+
if top_k > 0:
|
288 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
289 |
+
threshold = torch.topk(logits, top_k, dim=-1)[0][:, -1, None]
|
290 |
+
logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size)
|
291 |
+
|
292 |
+
if top_p > 0.0:
|
293 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size)
|
294 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size)
|
295 |
+
|
296 |
+
# Remove tokens with cumulative probability above the threshold
|
297 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
298 |
+
|
299 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
300 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
301 |
+
sorted_indices_to_remove[..., 0] = 0
|
302 |
+
|
303 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
304 |
+
|
305 |
+
logits.masked_fill_(indices_to_remove, filter_value)
|
306 |
+
|
307 |
+
return logits
|
308 |
+
|
309 |
+
def sample_sequence_conditional_batch_soft(self, past, context):
|
310 |
+
# context: a single id of <BOS>
|
311 |
+
# past: (B, past_seq_len dim_h)
|
312 |
+
num_samples = past.size(0)
|
313 |
+
context = torch.tensor(context, dtype=torch.long, device=past.device).unsqueeze(0).repeat(num_samples, 1) # (B, 1)
|
314 |
+
context_soft = torch.FloatTensor(num_samples, self.decoder.config.vocab_size).zero_().to(device=past.device) # (B, vocab_size)
|
315 |
+
context_soft.scatter_(1, context, 1) # (B, vocab_size)
|
316 |
+
generated_soft = context_soft.unsqueeze(1) # (B, 1, vocab_size)
|
317 |
+
|
318 |
+
# with torch.no_grad():
|
319 |
+
while generated_soft.size(1) < self.args.block_size: # generated_soft: (B, seq_len, vocab_size)
|
320 |
+
inputs = {'soft_ids': generated_soft, 'past': past}
|
321 |
+
outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
322 |
+
lm_logits = outputs[0] # (B, seq_len, vocab_size)
|
323 |
+
|
324 |
+
# Gumbel softmax sample
|
325 |
+
next_tokens_soft = gumbel_softmax(logits=lm_logits[:, -1:, :], temperature=self.args.soft_temperature, hard=False) # (B, 1, vocab_size)
|
326 |
+
generated_soft = torch.cat((generated_soft, next_tokens_soft), dim=1) # (B, seq_len+1, vocab_size)
|
327 |
+
|
328 |
+
# # softmax sample
|
329 |
+
# next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
|
330 |
+
# filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
|
331 |
+
# filtered_logits = F.softmax(filtered_logits, dim=-1)
|
332 |
+
# next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
|
333 |
+
# generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
|
334 |
+
|
335 |
+
next_tokens = torch.argmax(next_tokens_soft, dim=-1) # (B, 1)
|
336 |
+
not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
|
337 |
+
if torch.sum(not_finished) == 0:
|
338 |
+
break
|
339 |
+
|
340 |
+
return generated_soft # (B, seq_len, vocab_size)
|
341 |
+
|
342 |
+
|
343 |
+
### Gumbel Softmax
|
344 |
+
def gumbel_softmax(logits, temperature, hard=False):
|
345 |
+
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
|
346 |
+
Args:
|
347 |
+
logits: [..., n_class] unnormalized log-probs
|
348 |
+
temperature: non-negative scalar
|
349 |
+
hard: if True, take argmax, but differentiate w.r.t. soft sample y
|
350 |
+
Returns:
|
351 |
+
[..., n_class] sample from the Gumbel-Softmax distribution.
|
352 |
+
If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes
|
353 |
+
"""
|
354 |
+
y = gumbel_softmax_sample(logits, temperature) # (..., n_class)
|
355 |
+
|
356 |
+
if hard: # return onehot
|
357 |
+
shape = y.size()
|
358 |
+
_, ind = y.max(dim=-1)
|
359 |
+
y_hard = torch.zeros_like(y).view(-1, shape[-1])
|
360 |
+
y_hard.scatter_(1, ind.view(-1, 1), 1) # one hot
|
361 |
+
y_hard = y_hard.view(*shape)
|
362 |
+
# Set gradients w.r.t. y_hard gradients w.r.t. y
|
363 |
+
y = (y_hard - y).detach() + y
|
364 |
+
|
365 |
+
return y # (..., n_class)
|
366 |
+
|
367 |
+
from torch.nn import functional as F
|
368 |
+
def gumbel_softmax_sample(logits, temperature):
|
369 |
+
y = logits + sample_gumbel(logits.size(), logits.device)
|
370 |
+
return F.softmax(y / temperature, dim=-1)
|
371 |
+
|
372 |
+
def sample_gumbel(shape, device, eps=1e-20):
|
373 |
+
U = torch.rand(shape).to(device=device)
|
374 |
+
return -torch.log(-torch.log(U + eps) + eps)
|
Optimus/code/examples/big_ae/modules/ctrl_gen.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from .utils import log_sum_exp
|
5 |
+
import pdb
|
6 |
+
import sys
|
7 |
+
sys.path.append('../../')
|
8 |
+
from pytorch_transformers.modeling_bert import BertEmbeddings
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class Ctrl_Gen(nn.Module):
|
13 |
+
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
|
14 |
+
super(Ctrl_Gen, self).__init__()
|
15 |
+
self.encoder = encoder
|
16 |
+
self.decoder = decoder
|
17 |
+
self.tokenizer_encoder = tokenizer_encoder
|
18 |
+
self.tokenizer_decoder = tokenizer_decoder
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
self.nz = args.latent_size
|
22 |
+
|
23 |
+
self.bos_token_id_list = self.tokenizer_decoder.encode(self.tokenizer_decoder.bos_token)
|
24 |
+
self.pad_token_id = self.tokenizer_decoder.encode(self.tokenizer_decoder.pad_token)[0]
|
25 |
+
|
26 |
+
# connector: from Bert hidden units to the latent space
|
27 |
+
self.linear = nn.Linear(encoder.config.hidden_size, self.nz, bias=False)
|
28 |
+
|
29 |
+
# # Standard Normal prior
|
30 |
+
# loc = torch.zeros(self.nz, device=args.device)
|
31 |
+
# scale = torch.ones(self.nz, device=args.device)
|
32 |
+
# self.prior = torch.distributions.normal.Normal(loc, scale)
|
33 |
+
|
34 |
+
self.label_embedding = nn.Embedding(args.label_size, self.nz, padding_idx=0) # use the same size as latent_z so as to use the same decoder.linear()
|
35 |
+
self.latent_generator = nn.Linear(self.nz, self.nz)
|
36 |
+
self.latent_classifier = nn.Linear(self.nz, args.label_size if args.label_size > 2 else 1)
|
37 |
+
self.latent_discriminator = nn.Linear(self.nz, 1)
|
38 |
+
|
39 |
+
self.gpt_embeddings = nn.Embedding(self.decoder.config.vocab_size, self.decoder.config.n_embd)
|
40 |
+
self.gpt_embeddings.weight.data = decoder.transformer.wte.weight.data
|
41 |
+
|
42 |
+
self.conv1 = nn.Conv1d(self.encoder.config.hidden_size, self.encoder.config.hidden_size, 3)
|
43 |
+
self.classifier = nn.Linear(self.encoder.config.hidden_size, 1 if args.label_size <= 2 else args.label_size)
|
44 |
+
|
45 |
+
self.CrossEntropyLoss = torch.nn.CrossEntropyLoss()
|
46 |
+
self.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
|
47 |
+
|
48 |
+
def forward(self, input_seq_ids, tgt_seq_ids, cond_labels, attention_mask):
|
49 |
+
# inputs: (B, seq_len)
|
50 |
+
# labels: (B, seq_len)
|
51 |
+
# cond_labels: (B), conditional labels.
|
52 |
+
|
53 |
+
ones_label = torch.ones_like(cond_labels).to(dtype=torch.float32)
|
54 |
+
zeros_label = torch.zeros_like(cond_labels).to(dtype=torch.float32)
|
55 |
+
random_noise = torch.nn.init.normal_(torch.empty(input_seq_ids.size(0), self.nz)).to(device=input_seq_ids.device, dtype=torch.float32)
|
56 |
+
|
57 |
+
# Encode inputs
|
58 |
+
outputs = self.encoder(input_seq_ids, attention_mask=attention_mask)
|
59 |
+
pooled_hidden_fea = outputs[1] # (B, dim_h)
|
60 |
+
|
61 |
+
# Encode z
|
62 |
+
latent_z = self.linear(pooled_hidden_fea) # (B, nz)
|
63 |
+
|
64 |
+
# Generate z
|
65 |
+
gen_z = self.latent_generator(random_noise) # (B, nz)
|
66 |
+
|
67 |
+
# Latent discriminator
|
68 |
+
prob_encode_z_dis = self.latent_discriminator(latent_z).squeeze(1).float() # (B)
|
69 |
+
prob_gen_z_dis = self.latent_discriminator(gen_z).squeeze(1).float() # (B)
|
70 |
+
# Train latent discriminator
|
71 |
+
loss_lsd = self.BCEWithLogitsLoss(prob_gen_z_dis, zeros_label) + self.BCEWithLogitsLoss(prob_encode_z_dis, ones_label)
|
72 |
+
acc_encode_z_dis = ((prob_encode_z_dis >= 0).float() == ones_label).float()
|
73 |
+
acc_gen_z_dis = ((prob_gen_z_dis >= 0).float() == zeros_label).float()
|
74 |
+
# Train sampler adversarially
|
75 |
+
loss_lsg = self.BCEWithLogitsLoss(prob_gen_z_dis, ones_label)
|
76 |
+
|
77 |
+
# Latent classifier
|
78 |
+
prob_encode_z_cls = self.latent_classifier(latent_z) # (B, n_labels)
|
79 |
+
if self.args.label_size <= 2:
|
80 |
+
prob_encode_z_cls = prob_encode_z_cls.squeeze(1) # (B)
|
81 |
+
# Train latent classifier
|
82 |
+
loss_lsc = self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
|
83 |
+
acc_encode_z_cls = ((prob_encode_z_cls >= 0).float() == cond_labels.float()).float()
|
84 |
+
# Train encoder adversarially
|
85 |
+
loss_encoder = 1 - self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
|
86 |
+
else:
|
87 |
+
# Train latent classifier
|
88 |
+
loss_lsc = self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
|
89 |
+
acc_encode_z_cls = (torch.argmax(prob_encode_z_cls, dim=-1) == cond_labels).float()
|
90 |
+
# Train encoder adversarially
|
91 |
+
loss_encoder = 1 - self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
|
92 |
+
|
93 |
+
# Embed labels
|
94 |
+
label_emb = self.label_embedding(cond_labels) # (B, hidden_size)
|
95 |
+
# past_label = self.decoder.linear(label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
|
96 |
+
if self.args.label_size <= 2:
|
97 |
+
sampled_cond_labels = 1 - cond_labels
|
98 |
+
else:
|
99 |
+
raise NotImplementedError # todo: currently only implemented for binary labels. need to change for multi-class labels.
|
100 |
+
sampled_label_emb = self.label_embedding(sampled_cond_labels) # (B, hidden_size)
|
101 |
+
# past_sampled_label = self.decoder.linear(sampled_label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
|
102 |
+
past_sampled_label = sampled_label_emb
|
103 |
+
|
104 |
+
# Generate based on encoded z and gt labels. (reconstruction)
|
105 |
+
# past_z = self.decoder.linear(latent_z) # (B, n_blocks * hidden_size)
|
106 |
+
past_z = latent_z
|
107 |
+
# gen_past_z = self.decoder.linear(gen_z) # (B, n_blocks * hidden_size)
|
108 |
+
gen_past_z = gen_z # (B, n_blocks * hidden_size)
|
109 |
+
|
110 |
+
# past = torch.cat([past_z.unsqueeze(1), past_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
111 |
+
|
112 |
+
past = latent_z + label_emb # (B, n_blocks * hidden_size)
|
113 |
+
|
114 |
+
outputs = self.decoder(input_ids=tgt_seq_ids, past=past, labels=tgt_seq_ids, label_ignore=self.pad_token_id)
|
115 |
+
loss_rec = outputs[0]
|
116 |
+
|
117 |
+
# Train a classifier in the observation space
|
118 |
+
tgt_emb = self.gpt_embeddings(tgt_seq_ids)
|
119 |
+
tgt_encode = self.conv1(tgt_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
120 |
+
tgt_encode = torch.mean(tgt_encode, dim=-1) # (B, dim_h)
|
121 |
+
prob_cls = self.classifier(tgt_encode) # (B, n_labels)
|
122 |
+
if self.args.label_size <= 2:
|
123 |
+
prob_cls = prob_cls.squeeze(1)
|
124 |
+
loss_cls = self.BCEWithLogitsLoss(prob_cls, cond_labels.float())
|
125 |
+
pred_cls = (prob_cls >= 0).to(dtype=torch.long)
|
126 |
+
else:
|
127 |
+
loss_cls = self.CrossEntropyLoss(prob_cls, cond_labels)
|
128 |
+
pred_cls = torch.argmax(prob_cls, dim=-1)
|
129 |
+
acc_cls = (pred_cls == cond_labels).float()
|
130 |
+
|
131 |
+
# Generate based on encoded z and sampled labels (attribute transfer)
|
132 |
+
# at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
133 |
+
# at_generated_soft = self.sample_sequence_conditional_batch_soft(past=at_past, context=self.bos_token_id_list) # (B, seq_len, vocab_size)
|
134 |
+
|
135 |
+
# # Classifier on attribute transfer generated sentences. Train Generator on attribute transfer.
|
136 |
+
# at_soft_emb = torch.matmul(at_generated_soft, self.gpt_embeddings.weight)
|
137 |
+
# at_soft_encode = self.conv1(at_soft_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
138 |
+
# at_soft_encode = torch.mean(at_soft_encode, dim=-1) # (B, dim_h)
|
139 |
+
# prob_at_soft_cls = self.classifier(at_soft_encode) # (B, 1)
|
140 |
+
# if self.args.label_size <= 2:
|
141 |
+
# prob_at_soft_cls = prob_at_soft_cls.squeeze(1)
|
142 |
+
# loss_at_soft_cls = self.BCEWithLogitsLoss(prob_at_soft_cls, sampled_cond_labels.float())
|
143 |
+
# pred_at_soft_cls = (prob_at_soft_cls >= 0).to(torch.long)
|
144 |
+
# else:
|
145 |
+
# loss_at_soft_cls = self.CrossEntropyLoss(prob_at_soft_cls, sampled_cond_labels)
|
146 |
+
# pred_at_soft_cls = torch.argmax(prob_at_soft_cls, dim=-1)
|
147 |
+
# acc_at_soft_cls = (pred_at_soft_cls == sampled_cond_labels).float()
|
148 |
+
|
149 |
+
# Loss
|
150 |
+
loss = loss_rec + loss_encoder + loss_lsc + loss_lsd + loss_lsg + self.args.beta_cls * loss_cls # + loss_at_soft_cls
|
151 |
+
|
152 |
+
if not self.training:
|
153 |
+
# Generate based on encoded z and gt labels
|
154 |
+
generated = self.sample_sequence_conditional_batch(past=past, context=self.bos_token_id_list)
|
155 |
+
|
156 |
+
# Generate based on encoded z and sampled labels (attribute transfer)
|
157 |
+
# at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
158 |
+
at_past = past_z + past_sampled_label # (B, n_blocks * hidden_size)
|
159 |
+
at_generated = self.sample_sequence_conditional_batch(past=at_past, context=self.bos_token_id_list) # (B, seq_len)
|
160 |
+
|
161 |
+
# Generate based on sampled z and sampled labels. (conditional generation)
|
162 |
+
# cg_past = torch.cat([gen_past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
|
163 |
+
cg_past = gen_past_z + past_sampled_label # (B, n_blocks * hidden_size)
|
164 |
+
cg_generated = self.sample_sequence_conditional_batch(past=cg_past, context=self.bos_token_id_list) # (B, seq_len)
|
165 |
+
|
166 |
+
# classifier on gt generated sentences.
|
167 |
+
ge_emb = self.gpt_embeddings(generated)
|
168 |
+
ge_encode = self.conv1(ge_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
169 |
+
ge_encode = torch.mean(ge_encode, dim=-1) # (B, dim_h)
|
170 |
+
prob_ge_cls = self.classifier(ge_encode) # (B, 1)
|
171 |
+
|
172 |
+
if self.args.label_size <= 2:
|
173 |
+
pred_ge_cls = (prob_ge_cls.squeeze(1) >= 0).to(torch.long)
|
174 |
+
else:
|
175 |
+
pred_ge_cls = torch.argmax(prob_ge_cls, dim=-1)
|
176 |
+
acc_ge_cls = (pred_ge_cls == cond_labels).float()
|
177 |
+
|
178 |
+
# classifier on attribute transfer generated sentences.
|
179 |
+
at_emb = self.gpt_embeddings(at_generated)
|
180 |
+
at_encode = self.conv1(at_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
181 |
+
at_encode = torch.mean(at_encode, dim=-1) # (B, dim_h)
|
182 |
+
prob_at_cls = self.classifier(at_encode) # (B, 1)
|
183 |
+
if self.args.label_size <= 2:
|
184 |
+
pred_at_cls = (prob_at_cls.squeeze(1) >= 0).to(torch.long)
|
185 |
+
else:
|
186 |
+
pred_at_cls = torch.argmax(prob_at_cls, dim=-1)
|
187 |
+
acc_at_cls = (pred_at_cls == sampled_cond_labels).float()
|
188 |
+
|
189 |
+
# classifier on conditional generated sentences.
|
190 |
+
cg_emb = self.gpt_embeddings(cg_generated)
|
191 |
+
cg_encode = self.conv1(cg_emb.transpose(1, 2)) # (B, dim_h, seq_len)
|
192 |
+
cg_encode = torch.mean(cg_encode, dim=-1) # (B, dim_h)
|
193 |
+
prob_cg_cls = self.classifier(cg_encode) # (B, 1)
|
194 |
+
if self.args.label_size <= 2:
|
195 |
+
pred_cg_cls = (prob_cg_cls.squeeze(1) >= 0).to(torch.long)
|
196 |
+
else:
|
197 |
+
pred_cg_cls = torch.argmax(prob_cg_cls, dim=-1)
|
198 |
+
acc_cg_cls = (pred_cg_cls == sampled_cond_labels).float()
|
199 |
+
|
200 |
+
result = {
|
201 |
+
'sampled_cond_labels': sampled_cond_labels,
|
202 |
+
'cond_labels': cond_labels,
|
203 |
+
|
204 |
+
'tgt_seq_ids': tgt_seq_ids,
|
205 |
+
'generated': generated,
|
206 |
+
'at_generated': at_generated,
|
207 |
+
'cg_generated': cg_generated,
|
208 |
+
|
209 |
+
'acc_encode_z_dis': acc_encode_z_dis,
|
210 |
+
'acc_gen_z_dis': acc_gen_z_dis,
|
211 |
+
'acc_encode_z_cls': acc_encode_z_cls,
|
212 |
+
'acc_cls': acc_cls,
|
213 |
+
'acc_ge_cls': acc_ge_cls,
|
214 |
+
'acc_at_cls': acc_at_cls,
|
215 |
+
'acc_cg_cls': acc_cg_cls,
|
216 |
+
|
217 |
+
'pred_cls': pred_cls,
|
218 |
+
'pred_ge_cls': pred_ge_cls,
|
219 |
+
'pred_at_cls': pred_at_cls,
|
220 |
+
'pred_cg_cls': pred_cg_cls,
|
221 |
+
}
|
222 |
+
|
223 |
+
return result
|
224 |
+
|
225 |
+
loss_dict = {
|
226 |
+
'loss': loss,
|
227 |
+
'loss_rec': loss_rec,
|
228 |
+
'loss_encoder': loss_encoder,
|
229 |
+
'loss_lsc': loss_lsc,
|
230 |
+
'loss_lsd': loss_lsd,
|
231 |
+
'loss_lsg': loss_lsg,
|
232 |
+
'loss_cls': loss_cls,
|
233 |
+
# 'loss_at_soft_cls': loss_at_soft_cls,
|
234 |
+
}
|
235 |
+
acc_dict = {
|
236 |
+
'acc_encode_z_dis': acc_encode_z_dis,
|
237 |
+
'acc_gen_z_dis': acc_gen_z_dis,
|
238 |
+
'acc_encode_z_cls': acc_encode_z_cls,
|
239 |
+
'acc_cls': acc_cls,
|
240 |
+
# 'acc_at_soft_cls': acc_at_soft_cls,
|
241 |
+
}
|
242 |
+
return loss_dict, acc_dict
|
243 |
+
|
244 |
+
def sample_sequence_conditional_batch(self, past, context):
|
245 |
+
# context: a single id of <BOS>
|
246 |
+
# past: (B, past_seq_len dim_h)
|
247 |
+
num_samples = past.size(0)
|
248 |
+
context = torch.tensor(context, dtype=torch.long, device=past.device)
|
249 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
250 |
+
generated = context # (B, 1)
|
251 |
+
|
252 |
+
# with torch.no_grad():
|
253 |
+
while generated.size(-1) < self.args.block_size:
|
254 |
+
inputs = {'input_ids': generated, 'past': past}
|
255 |
+
outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
256 |
+
lm_logits = outputs[0]
|
257 |
+
|
258 |
+
# softmax sample
|
259 |
+
next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
|
260 |
+
filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
|
261 |
+
filtered_logits = F.softmax(filtered_logits, dim=-1)
|
262 |
+
next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
|
263 |
+
generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
|
264 |
+
|
265 |
+
not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
|
266 |
+
if torch.sum(not_finished) == 0:
|
267 |
+
break
|
268 |
+
|
269 |
+
return generated # (B, seq_len)
|
270 |
+
|
271 |
+
def top_k_top_p_filtering_batch(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
272 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
273 |
+
Args:
|
274 |
+
logits: logits distribution shape (vocabulary size)
|
275 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
276 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
277 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
278 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
279 |
+
"""
|
280 |
+
# assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
281 |
+
|
282 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
283 |
+
|
284 |
+
if top_k > 0:
|
285 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
286 |
+
threshold = torch.topk(logits, top_k, dim=-1)[0][:, -1, None]
|
287 |
+
logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size)
|
288 |
+
|
289 |
+
if top_p > 0.0:
|
290 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size)
|
291 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size)
|
292 |
+
|
293 |
+
# Remove tokens with cumulative probability above the threshold
|
294 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
295 |
+
|
296 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
297 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
298 |
+
sorted_indices_to_remove[..., 0] = 0
|
299 |
+
|
300 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
301 |
+
|
302 |
+
logits.masked_fill_(indices_to_remove, filter_value)
|
303 |
+
|
304 |
+
return logits
|
305 |
+
|
306 |
+
def sample_sequence_conditional_batch_soft(self, past, context):
|
307 |
+
# context: a single id of <BOS>
|
308 |
+
# past: (B, past_seq_len dim_h)
|
309 |
+
num_samples = past.size(0)
|
310 |
+
context = torch.tensor(context, dtype=torch.long, device=past.device).unsqueeze(0).repeat(num_samples, 1) # (B, 1)
|
311 |
+
context_soft = torch.FloatTensor(num_samples, self.decoder.config.vocab_size).zero_().to(device=past.device) # (B, vocab_size)
|
312 |
+
context_soft.scatter_(1, context, 1) # (B, vocab_size)
|
313 |
+
generated_soft = context_soft.unsqueeze(1) # (B, 1, vocab_size)
|
314 |
+
|
315 |
+
# with torch.no_grad():
|
316 |
+
while generated_soft.size(1) < self.args.block_size: # generated_soft: (B, seq_len, vocab_size)
|
317 |
+
inputs = {'soft_ids': generated_soft, 'past': past}
|
318 |
+
outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
319 |
+
lm_logits = outputs[0] # (B, seq_len, vocab_size)
|
320 |
+
|
321 |
+
# Gumbel softmax sample
|
322 |
+
next_tokens_soft = gumbel_softmax(logits=lm_logits[:, -1:, :], temperature=self.args.soft_temperature, hard=False) # (B, 1, vocab_size)
|
323 |
+
generated_soft = torch.cat((generated_soft, next_tokens_soft), dim=1) # (B, seq_len+1, vocab_size)
|
324 |
+
|
325 |
+
# # softmax sample
|
326 |
+
# next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
|
327 |
+
# filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
|
328 |
+
# filtered_logits = F.softmax(filtered_logits, dim=-1)
|
329 |
+
# next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
|
330 |
+
# generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
|
331 |
+
|
332 |
+
next_tokens = torch.argmax(next_tokens_soft, dim=-1) # (B, 1)
|
333 |
+
not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
|
334 |
+
if torch.sum(not_finished) == 0:
|
335 |
+
break
|
336 |
+
|
337 |
+
return generated_soft # (B, seq_len, vocab_size)
|
338 |
+
|
339 |
+
|
340 |
+
### Gumbel Softmax
|
341 |
+
def gumbel_softmax(logits, temperature, hard=False):
|
342 |
+
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
|
343 |
+
Args:
|
344 |
+
logits: [..., n_class] unnormalized log-probs
|
345 |
+
temperature: non-negative scalar
|
346 |
+
hard: if True, take argmax, but differentiate w.r.t. soft sample y
|
347 |
+
Returns:
|
348 |
+
[..., n_class] sample from the Gumbel-Softmax distribution.
|
349 |
+
If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes
|
350 |
+
"""
|
351 |
+
y = gumbel_softmax_sample(logits, temperature) # (..., n_class)
|
352 |
+
|
353 |
+
if hard: # return onehot
|
354 |
+
shape = y.size()
|
355 |
+
_, ind = y.max(dim=-1)
|
356 |
+
y_hard = torch.zeros_like(y).view(-1, shape[-1])
|
357 |
+
y_hard.scatter_(1, ind.view(-1, 1), 1) # one hot
|
358 |
+
y_hard = y_hard.view(*shape)
|
359 |
+
# Set gradients w.r.t. y_hard gradients w.r.t. y
|
360 |
+
y = (y_hard - y).detach() + y
|
361 |
+
|
362 |
+
return y # (..., n_class)
|
363 |
+
|
364 |
+
from torch.nn import functional as F
|
365 |
+
def gumbel_softmax_sample(logits, temperature):
|
366 |
+
y = logits + sample_gumbel(logits.size(), logits.device)
|
367 |
+
return F.softmax(y / temperature, dim=-1)
|
368 |
+
|
369 |
+
def sample_gumbel(shape, device, eps=1e-20):
|
370 |
+
U = torch.rand(shape).to(device=device)
|
371 |
+
return -torch.log(-torch.log(U + eps) + eps)
|
Optimus/code/examples/big_ae/modules/decoders/dec_gpt2.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import torch
|
2 |
+
|
3 |
+
import time
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from .decoder import DecoderBase
|
15 |
+
|
16 |
+
class LSTMDecoder(DecoderBase):
|
17 |
+
"""LSTM decoder with constant-length data"""
|
18 |
+
def __init__(self, args, vocab, model_init, emb_init):
|
19 |
+
super(LSTMDecoder, self).__init__()
|
20 |
+
self.ni = args.ni
|
21 |
+
self.nh = args.dec_nh
|
22 |
+
self.nz = args.nz
|
23 |
+
self.vocab = vocab
|
24 |
+
self.device = args.device
|
25 |
+
|
26 |
+
# no padding when setting padding_idx to -1
|
27 |
+
self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1)
|
28 |
+
|
29 |
+
self.dropout_in = nn.Dropout(args.dec_dropout_in)
|
30 |
+
self.dropout_out = nn.Dropout(args.dec_dropout_out)
|
31 |
+
|
32 |
+
# for initializing hidden state and cell
|
33 |
+
self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False)
|
34 |
+
|
35 |
+
# concatenate z with input
|
36 |
+
self.lstm = nn.LSTM(input_size=args.ni + args.nz,
|
37 |
+
hidden_size=args.dec_nh,
|
38 |
+
num_layers=1,
|
39 |
+
batch_first=True)
|
40 |
+
|
41 |
+
# prediction layer
|
42 |
+
self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False)
|
43 |
+
|
44 |
+
vocab_mask = torch.ones(len(vocab))
|
45 |
+
# vocab_mask[vocab['<pad>']] = 0
|
46 |
+
self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
|
47 |
+
|
48 |
+
self.reset_parameters(model_init, emb_init)
|
49 |
+
|
50 |
+
def reset_parameters(self, model_init, emb_init):
|
51 |
+
# for name, param in self.lstm.named_parameters():
|
52 |
+
# # self.initializer(param)
|
53 |
+
# if 'bias' in name:
|
54 |
+
# nn.init.constant_(param, 0.0)
|
55 |
+
# # model_init(param)
|
56 |
+
# elif 'weight' in name:
|
57 |
+
# model_init(param)
|
58 |
+
|
59 |
+
# model_init(self.trans_linear.weight)
|
60 |
+
# model_init(self.pred_linear.weight)
|
61 |
+
for param in self.parameters():
|
62 |
+
model_init(param)
|
63 |
+
emb_init(self.embed.weight)
|
64 |
+
|
65 |
+
def sample_text(self, input, z, EOS, device):
|
66 |
+
sentence = [input]
|
67 |
+
max_index = 0
|
68 |
+
|
69 |
+
input_word = input
|
70 |
+
batch_size, n_sample, _ = z.size()
|
71 |
+
seq_len = 1
|
72 |
+
z_ = z.expand(batch_size, seq_len, self.nz)
|
73 |
+
seq_len = input.size(1)
|
74 |
+
softmax = torch.nn.Softmax(dim=0)
|
75 |
+
while max_index != EOS and len(sentence) < 100:
|
76 |
+
# (batch_size, seq_len, ni)
|
77 |
+
word_embed = self.embed(input_word)
|
78 |
+
word_embed = torch.cat((word_embed, z_), -1)
|
79 |
+
c_init = self.trans_linear(z).unsqueeze(0)
|
80 |
+
h_init = torch.tanh(c_init)
|
81 |
+
if len(sentence) == 1:
|
82 |
+
h_init = h_init.squeeze(dim=1)
|
83 |
+
c_init = c_init.squeeze(dim=1)
|
84 |
+
output, hidden = self.lstm.forward(word_embed, (h_init, c_init))
|
85 |
+
else:
|
86 |
+
output, hidden = self.lstm.forward(word_embed, hidden)
|
87 |
+
# (batch_size * n_sample, seq_len, vocab_size)
|
88 |
+
output_logits = self.pred_linear(output)
|
89 |
+
output_logits = output_logits.view(-1)
|
90 |
+
probs = softmax(output_logits)
|
91 |
+
# max_index = torch.argmax(output_logits)
|
92 |
+
max_index = torch.multinomial(probs, num_samples=1)
|
93 |
+
input_word = torch.tensor([[max_index]]).to(device)
|
94 |
+
sentence.append(max_index)
|
95 |
+
return sentence
|
96 |
+
|
97 |
+
def decode(self, input, z):
|
98 |
+
"""
|
99 |
+
Args:
|
100 |
+
input: (batch_size, seq_len)
|
101 |
+
z: (batch_size, n_sample, nz)
|
102 |
+
"""
|
103 |
+
|
104 |
+
# not predicting start symbol
|
105 |
+
# sents_len -= 1
|
106 |
+
|
107 |
+
batch_size, n_sample, _ = z.size()
|
108 |
+
seq_len = input.size(1)
|
109 |
+
|
110 |
+
# (batch_size, seq_len, ni)
|
111 |
+
word_embed = self.embed(input)
|
112 |
+
word_embed = self.dropout_in(word_embed)
|
113 |
+
|
114 |
+
if n_sample == 1:
|
115 |
+
z_ = z.expand(batch_size, seq_len, self.nz)
|
116 |
+
|
117 |
+
else:
|
118 |
+
word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
|
119 |
+
.contiguous()
|
120 |
+
|
121 |
+
# (batch_size * n_sample, seq_len, ni)
|
122 |
+
word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
|
123 |
+
|
124 |
+
z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
|
125 |
+
z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
|
126 |
+
|
127 |
+
# (batch_size * n_sample, seq_len, ni + nz)
|
128 |
+
word_embed = torch.cat((word_embed, z_), -1)
|
129 |
+
|
130 |
+
z = z.view(batch_size * n_sample, self.nz)
|
131 |
+
c_init = self.trans_linear(z).unsqueeze(0)
|
132 |
+
h_init = torch.tanh(c_init)
|
133 |
+
# h_init = self.trans_linear(z).unsqueeze(0)
|
134 |
+
# c_init = h_init.new_zeros(h_init.size())
|
135 |
+
output, _ = self.lstm(word_embed, (h_init, c_init))
|
136 |
+
|
137 |
+
output = self.dropout_out(output)
|
138 |
+
|
139 |
+
# (batch_size * n_sample, seq_len, vocab_size)
|
140 |
+
output_logits = self.pred_linear(output)
|
141 |
+
|
142 |
+
return output_logits
|
143 |
+
|
144 |
+
def reconstruct_error(self, x, z):
|
145 |
+
"""Cross Entropy in the language case
|
146 |
+
Args:
|
147 |
+
x: (batch_size, seq_len)
|
148 |
+
z: (batch_size, n_sample, nz)
|
149 |
+
Returns:
|
150 |
+
loss: (batch_size, n_sample). Loss
|
151 |
+
across different sentence and z
|
152 |
+
"""
|
153 |
+
|
154 |
+
#remove end symbol
|
155 |
+
src = x[:, :-1]
|
156 |
+
|
157 |
+
# remove start symbol
|
158 |
+
tgt = x[:, 1:]
|
159 |
+
|
160 |
+
batch_size, seq_len = src.size()
|
161 |
+
n_sample = z.size(1)
|
162 |
+
|
163 |
+
# (batch_size * n_sample, seq_len, vocab_size)
|
164 |
+
output_logits = self.decode(src, z)
|
165 |
+
|
166 |
+
if n_sample == 1:
|
167 |
+
tgt = tgt.contiguous().view(-1)
|
168 |
+
else:
|
169 |
+
# (batch_size * n_sample * seq_len)
|
170 |
+
tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
|
171 |
+
.contiguous().view(-1)
|
172 |
+
|
173 |
+
# (batch_size * n_sample * seq_len)
|
174 |
+
loss = self.loss(output_logits.view(-1, output_logits.size(2)),
|
175 |
+
tgt)
|
176 |
+
|
177 |
+
|
178 |
+
# (batch_size, n_sample)
|
179 |
+
return loss.view(batch_size, n_sample, -1).sum(-1)
|
180 |
+
|
181 |
+
|
182 |
+
def log_probability(self, x, z):
|
183 |
+
"""Cross Entropy in the language case
|
184 |
+
Args:
|
185 |
+
x: (batch_size, seq_len)
|
186 |
+
z: (batch_size, n_sample, nz)
|
187 |
+
Returns:
|
188 |
+
log_p: (batch_size, n_sample).
|
189 |
+
log_p(x|z) across different x and z
|
190 |
+
"""
|
191 |
+
|
192 |
+
return -self.reconstruct_error(x, z)
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
def greedy_decode(self, z):
|
198 |
+
return self.sample_decode(z, greedy=True)
|
199 |
+
|
200 |
+
def sample_decode(self, z, greedy=False):
|
201 |
+
"""sample/greedy decoding from z
|
202 |
+
Args:
|
203 |
+
z: (batch_size, nz)
|
204 |
+
Returns: List1
|
205 |
+
List1: the decoded word sentence list
|
206 |
+
"""
|
207 |
+
|
208 |
+
batch_size = z.size(0)
|
209 |
+
decoded_batch = [[] for _ in range(batch_size)]
|
210 |
+
|
211 |
+
# (batch_size, 1, nz)
|
212 |
+
c_init = self.trans_linear(z).unsqueeze(0)
|
213 |
+
h_init = torch.tanh(c_init)
|
214 |
+
|
215 |
+
decoder_hidden = (h_init, c_init)
|
216 |
+
decoder_input = torch.tensor([self.vocab["<s>"]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1)
|
217 |
+
end_symbol = torch.tensor([self.vocab["</s>"]] * batch_size, dtype=torch.long, device=self.device)
|
218 |
+
|
219 |
+
mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device)
|
220 |
+
length_c = 1
|
221 |
+
while mask.sum().item() != 0 and length_c < 100:
|
222 |
+
|
223 |
+
# (batch_size, 1, ni) --> (batch_size, 1, ni+nz)
|
224 |
+
word_embed = self.embed(decoder_input)
|
225 |
+
word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1)
|
226 |
+
|
227 |
+
output, decoder_hidden = self.lstm(word_embed, decoder_hidden)
|
228 |
+
|
229 |
+
# (batch_size, 1, vocab_size) --> (batch_size, vocab_size)
|
230 |
+
decoder_output = self.pred_linear(output)
|
231 |
+
output_logits = decoder_output.squeeze(1)
|
232 |
+
|
233 |
+
# (batch_size)
|
234 |
+
if greedy:
|
235 |
+
max_index = torch.argmax(output_logits, dim=1)
|
236 |
+
else:
|
237 |
+
probs = F.softmax(output_logits, dim=1)
|
238 |
+
max_index = torch.multinomial(probs, num_samples=1).squeeze(1)
|
239 |
+
|
240 |
+
decoder_input = max_index.unsqueeze(1)
|
241 |
+
length_c += 1
|
242 |
+
|
243 |
+
for i in range(batch_size):
|
244 |
+
word = self.vocab.id2word(max_index[i].item())
|
245 |
+
if mask[i].item():
|
246 |
+
decoded_batch[i].append(self.vocab.id2word(max_index[i].item()))
|
247 |
+
|
248 |
+
mask = torch.mul((max_index != end_symbol), mask)
|
249 |
+
|
250 |
+
return decoded_batch
|
251 |
+
|
252 |
+
class VarLSTMDecoder(LSTMDecoder):
|
253 |
+
"""LSTM decoder with constant-length data"""
|
254 |
+
def __init__(self, args, vocab, model_init, emb_init):
|
255 |
+
super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init)
|
256 |
+
|
257 |
+
self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['<pad>'])
|
258 |
+
vocab_mask = torch.ones(len(vocab))
|
259 |
+
vocab_mask[vocab['<pad>']] = 0
|
260 |
+
self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
|
261 |
+
|
262 |
+
self.reset_parameters(model_init, emb_init)
|
263 |
+
|
264 |
+
def decode(self, input, z):
|
265 |
+
"""
|
266 |
+
Args:
|
267 |
+
input: tuple which contains x and sents_len
|
268 |
+
x: (batch_size, seq_len)
|
269 |
+
sents_len: long tensor of sentence lengths
|
270 |
+
z: (batch_size, n_sample, nz)
|
271 |
+
"""
|
272 |
+
|
273 |
+
input, sents_len = input
|
274 |
+
|
275 |
+
# not predicting start symbol
|
276 |
+
sents_len = sents_len - 1
|
277 |
+
|
278 |
+
batch_size, n_sample, _ = z.size()
|
279 |
+
seq_len = input.size(1)
|
280 |
+
|
281 |
+
# (batch_size, seq_len, ni)
|
282 |
+
word_embed = self.embed(input)
|
283 |
+
word_embed = self.dropout_in(word_embed)
|
284 |
+
|
285 |
+
if n_sample == 1:
|
286 |
+
z_ = z.expand(batch_size, seq_len, self.nz)
|
287 |
+
|
288 |
+
else:
|
289 |
+
word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
|
290 |
+
.contiguous()
|
291 |
+
|
292 |
+
# (batch_size * n_sample, seq_len, ni)
|
293 |
+
word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
|
294 |
+
|
295 |
+
z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
|
296 |
+
z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
|
297 |
+
|
298 |
+
# (batch_size * n_sample, seq_len, ni + nz)
|
299 |
+
word_embed = torch.cat((word_embed, z_), -1)
|
300 |
+
|
301 |
+
sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1)
|
302 |
+
packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True)
|
303 |
+
|
304 |
+
z = z.view(batch_size * n_sample, self.nz)
|
305 |
+
# h_init = self.trans_linear(z).unsqueeze(0)
|
306 |
+
# c_init = h_init.new_zeros(h_init.size())
|
307 |
+
c_init = self.trans_linear(z).unsqueeze(0)
|
308 |
+
h_init = torch.tanh(c_init)
|
309 |
+
output, _ = self.lstm(packed_embed, (h_init, c_init))
|
310 |
+
output, _ = pad_packed_sequence(output, batch_first=True)
|
311 |
+
|
312 |
+
output = self.dropout_out(output)
|
313 |
+
|
314 |
+
# (batch_size * n_sample, seq_len, vocab_size)
|
315 |
+
output_logits = self.pred_linear(output)
|
316 |
+
|
317 |
+
return output_logits
|
318 |
+
|
319 |
+
def reconstruct_error(self, x, z):
|
320 |
+
"""Cross Entropy in the language case
|
321 |
+
Args:
|
322 |
+
x: tuple which contains x_ and sents_len
|
323 |
+
x_: (batch_size, seq_len)
|
324 |
+
sents_len: long tensor of sentence lengths
|
325 |
+
z: (batch_size, n_sample, nz)
|
326 |
+
Returns:
|
327 |
+
loss: (batch_size, n_sample). Loss
|
328 |
+
across different sentence and z
|
329 |
+
"""
|
330 |
+
|
331 |
+
x, sents_len = x
|
332 |
+
|
333 |
+
#remove end symbol
|
334 |
+
src = x[:, :-1]
|
335 |
+
|
336 |
+
# remove start symbol
|
337 |
+
tgt = x[:, 1:]
|
338 |
+
|
339 |
+
batch_size, seq_len = src.size()
|
340 |
+
n_sample = z.size(1)
|
341 |
+
|
342 |
+
# (batch_size * n_sample, seq_len, vocab_size)
|
343 |
+
output_logits = self.decode((src, sents_len), z)
|
344 |
+
|
345 |
+
if n_sample == 1:
|
346 |
+
tgt = tgt.contiguous().view(-1)
|
347 |
+
else:
|
348 |
+
# (batch_size * n_sample * seq_len)
|
349 |
+
tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
|
350 |
+
.contiguous().view(-1)
|
351 |
+
|
352 |
+
# (batch_size * n_sample * seq_len)
|
353 |
+
loss = self.loss(output_logits.view(-1, output_logits.size(2)),
|
354 |
+
tgt)
|
355 |
+
|
356 |
+
|
357 |
+
# (batch_size, n_sample)
|
358 |
+
return loss.view(batch_size, n_sample, -1).sum(-1)
|
Optimus/code/examples/big_ae/modules/decoders/decoder.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class DecoderBase(nn.Module):
|
6 |
+
"""docstring for Decoder"""
|
7 |
+
def __init__(self):
|
8 |
+
super(DecoderBase, self).__init__()
|
9 |
+
|
10 |
+
|
11 |
+
def freeze(self):
|
12 |
+
for param in self.parameters():
|
13 |
+
param.requires_grad = False
|
14 |
+
|
15 |
+
def decode(self, x, z):
|
16 |
+
"""
|
17 |
+
Args:
|
18 |
+
x: (batch_size, seq_len)
|
19 |
+
z: (batch_size, n_sample, nz)
|
20 |
+
Returns: Tensor1
|
21 |
+
Tensor1: the output logits with size (batch_size * n_sample, seq_len, vocab_size)
|
22 |
+
"""
|
23 |
+
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def reconstruct_error(self, x, z):
|
27 |
+
"""reconstruction loss
|
28 |
+
Args:
|
29 |
+
x: (batch_size, *)
|
30 |
+
z: (batch_size, n_sample, nz)
|
31 |
+
Returns:
|
32 |
+
loss: (batch_size, n_sample). Loss
|
33 |
+
across different sentence and z
|
34 |
+
"""
|
35 |
+
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
def beam_search_decode(self, z, K):
|
39 |
+
"""beam search decoding
|
40 |
+
Args:
|
41 |
+
z: (batch_size, nz)
|
42 |
+
K: the beam size
|
43 |
+
Returns: List1
|
44 |
+
List1: the decoded word sentence list
|
45 |
+
"""
|
46 |
+
|
47 |
+
raise NotImplementedError
|
48 |
+
|
49 |
+
def sample_decode(self, z):
|
50 |
+
"""sampling from z
|
51 |
+
Args:
|
52 |
+
z: (batch_size, nz)
|
53 |
+
Returns: List1
|
54 |
+
List1: the decoded word sentence list
|
55 |
+
"""
|
56 |
+
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
def greedy_decode(self, z):
|
60 |
+
"""greedy decoding from z
|
61 |
+
Args:
|
62 |
+
z: (batch_size, nz)
|
63 |
+
Returns: List1
|
64 |
+
List1: the decoded word sentence list
|
65 |
+
"""
|
66 |
+
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
def log_probability(self, x, z):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
x: (batch_size, *)
|
73 |
+
z: (batch_size, n_sample, nz)
|
74 |
+
Returns:
|
75 |
+
log_p: (batch_size, n_sample).
|
76 |
+
log_p(x|z) across different x and z
|
77 |
+
"""
|
78 |
+
|
79 |
+
raise NotImplementedError
|
Optimus/code/examples/big_ae/modules/encoders/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .enc_lstm import *
|
Optimus/code/examples/big_ae/modules/encoders/enc_lstm.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import chain
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
|
7 |
+
from .gaussian_encoder import GaussianEncoderBase
|
8 |
+
from ..utils import log_sum_exp
|
9 |
+
|
10 |
+
class GaussianLSTMEncoder(GaussianEncoderBase):
|
11 |
+
"""Gaussian LSTM Encoder with constant-length input"""
|
12 |
+
def __init__(self, args, vocab_size, model_init, emb_init):
|
13 |
+
super(GaussianLSTMEncoder, self).__init__()
|
14 |
+
self.ni = args.ni
|
15 |
+
self.nh = args.enc_nh
|
16 |
+
self.nz = args.nz
|
17 |
+
self.args = args
|
18 |
+
|
19 |
+
self.embed = nn.Embedding(vocab_size, args.ni)
|
20 |
+
|
21 |
+
self.lstm = nn.LSTM(input_size=args.ni,
|
22 |
+
hidden_size=args.enc_nh,
|
23 |
+
num_layers=1,
|
24 |
+
batch_first=True,
|
25 |
+
dropout=0)
|
26 |
+
|
27 |
+
self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False)
|
28 |
+
|
29 |
+
self.reset_parameters(model_init, emb_init)
|
30 |
+
|
31 |
+
def reset_parameters(self, model_init, emb_init):
|
32 |
+
# for name, param in self.lstm.named_parameters():
|
33 |
+
# # self.initializer(param)
|
34 |
+
# if 'bias' in name:
|
35 |
+
# nn.init.constant_(param, 0.0)
|
36 |
+
# # model_init(param)
|
37 |
+
# elif 'weight' in name:
|
38 |
+
# model_init(param)
|
39 |
+
|
40 |
+
# model_init(self.linear.weight)
|
41 |
+
# emb_init(self.embed.weight)
|
42 |
+
for param in self.parameters():
|
43 |
+
model_init(param)
|
44 |
+
emb_init(self.embed.weight)
|
45 |
+
|
46 |
+
|
47 |
+
def forward(self, input):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
x: (batch_size, seq_len)
|
51 |
+
Returns: Tensor1, Tensor2
|
52 |
+
Tensor1: the mean tensor, shape (batch, nz)
|
53 |
+
Tensor2: the logvar tensor, shape (batch, nz)
|
54 |
+
"""
|
55 |
+
|
56 |
+
# (batch_size, seq_len-1, args.ni)
|
57 |
+
word_embed = self.embed(input)
|
58 |
+
|
59 |
+
_, (last_state, last_cell) = self.lstm(word_embed)
|
60 |
+
|
61 |
+
mean, logvar = self.linear(last_state).chunk(2, -1)
|
62 |
+
|
63 |
+
# fix variance as a pre-defined value
|
64 |
+
if self.args.fix_var > 0:
|
65 |
+
logvar = mean.new_tensor([[[math.log(self.args.fix_var)]]]).expand_as(mean)
|
66 |
+
|
67 |
+
return mean.squeeze(0), logvar.squeeze(0)
|
68 |
+
|
69 |
+
# def eval_inference_mode(self, x):
|
70 |
+
# """compute the mode points in the inference distribution
|
71 |
+
# (in Gaussian case)
|
72 |
+
# Returns: Tensor
|
73 |
+
# Tensor: the posterior mode points with shape (*, nz)
|
74 |
+
# """
|
75 |
+
|
76 |
+
# # (batch_size, nz)
|
77 |
+
# mu, logvar = self.forward(x)
|
78 |
+
|
79 |
+
|
80 |
+
class VarLSTMEncoder(GaussianLSTMEncoder):
|
81 |
+
"""Gaussian LSTM Encoder with variable-length input"""
|
82 |
+
def __init__(self, args, vocab_size, model_init, emb_init):
|
83 |
+
super(VarLSTMEncoder, self).__init__(args, vocab_size, model_init, emb_init)
|
84 |
+
|
85 |
+
|
86 |
+
def forward(self, input):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
input: tuple which contains x and sents_len
|
90 |
+
x: (batch_size, seq_len)
|
91 |
+
sents_len: long tensor of sentence lengths
|
92 |
+
Returns: Tensor1, Tensor2
|
93 |
+
Tensor1: the mean tensor, shape (batch, nz)
|
94 |
+
Tensor2: the logvar tensor, shape (batch, nz)
|
95 |
+
"""
|
96 |
+
|
97 |
+
input, sents_len = input
|
98 |
+
# (batch_size, seq_len, args.ni)
|
99 |
+
word_embed = self.embed(input)
|
100 |
+
|
101 |
+
packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True)
|
102 |
+
|
103 |
+
_, (last_state, last_cell) = self.lstm(packed_embed)
|
104 |
+
|
105 |
+
mean, logvar = self.linear(last_state).chunk(2, -1)
|
106 |
+
|
107 |
+
return mean.squeeze(0), logvar.squeeze(0)
|
108 |
+
|
109 |
+
def encode(self, input, nsamples):
|
110 |
+
"""perform the encoding and compute the KL term
|
111 |
+
Args:
|
112 |
+
input: tuple which contains x and sents_len
|
113 |
+
Returns: Tensor1, Tensor2
|
114 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
115 |
+
Tensor2: the tenor of KL for each x with shape [batch]
|
116 |
+
"""
|
117 |
+
|
118 |
+
# (batch_size, nz)
|
119 |
+
mu, logvar = self.forward(input)
|
120 |
+
|
121 |
+
# (batch, nsamples, nz)
|
122 |
+
z = self.reparameterize(mu, logvar, nsamples)
|
123 |
+
|
124 |
+
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
|
125 |
+
|
126 |
+
return z, KL
|
Optimus/code/examples/big_ae/modules/encoders/encoder.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from ..utils import log_sum_exp
|
6 |
+
|
7 |
+
class EncoderBase(nn.Module):
|
8 |
+
"""docstring for EncoderBase"""
|
9 |
+
def __init__(self):
|
10 |
+
super(EncoderBase, self).__init__()
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
x: (batch_size, *)
|
16 |
+
Returns: the tensors required to parameterize a distribution.
|
17 |
+
E.g. for Gaussian encoder it returns the mean and variance tensors
|
18 |
+
"""
|
19 |
+
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def sample(self, input, nsamples):
|
23 |
+
"""sampling from the encoder
|
24 |
+
Returns: Tensor1
|
25 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
26 |
+
"""
|
27 |
+
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
def encode(self, input, nsamples):
|
31 |
+
"""perform the encoding and compute the KL term
|
32 |
+
Returns: Tensor1, Tensor2
|
33 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
34 |
+
Tensor2: the tenor of KL for each x with shape [batch]
|
35 |
+
"""
|
36 |
+
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
|
40 |
+
def eval_inference_dist(self, x, z, param=None):
|
41 |
+
"""this function computes log q(z | x)
|
42 |
+
Args:
|
43 |
+
z: tensor
|
44 |
+
different z points that will be evaluated, with
|
45 |
+
shape [batch, nsamples, nz]
|
46 |
+
Returns: Tensor1
|
47 |
+
Tensor1: log q(z|x) with shape [batch, nsamples]
|
48 |
+
"""
|
49 |
+
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
def calc_mi(self, x):
|
53 |
+
"""Approximate the mutual information between x and z
|
54 |
+
I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
|
55 |
+
Returns: Float
|
56 |
+
"""
|
57 |
+
|
58 |
+
raise NotImplementedError
|
Optimus/code/examples/big_ae/modules/encoders/gaussian_encoder.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .encoder import EncoderBase
|
6 |
+
from ..utils import log_sum_exp
|
7 |
+
|
8 |
+
class GaussianEncoderBase(EncoderBase):
|
9 |
+
"""docstring for EncoderBase"""
|
10 |
+
def __init__(self):
|
11 |
+
super(GaussianEncoderBase, self).__init__()
|
12 |
+
|
13 |
+
def freeze(self):
|
14 |
+
for param in self.parameters():
|
15 |
+
param.requires_grad = False
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
x: (batch_size, *)
|
21 |
+
Returns: Tensor1, Tensor2
|
22 |
+
Tensor1: the mean tensor, shape (batch, nz)
|
23 |
+
Tensor2: the logvar tensor, shape (batch, nz)
|
24 |
+
"""
|
25 |
+
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
def encode_stats(self, x):
|
29 |
+
|
30 |
+
return self.forward(x)
|
31 |
+
|
32 |
+
def sample(self, input, nsamples):
|
33 |
+
"""sampling from the encoder
|
34 |
+
Returns: Tensor1
|
35 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
36 |
+
"""
|
37 |
+
|
38 |
+
# (batch_size, nz)
|
39 |
+
mu, logvar = self.forward(input)
|
40 |
+
|
41 |
+
# (batch, nsamples, nz)
|
42 |
+
z = self.reparameterize(mu, logvar, nsamples)
|
43 |
+
|
44 |
+
return z, (mu, logvar)
|
45 |
+
|
46 |
+
def encode(self, input, nsamples):
|
47 |
+
"""perform the encoding and compute the KL term
|
48 |
+
Returns: Tensor1, Tensor2
|
49 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
50 |
+
Tensor2: the tenor of KL for each x with shape [batch]
|
51 |
+
"""
|
52 |
+
|
53 |
+
# (batch_size, nz)
|
54 |
+
mu, logvar = self.forward(input)
|
55 |
+
|
56 |
+
# (batch, nsamples, nz)
|
57 |
+
z = self.reparameterize(mu, logvar, nsamples)
|
58 |
+
|
59 |
+
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
|
60 |
+
|
61 |
+
return z, KL
|
62 |
+
|
63 |
+
def reparameterize(self, mu, logvar, nsamples=1):
|
64 |
+
"""sample from posterior Gaussian family
|
65 |
+
Args:
|
66 |
+
mu: Tensor
|
67 |
+
Mean of gaussian distribution with shape (batch, nz)
|
68 |
+
logvar: Tensor
|
69 |
+
logvar of gaussian distibution with shape (batch, nz)
|
70 |
+
Returns: Tensor
|
71 |
+
Sampled z with shape (batch, nsamples, nz)
|
72 |
+
"""
|
73 |
+
batch_size, nz = mu.size()
|
74 |
+
std = logvar.mul(0.5).exp()
|
75 |
+
|
76 |
+
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
|
77 |
+
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
|
78 |
+
|
79 |
+
eps = torch.zeros_like(std_expd).normal_()
|
80 |
+
|
81 |
+
return mu_expd + torch.mul(eps, std_expd)
|
82 |
+
|
83 |
+
def eval_inference_dist(self, x, z, param=None):
|
84 |
+
"""this function computes log q(z | x)
|
85 |
+
Args:
|
86 |
+
z: tensor
|
87 |
+
different z points that will be evaluated, with
|
88 |
+
shape [batch, nsamples, nz]
|
89 |
+
Returns: Tensor1
|
90 |
+
Tensor1: log q(z|x) with shape [batch, nsamples]
|
91 |
+
"""
|
92 |
+
|
93 |
+
nz = z.size(2)
|
94 |
+
|
95 |
+
if not param:
|
96 |
+
mu, logvar = self.forward(x)
|
97 |
+
else:
|
98 |
+
mu, logvar = param
|
99 |
+
|
100 |
+
# (batch_size, 1, nz)
|
101 |
+
mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
|
102 |
+
var = logvar.exp()
|
103 |
+
|
104 |
+
# (batch_size, nsamples, nz)
|
105 |
+
dev = z - mu
|
106 |
+
|
107 |
+
# (batch_size, nsamples)
|
108 |
+
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
|
109 |
+
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
|
110 |
+
|
111 |
+
return log_density
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
def calc_mi(self, x):
|
116 |
+
"""Approximate the mutual information between x and z
|
117 |
+
I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
|
118 |
+
Returns: Float
|
119 |
+
"""
|
120 |
+
|
121 |
+
# [x_batch, nz]
|
122 |
+
mu, logvar = self.forward(x)
|
123 |
+
|
124 |
+
x_batch, nz = mu.size()
|
125 |
+
|
126 |
+
# E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
|
127 |
+
neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean()
|
128 |
+
|
129 |
+
# [z_batch, 1, nz]
|
130 |
+
z_samples = self.reparameterize(mu, logvar, 1)
|
131 |
+
|
132 |
+
# [1, x_batch, nz]
|
133 |
+
mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
|
134 |
+
var = logvar.exp()
|
135 |
+
|
136 |
+
# (z_batch, x_batch, nz)
|
137 |
+
dev = z_samples - mu
|
138 |
+
|
139 |
+
# (z_batch, x_batch)
|
140 |
+
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
|
141 |
+
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
|
142 |
+
|
143 |
+
# log q(z): aggregate posterior
|
144 |
+
# [z_batch]
|
145 |
+
log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch)
|
146 |
+
|
147 |
+
return (neg_entropy - log_qz.mean(-1)).item()
|
Optimus/code/examples/big_ae/modules/spacefusion.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vae import VAE
|
2 |
+
import numpy as np
|
3 |
+
import torch, copy, pdb
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
|
11 |
+
def set_trainable(module, value):
|
12 |
+
for param in module.parameters():
|
13 |
+
param.requires_grad = value
|
14 |
+
|
15 |
+
class SpaceFusion(VAE):
|
16 |
+
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args):
|
17 |
+
super(SpaceFusion, self).__init__(encoder, decoder, tokenizer_encoder, tokenizer_decoder, args)
|
18 |
+
children = [v for v in encoder.encoder.layer.children()] # list of 12 BertLayer
|
19 |
+
|
20 |
+
self.num_s2s_bert_layer = args.num_s2s_bert_layer
|
21 |
+
self.S2S_layers = nn.ModuleList([copy.deepcopy(c) for c in children[-args.num_s2s_bert_layer:] ]) # the last layer of encoder
|
22 |
+
self.S2S_pooler = copy.deepcopy(encoder.pooler)
|
23 |
+
self.ix_turn_sep = tokenizer_encoder.convert_tokens_to_ids('[SEP]')
|
24 |
+
if args.freeze_bert:
|
25 |
+
print('@'*20 + f' freezing BERT {args.num_frozen_bert_layer} layers')
|
26 |
+
for child in children[:args.num_frozen_bert_layer]:
|
27 |
+
set_trainable(child, False)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
def ids2speaker(self, ids):
|
32 |
+
# 0 for speaker A, 1 for speaker B
|
33 |
+
N, T = ids.shape
|
34 |
+
speaker = np.zeros((N, T))
|
35 |
+
sep = ids == self.ix_turn_sep
|
36 |
+
for i in range(N):
|
37 |
+
is_B = False # start with speaker A
|
38 |
+
for t in range(T):
|
39 |
+
speaker[i,t] = int(is_B)
|
40 |
+
if sep[i,t].item():
|
41 |
+
is_B = not is_B
|
42 |
+
|
43 |
+
# make sure the final speaker is speaker B (so response is always speaker A)
|
44 |
+
if not is_B:
|
45 |
+
speaker = 1 - speaker
|
46 |
+
|
47 |
+
return torch.LongTensor(speaker).to(ids.device)
|
48 |
+
|
49 |
+
def forward(self, inputs_src, inputs_tgt, labels_tgt, return_vec=False): # [batch, time]
|
50 |
+
# toggle config to get desired encoder output
|
51 |
+
self.encoder.encoder.output_attentions = False
|
52 |
+
self.encoder.encoder.output_hidden_states = True
|
53 |
+
|
54 |
+
|
55 |
+
# AE encoder
|
56 |
+
mask = (inputs_tgt > 0).float().to(inputs_src.device)
|
57 |
+
outputs = self.encoder(inputs_tgt, attention_mask=mask)
|
58 |
+
z_AE, _ = self.connect(outputs[1])
|
59 |
+
z_AE = z_AE.squeeze(1)
|
60 |
+
|
61 |
+
# S2S encoder
|
62 |
+
mask = (inputs_src > 0).float()
|
63 |
+
speaker = self.ids2speaker(inputs_src)
|
64 |
+
outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker)
|
65 |
+
_, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs
|
66 |
+
seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 ()
|
67 |
+
|
68 |
+
for s2s in self.S2S_layers:
|
69 |
+
layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
|
70 |
+
seq_z_prev = layer_outputs[0]
|
71 |
+
|
72 |
+
z_S2S = self.encoder.pooler(layer_outputs[0])
|
73 |
+
z_S2S, _ = self.connect(z_S2S)
|
74 |
+
z_S2S = z_S2S.squeeze(1)
|
75 |
+
|
76 |
+
if return_vec:
|
77 |
+
return z_AE, z_S2S
|
78 |
+
|
79 |
+
# interpolation/smoothness
|
80 |
+
u = torch.FloatTensor(np.random.random((z_AE.shape[0], 1))).to(inputs_tgt.device)
|
81 |
+
z_interp = u * z_AE + (1 - u) * z_S2S
|
82 |
+
std = 0.1
|
83 |
+
noise = torch.FloatTensor(np.random.normal(size=z_interp.shape) * std).to(z_interp.device)
|
84 |
+
z_interp = z_interp + noise
|
85 |
+
|
86 |
+
loss_rec = 0
|
87 |
+
z_idx = 0
|
88 |
+
for z in [z_AE, z_S2S, z_interp]:
|
89 |
+
#pdb.set_trace()
|
90 |
+
past = z # past = self.decoder.linear(z)
|
91 |
+
outputs = self.decoder(input_ids=labels_tgt, past=past, labels=labels_tgt, label_ignore=self.pad_token_id)
|
92 |
+
if z_idx == 1:
|
93 |
+
loss_rec = loss_rec + 1.0 * outputs[0]
|
94 |
+
else:
|
95 |
+
loss_rec = loss_rec + outputs[0]
|
96 |
+
z_idx += 1
|
97 |
+
loss_rec = loss_rec/3
|
98 |
+
|
99 |
+
# fusion/regularization
|
100 |
+
L_pull = self.dist_pair(z_AE, z_S2S)
|
101 |
+
L_push = torch.stack([self.dist_batch(z) for z in [z_AE, z_S2S]]).min()
|
102 |
+
loss_reg = (L_pull - L_push * 2) / np.sqrt(z.shape[-1])
|
103 |
+
|
104 |
+
loss = loss_rec + self.args.beta * loss_reg
|
105 |
+
return loss_rec, loss_reg, loss
|
106 |
+
|
107 |
+
def sent2latent(self, inputs_src):
|
108 |
+
# toggle config to get desired encoder output
|
109 |
+
self.encoder.encoder.output_attentions = False
|
110 |
+
self.encoder.encoder.output_hidden_states = True
|
111 |
+
|
112 |
+
# S2S encoder
|
113 |
+
mask = (inputs_src > 0).float()
|
114 |
+
speaker = self.ids2speaker(inputs_src)
|
115 |
+
outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker)
|
116 |
+
|
117 |
+
_, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs
|
118 |
+
# seq_z_prev = all_layer_attn[-2] # seq of z at layer 11 ()
|
119 |
+
# layer_outputs = self.S2S_layer(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
|
120 |
+
|
121 |
+
seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 ()
|
122 |
+
for s2s in self.S2S_layers:
|
123 |
+
layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
|
124 |
+
seq_z_prev = layer_outputs[0]
|
125 |
+
|
126 |
+
z_S2S = self.encoder.pooler(layer_outputs[0])
|
127 |
+
z_S2S, _ = self.connect(z_S2S)
|
128 |
+
z_S2S = z_S2S.squeeze(1)
|
129 |
+
|
130 |
+
return z_S2S
|
131 |
+
|
132 |
+
|
133 |
+
def dist_pair(self, a, b):
|
134 |
+
return F.pairwise_distance(a, b).mean()
|
135 |
+
|
136 |
+
|
137 |
+
def dist_batch(self, vec):
|
138 |
+
n = vec.shape[0]
|
139 |
+
dmin = []
|
140 |
+
for i in range(n):
|
141 |
+
dd = F.pairwise_distance(vec[i:i+1,:].repeat(n,1), vec)
|
142 |
+
dmin.append(dd.min())
|
143 |
+
return torch.stack(dmin).mean()
|
Optimus/code/examples/big_ae/modules/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def safe_log(z):
|
4 |
+
return torch.log(z + 1e-7)
|
5 |
+
|
6 |
+
def log_sum_exp(value, dim=None, keepdim=False):
|
7 |
+
"""Numerically stable implementation of the operation
|
8 |
+
value.exp().sum(dim, keepdim).log()
|
9 |
+
"""
|
10 |
+
if dim is not None:
|
11 |
+
m, _ = torch.max(value, dim=dim, keepdim=True)
|
12 |
+
value0 = value - m
|
13 |
+
if keepdim is False:
|
14 |
+
m = m.squeeze(dim)
|
15 |
+
return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
|
16 |
+
else:
|
17 |
+
m = torch.max(value)
|
18 |
+
sum_exp = torch.sum(torch.exp(value - m))
|
19 |
+
return m + torch.log(sum_exp)
|
20 |
+
|
21 |
+
|
22 |
+
def generate_grid(zmin, zmax, dz, device, ndim=2):
|
23 |
+
"""generate a 1- or 2-dimensional grid
|
24 |
+
Returns: Tensor, int
|
25 |
+
Tensor: The grid tensor with shape (k^2, 2),
|
26 |
+
where k=(zmax - zmin)/dz
|
27 |
+
int: k
|
28 |
+
"""
|
29 |
+
|
30 |
+
if ndim == 2:
|
31 |
+
x = torch.arange(zmin, zmax, dz)
|
32 |
+
k = x.size(0)
|
33 |
+
|
34 |
+
x1 = x.unsqueeze(1).repeat(1, k).view(-1)
|
35 |
+
x2 = x.repeat(k)
|
36 |
+
|
37 |
+
return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k
|
38 |
+
|
39 |
+
elif ndim == 1:
|
40 |
+
return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device)
|
Optimus/code/examples/big_ae/modules/vae.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .utils import log_sum_exp
|
6 |
+
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
import logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class VAE(nn.Module):
|
14 |
+
"""VAE with normal prior"""
|
15 |
+
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
|
16 |
+
super(VAE, self).__init__()
|
17 |
+
self.encoder = encoder
|
18 |
+
self.decoder = decoder
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
self.nz = args.latent_size
|
22 |
+
|
23 |
+
self.eos_token_id = tokenizer_decoder.convert_tokens_to_ids([tokenizer_decoder.eos_token])[0]
|
24 |
+
self.pad_token_id = tokenizer_decoder.convert_tokens_to_ids([tokenizer_decoder.pad_token])[0]
|
25 |
+
|
26 |
+
|
27 |
+
# connector: from Bert hidden units to the latent space
|
28 |
+
# self.linear = nn.Linear(args.nz, 2 * args.nz, bias=False)
|
29 |
+
|
30 |
+
# Standard Normal prior
|
31 |
+
loc = torch.zeros(self.nz, device=args.device)
|
32 |
+
scale = torch.ones(self.nz, device=args.device)
|
33 |
+
self.prior = torch.distributions.normal.Normal(loc, scale)
|
34 |
+
|
35 |
+
def connect(self, bert_fea, nsamples=1):
|
36 |
+
"""
|
37 |
+
Returns: Tensor1, Tensor2
|
38 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
39 |
+
Tensor2: the tenor of KL for each x with shape [batch]
|
40 |
+
"""
|
41 |
+
|
42 |
+
# (batch_size, nz)
|
43 |
+
|
44 |
+
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
|
45 |
+
# pdb.set_trace()
|
46 |
+
# mean, logvar = mean.squeeze(0), logvar.squeeze(0)
|
47 |
+
|
48 |
+
# (batch, nsamples, nz)
|
49 |
+
z = self.reparameterize(mean, logvar, nsamples)
|
50 |
+
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
|
51 |
+
|
52 |
+
return z, KL
|
53 |
+
|
54 |
+
def connect_deterministic(self, bert_fea, nsamples=1):
|
55 |
+
"""
|
56 |
+
Returns: Tensor1, Tensor2
|
57 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
58 |
+
Tensor2: the tenor of KL for each x with shape [batch]
|
59 |
+
"""
|
60 |
+
|
61 |
+
# (batch_size, nz)
|
62 |
+
|
63 |
+
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
|
64 |
+
# pdb.set_trace()
|
65 |
+
# mean, logvar = mean.squeeze(0), logvar.squeeze(0)
|
66 |
+
|
67 |
+
logvar.fill_(.0)
|
68 |
+
# (batch, nsamples, nz)
|
69 |
+
z = self.reparameterize(mean, logvar, nsamples)
|
70 |
+
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
|
71 |
+
|
72 |
+
return z, KL
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def reparameterize(self, mu, logvar, nsamples=1):
|
77 |
+
"""sample from posterior Gaussian family
|
78 |
+
Args:
|
79 |
+
mu: Tensor
|
80 |
+
Mean of gaussian distribution with shape (batch, nz)
|
81 |
+
logvar: Tensor
|
82 |
+
logvar of gaussian distibution with shape (batch, nz)
|
83 |
+
Returns: Tensor
|
84 |
+
Sampled z with shape (batch, nsamples, nz)
|
85 |
+
"""
|
86 |
+
batch_size, nz = mu.size()
|
87 |
+
std = logvar.mul(0.5).exp()
|
88 |
+
|
89 |
+
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
|
90 |
+
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
|
91 |
+
|
92 |
+
eps = torch.zeros_like(std_expd).normal_()
|
93 |
+
|
94 |
+
return mu_expd + torch.mul(eps, std_expd)
|
95 |
+
|
96 |
+
def forward(self, inputs, labels):
|
97 |
+
|
98 |
+
# pdb.set_trace()
|
99 |
+
|
100 |
+
attention_mask=(inputs > 0).float()
|
101 |
+
# logger.info(inputs)
|
102 |
+
# logger.info(attention_mask)
|
103 |
+
# logger.info(labels)
|
104 |
+
reconstrution_mask=(labels != 50257).float() # 50257 is the padding token for GPT2
|
105 |
+
sent_length = torch.sum(reconstrution_mask, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
outputs = self.encoder(inputs, attention_mask)
|
109 |
+
pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
|
110 |
+
|
111 |
+
if self.args.fb_mode==0:
|
112 |
+
# Connect hidden feature to the latent space
|
113 |
+
latent_z, loss_kl = self.connect(pooled_hidden_fea)
|
114 |
+
latent_z = latent_z.squeeze(1)
|
115 |
+
|
116 |
+
|
117 |
+
# Decoding
|
118 |
+
outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
|
119 |
+
loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
120 |
+
|
121 |
+
elif self.args.fb_mode==1:
|
122 |
+
# Connect hidden feature to the latent space
|
123 |
+
mu, logvar = self.encoder.linear(pooled_hidden_fea).chunk(2, -1)
|
124 |
+
latent_z = self.reparameterize(mu, logvar, nsamples=1)
|
125 |
+
latent_z = latent_z.squeeze(1)
|
126 |
+
loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1)
|
127 |
+
kl_mask = (loss_kl > self.args.dim_target_kl).float()
|
128 |
+
loss_kl = (kl_mask * loss_kl).sum(dim=1)
|
129 |
+
|
130 |
+
# pdb.set_trace()
|
131 |
+
# past = self.decoder.linear(latent_z)
|
132 |
+
# Decoding
|
133 |
+
outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
|
134 |
+
loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
135 |
+
|
136 |
+
elif self.args.fb_mode==2:
|
137 |
+
# Connect hidden feature to the latent space
|
138 |
+
latent_z, loss_kl = self.connect_deterministic(pooled_hidden_fea)
|
139 |
+
latent_z = latent_z.squeeze(1)
|
140 |
+
|
141 |
+
# past = self.decoder.linear(latent_z)
|
142 |
+
# Decoding
|
143 |
+
outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
|
144 |
+
loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
145 |
+
|
146 |
+
|
147 |
+
# pdb.set_trace()
|
148 |
+
if self.args.length_weighted_loss:
|
149 |
+
loss = loss_rec / sent_length + self.args.beta * loss_kl
|
150 |
+
else:
|
151 |
+
loss = loss_rec + self.args.beta * loss_kl
|
152 |
+
|
153 |
+
|
154 |
+
return loss_rec, loss_kl, loss
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
def encoder_sample(self, bert_fea, nsamples):
|
159 |
+
"""sampling from the encoder
|
160 |
+
Returns: Tensor1
|
161 |
+
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
|
162 |
+
"""
|
163 |
+
|
164 |
+
# (batch_size, nz)
|
165 |
+
|
166 |
+
mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
|
167 |
+
mu, logvar = mu.squeeze(0), logvar.squeeze(0)
|
168 |
+
|
169 |
+
# (batch, nsamples, nz)
|
170 |
+
z = self.reparameterize(mu, logvar, nsamples)
|
171 |
+
|
172 |
+
return z, (mu, logvar)
|
173 |
+
|
174 |
+
|
175 |
+
def encode_stats(self, x):
|
176 |
+
"""
|
177 |
+
Returns: Tensor1, Tensor2
|
178 |
+
Tensor1: the mean of latent z with shape [batch, nz]
|
179 |
+
Tensor2: the logvar of latent z with shape [batch, nz]
|
180 |
+
"""
|
181 |
+
|
182 |
+
return self.encoder.encode_stats(x)
|
183 |
+
|
184 |
+
def decode(self, z, strategy, K=10):
|
185 |
+
"""generate samples from z given strategy
|
186 |
+
Args:
|
187 |
+
z: [batch, nsamples, nz]
|
188 |
+
strategy: "beam" or "greedy" or "sample"
|
189 |
+
K: the beam width parameter
|
190 |
+
Returns: List1
|
191 |
+
List1: a list of decoded word sequence
|
192 |
+
"""
|
193 |
+
|
194 |
+
if strategy == "beam":
|
195 |
+
return self.decoder.beam_search_decode(z, K)
|
196 |
+
elif strategy == "greedy":
|
197 |
+
return self.decoder.greedy_decode(z)
|
198 |
+
elif strategy == "sample":
|
199 |
+
return self.decoder.sample_decode(z)
|
200 |
+
else:
|
201 |
+
raise ValueError("the decoding strategy is not supported")
|
202 |
+
|
203 |
+
|
204 |
+
def reconstruct(self, x, decoding_strategy="greedy", K=5):
|
205 |
+
"""reconstruct from input x
|
206 |
+
Args:
|
207 |
+
x: (batch, *)
|
208 |
+
decoding_strategy: "beam" or "greedy" or "sample"
|
209 |
+
K: the beam width parameter
|
210 |
+
Returns: List1
|
211 |
+
List1: a list of decoded word sequence
|
212 |
+
"""
|
213 |
+
z = self.sample_from_inference(x).squeeze(1)
|
214 |
+
|
215 |
+
return self.decode(z, decoding_strategy, K)
|
216 |
+
|
217 |
+
def log_probability(self, x, z):
|
218 |
+
"""Cross Entropy in the language case
|
219 |
+
Args:
|
220 |
+
x: (batch_size, seq_len)
|
221 |
+
z: (batch_size, n_sample, nz)
|
222 |
+
Returns:
|
223 |
+
log_p: (batch_size, n_sample).
|
224 |
+
log_p(x|z) across different x and z
|
225 |
+
"""
|
226 |
+
outputs = self.decoder(input_ids=x, past=z, labels=x, label_ignore=self.pad_token_id)
|
227 |
+
loss_rec = outputs[0]
|
228 |
+
return -loss_rec
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
def loss_iw(self, x0, x1, nsamples=50, ns=1):
|
233 |
+
"""
|
234 |
+
Args:
|
235 |
+
x: if the data is constant-length, x is the data tensor with
|
236 |
+
shape (batch, *). Otherwise x is a tuple that contains
|
237 |
+
the data tensor and length list
|
238 |
+
Returns: Tensor1, Tensor2, Tensor3
|
239 |
+
Tensor1: total loss [batch]
|
240 |
+
Tensor2: reconstruction loss shape [batch]
|
241 |
+
Tensor3: KL loss shape [batch]
|
242 |
+
"""
|
243 |
+
|
244 |
+
# encoding into bert features
|
245 |
+
bert_fea = self.encoder(x0)[1]
|
246 |
+
|
247 |
+
# (batch_size, nz)
|
248 |
+
|
249 |
+
mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
|
250 |
+
|
251 |
+
|
252 |
+
##################
|
253 |
+
# compute KL
|
254 |
+
##################
|
255 |
+
# pdb.set_trace()
|
256 |
+
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
|
257 |
+
|
258 |
+
# mu, logvar = mu.squeeze(0), logvar.squeeze(0)
|
259 |
+
ll_tmp, rc_tmp = [], []
|
260 |
+
for _ in range(int(nsamples / ns)):
|
261 |
+
|
262 |
+
# (batch, nsamples, nz)
|
263 |
+
z = self.reparameterize(mu, logvar, ns)
|
264 |
+
# past = self.decoder.linear(z)
|
265 |
+
past = z
|
266 |
+
|
267 |
+
# [batch, nsamples]
|
268 |
+
log_prior = self.eval_prior_dist(z)
|
269 |
+
log_gen = self.eval_cond_ll(x1, past)
|
270 |
+
log_infer = self.eval_inference_dist(z, (mu, logvar))
|
271 |
+
|
272 |
+
# pdb.set_trace()
|
273 |
+
log_gen = log_gen.unsqueeze(0).contiguous().view(z.shape[0],-1)
|
274 |
+
|
275 |
+
|
276 |
+
# pdb.set_trace()
|
277 |
+
rc_tmp.append(log_gen)
|
278 |
+
ll_tmp.append(log_gen + log_prior - log_infer)
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
log_prob_iw = log_sum_exp(torch.cat(ll_tmp, dim=-1), dim=-1) - math.log(nsamples)
|
283 |
+
log_gen_iw = torch.mean(torch.cat(rc_tmp, dim=-1), dim=-1)
|
284 |
+
|
285 |
+
return log_prob_iw, log_gen_iw , KL
|
286 |
+
|
287 |
+
|
288 |
+
def nll_iw(self, x0, x1, nsamples, ns=1):
|
289 |
+
"""compute the importance weighting estimate of the log-likelihood
|
290 |
+
Args:
|
291 |
+
x0, x1: two different tokenization results of x, where x is the data tensor with shape (batch, *).
|
292 |
+
nsamples: Int
|
293 |
+
the number of samples required to estimate marginal data likelihood
|
294 |
+
Returns: Tensor1
|
295 |
+
Tensor1: the estimate of log p(x), shape [batch]
|
296 |
+
"""
|
297 |
+
|
298 |
+
# compute iw every ns samples to address the memory issue
|
299 |
+
# nsamples = 500, ns = 100
|
300 |
+
# nsamples = 500, ns = 10
|
301 |
+
|
302 |
+
# TODO: note that x is forwarded twice in self.encoder.sample(x, ns) and self.eval_inference_dist(x, z, param)
|
303 |
+
#. this problem is to be solved in order to speed up
|
304 |
+
|
305 |
+
tmp = []
|
306 |
+
for _ in range(int(nsamples / ns)):
|
307 |
+
# [batch, ns, nz]
|
308 |
+
|
309 |
+
# Chunyuan:
|
310 |
+
# encoding into bert features
|
311 |
+
pooled_hidden_fea = self.encoder(x0)[1]
|
312 |
+
|
313 |
+
# param is the parameters required to evaluate q(z|x)
|
314 |
+
z, param = self.encoder_sample(pooled_hidden_fea, ns)
|
315 |
+
|
316 |
+
# [batch, ns]
|
317 |
+
log_comp_ll = self.eval_complete_ll(x1, z)
|
318 |
+
log_infer_ll = self.eval_inference_dist(z, param)
|
319 |
+
|
320 |
+
tmp.append(log_comp_ll - log_infer_ll)
|
321 |
+
|
322 |
+
ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples)
|
323 |
+
|
324 |
+
return ll_iw
|
325 |
+
|
326 |
+
def KL(self, x):
|
327 |
+
_, KL = self.encode(x, 1)
|
328 |
+
|
329 |
+
return KL
|
330 |
+
|
331 |
+
def eval_prior_dist(self, zrange):
|
332 |
+
"""perform grid search to calculate the true posterior
|
333 |
+
Args:
|
334 |
+
zrange: tensor
|
335 |
+
different z points that will be evaluated, with
|
336 |
+
shape (k^2, nz), where k=(zmax - zmin)/space
|
337 |
+
"""
|
338 |
+
|
339 |
+
# (k^2)
|
340 |
+
return self.prior.log_prob(zrange).sum(dim=-1)
|
341 |
+
|
342 |
+
def eval_complete_ll(self, x, z):
|
343 |
+
"""compute log p(z,x)
|
344 |
+
Args:
|
345 |
+
x: Tensor
|
346 |
+
input with shape [batch, seq_len]
|
347 |
+
z: Tensor
|
348 |
+
evaluation points with shape [batch, nsamples, nz]
|
349 |
+
Returns: Tensor1
|
350 |
+
Tensor1: log p(z,x) Tensor with shape [batch, nsamples]
|
351 |
+
"""
|
352 |
+
|
353 |
+
# [batch, nsamples]
|
354 |
+
log_prior = self.eval_prior_dist(z)
|
355 |
+
log_gen = self.eval_cond_ll(x, z)
|
356 |
+
|
357 |
+
return log_prior + log_gen
|
358 |
+
|
359 |
+
|
360 |
+
|
361 |
+
def eval_cond_ll(self, x, z):
|
362 |
+
"""compute log p(x|z)
|
363 |
+
"""
|
364 |
+
x_shape = list(x.size())
|
365 |
+
z_shape = list(z.size())
|
366 |
+
if len(z_shape) == 3:
|
367 |
+
x = x.unsqueeze(1).repeat(1, z_shape[1], 1).contiguous().view(x_shape[0]*z_shape[1], x_shape[-1])
|
368 |
+
z = z.contiguous().view(x_shape[0]*z_shape[1], z_shape[-1])
|
369 |
+
|
370 |
+
return self.log_probability(x, z)
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
def eval_log_model_posterior(self, x, grid_z):
|
375 |
+
"""perform grid search to calculate the true posterior
|
376 |
+
this function computes p(z|x)
|
377 |
+
Args:
|
378 |
+
grid_z: tensor
|
379 |
+
different z points that will be evaluated, with
|
380 |
+
shape (k^2, nz), where k=(zmax - zmin)/pace
|
381 |
+
Returns: Tensor
|
382 |
+
Tensor: the log posterior distribution log p(z|x) with
|
383 |
+
shape [batch_size, K^2]
|
384 |
+
"""
|
385 |
+
try:
|
386 |
+
batch_size = x.size(0)
|
387 |
+
except:
|
388 |
+
batch_size = x[0].size(0)
|
389 |
+
|
390 |
+
# (batch_size, k^2, nz)
|
391 |
+
grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous()
|
392 |
+
|
393 |
+
# (batch_size, k^2)
|
394 |
+
log_comp = self.eval_complete_ll(x, grid_z)
|
395 |
+
|
396 |
+
# normalize to posterior
|
397 |
+
log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True)
|
398 |
+
|
399 |
+
return log_posterior
|
400 |
+
|
401 |
+
def sample_from_inference(self, x, nsamples=1):
|
402 |
+
"""perform sampling from inference net
|
403 |
+
Returns: Tensor
|
404 |
+
Tensor: samples from infernece nets with
|
405 |
+
shape (batch_size, nsamples, nz)
|
406 |
+
"""
|
407 |
+
z, _ = self.encoder.sample(x, nsamples)
|
408 |
+
|
409 |
+
return z
|
410 |
+
|
411 |
+
|
412 |
+
def sample_from_posterior(self, x, nsamples):
|
413 |
+
"""perform MH sampling from model posterior
|
414 |
+
Returns: Tensor
|
415 |
+
Tensor: samples from model posterior with
|
416 |
+
shape (batch_size, nsamples, nz)
|
417 |
+
"""
|
418 |
+
|
419 |
+
# use the samples from inference net as initial points
|
420 |
+
# for MCMC sampling. [batch_size, nsamples, nz]
|
421 |
+
cur = self.encoder.sample_from_inference(x, 1)
|
422 |
+
cur_ll = self.eval_complete_ll(x, cur)
|
423 |
+
total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin
|
424 |
+
samples = []
|
425 |
+
for iter_ in range(total_iter):
|
426 |
+
next = torch.normal(mean=cur,
|
427 |
+
std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std))
|
428 |
+
# [batch_size, 1]
|
429 |
+
next_ll = self.eval_complete_ll(x, next)
|
430 |
+
ratio = next_ll - cur_ll
|
431 |
+
|
432 |
+
accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size()))
|
433 |
+
|
434 |
+
uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_()
|
435 |
+
|
436 |
+
# [batch_size, 1]
|
437 |
+
mask = (uniform_t < accept_prob).float()
|
438 |
+
mask_ = mask.unsqueeze(2)
|
439 |
+
|
440 |
+
cur = mask_ * next + (1 - mask_) * cur
|
441 |
+
cur_ll = mask * next_ll + (1 - mask) * cur_ll
|
442 |
+
|
443 |
+
if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0:
|
444 |
+
samples.append(cur.unsqueeze(1))
|
445 |
+
|
446 |
+
return torch.cat(samples, dim=1)
|
447 |
+
|
448 |
+
|
449 |
+
def calc_model_posterior_mean(self, x, grid_z):
|
450 |
+
"""compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z]
|
451 |
+
Args:
|
452 |
+
grid_z: different z points that will be evaluated, with
|
453 |
+
shape (k^2, nz), where k=(zmax - zmin)/pace
|
454 |
+
x: [batch, *]
|
455 |
+
Returns: Tensor1
|
456 |
+
Tensor1: the mean value tensor with shape [batch, nz]
|
457 |
+
"""
|
458 |
+
|
459 |
+
# [batch, K^2]
|
460 |
+
log_posterior = self.eval_log_model_posterior(x, grid_z)
|
461 |
+
posterior = log_posterior.exp()
|
462 |
+
|
463 |
+
# [batch, nz]
|
464 |
+
return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1)
|
465 |
+
|
466 |
+
def calc_infer_mean(self, x):
|
467 |
+
"""
|
468 |
+
Returns: Tensor1
|
469 |
+
Tensor1: the mean of inference distribution, with shape [batch, nz]
|
470 |
+
"""
|
471 |
+
|
472 |
+
mean, logvar = self.encoder.forward(x)
|
473 |
+
|
474 |
+
return mean
|
475 |
+
|
476 |
+
|
477 |
+
|
478 |
+
|
479 |
+
def eval_inference_dist(self, z, param):
|
480 |
+
"""this function computes log q(z | x)
|
481 |
+
Args:
|
482 |
+
z: tensor
|
483 |
+
different z points that will be evaluated, with
|
484 |
+
shape [batch, nsamples, nz]
|
485 |
+
Returns: Tensor1
|
486 |
+
Tensor1: log q(z|x) with shape [batch, nsamples]
|
487 |
+
"""
|
488 |
+
|
489 |
+
nz = z.size(2)
|
490 |
+
mu, logvar = param
|
491 |
+
|
492 |
+
# (batch_size, 1, nz)
|
493 |
+
mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
|
494 |
+
var = logvar.exp()
|
495 |
+
|
496 |
+
# (batch_size, nsamples, nz)
|
497 |
+
dev = z - mu
|
498 |
+
|
499 |
+
# (batch_size, nsamples)
|
500 |
+
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
|
501 |
+
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
|
502 |
+
|
503 |
+
return log_density
|
504 |
+
|
505 |
+
|
506 |
+
|
507 |
+
def calc_mi(self, test_data_batch, args):
|
508 |
+
# calc_mi_v3
|
509 |
+
import math
|
510 |
+
from modules.utils import log_sum_exp
|
511 |
+
|
512 |
+
mi = 0
|
513 |
+
num_examples = 0
|
514 |
+
|
515 |
+
mu_batch_list, logvar_batch_list = [], []
|
516 |
+
neg_entropy = 0.
|
517 |
+
for batch_data in test_data_batch:
|
518 |
+
|
519 |
+
x0, _, _ = batch_data
|
520 |
+
x0 = x0.to(args.device)
|
521 |
+
|
522 |
+
# encoding into bert features
|
523 |
+
bert_fea = self.encoder(x0)[1]
|
524 |
+
|
525 |
+
(batch_size, nz)
|
526 |
+
mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
|
527 |
+
|
528 |
+
x_batch, nz = mu.size()
|
529 |
+
|
530 |
+
#print(x_batch, end=' ')
|
531 |
+
|
532 |
+
num_examples += x_batch
|
533 |
+
|
534 |
+
# E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
|
535 |
+
|
536 |
+
neg_entropy += (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).sum().item()
|
537 |
+
mu_batch_list += [mu.cpu()]
|
538 |
+
logvar_batch_list += [logvar.cpu()]
|
539 |
+
|
540 |
+
pdb.set_trace()
|
541 |
+
|
542 |
+
neg_entropy = neg_entropy / num_examples
|
543 |
+
##print()
|
544 |
+
|
545 |
+
num_examples = 0
|
546 |
+
log_qz = 0.
|
547 |
+
for i in range(len(mu_batch_list)):
|
548 |
+
###############
|
549 |
+
# get z_samples
|
550 |
+
###############
|
551 |
+
mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda()
|
552 |
+
|
553 |
+
# [z_batch, 1, nz]
|
554 |
+
|
555 |
+
z_samples = self.reparameterize(mu, logvar, 1)
|
556 |
+
|
557 |
+
z_samples = z_samples.view(-1, 1, nz)
|
558 |
+
num_examples += z_samples.size(0)
|
559 |
+
|
560 |
+
###############
|
561 |
+
# compute density
|
562 |
+
###############
|
563 |
+
# [1, x_batch, nz]
|
564 |
+
#mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda()
|
565 |
+
#indices = list(np.random.choice(np.arange(len(mu_batch_list)), 10)) + [i]
|
566 |
+
indices = np.arange(len(mu_batch_list))
|
567 |
+
mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).cuda()
|
568 |
+
logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).cuda()
|
569 |
+
x_batch, nz = mu.size()
|
570 |
+
|
571 |
+
mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
|
572 |
+
var = logvar.exp()
|
573 |
+
|
574 |
+
# (z_batch, x_batch, nz)
|
575 |
+
dev = z_samples - mu
|
576 |
+
|
577 |
+
# (z_batch, x_batch)
|
578 |
+
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
|
579 |
+
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
|
580 |
+
|
581 |
+
# log q(z): aggregate posterior
|
582 |
+
# [z_batch]
|
583 |
+
log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1)
|
584 |
+
|
585 |
+
log_qz /= num_examples
|
586 |
+
mi = neg_entropy - log_qz
|
587 |
+
|
588 |
+
return mi
|
589 |
+
|
590 |
+
|
591 |
+
|
592 |
+
def calc_au(self, eval_dataloader, args, delta=0.01):
|
593 |
+
"""compute the number of active units
|
594 |
+
"""
|
595 |
+
cnt = 0
|
596 |
+
for batch_data in eval_dataloader:
|
597 |
+
|
598 |
+
x0, _, _ = batch_data
|
599 |
+
x0 = x0.to(args.device)
|
600 |
+
|
601 |
+
# encoding into bert features
|
602 |
+
bert_fea = self.encoder(x0)[1]
|
603 |
+
|
604 |
+
# (batch_size, nz)
|
605 |
+
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
|
606 |
+
|
607 |
+
if cnt == 0:
|
608 |
+
means_sum = mean.sum(dim=0, keepdim=True)
|
609 |
+
else:
|
610 |
+
means_sum = means_sum + mean.sum(dim=0, keepdim=True)
|
611 |
+
cnt += mean.size(0)
|
612 |
+
|
613 |
+
# (1, nz)
|
614 |
+
mean_mean = means_sum / cnt
|
615 |
+
|
616 |
+
cnt = 0
|
617 |
+
for batch_data in eval_dataloader:
|
618 |
+
|
619 |
+
x0, _, _ = batch_data
|
620 |
+
x0 = x0.to(args.device)
|
621 |
+
|
622 |
+
# encoding into bert features
|
623 |
+
bert_fea = self.encoder(x0)[1]
|
624 |
+
|
625 |
+
# (batch_size, nz)
|
626 |
+
mean, _ = self.encoder.linear(bert_fea).chunk(2, -1)
|
627 |
+
|
628 |
+
if cnt == 0:
|
629 |
+
var_sum = ((mean - mean_mean) ** 2).sum(dim=0)
|
630 |
+
else:
|
631 |
+
var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0)
|
632 |
+
cnt += mean.size(0)
|
633 |
+
|
634 |
+
# (nz)
|
635 |
+
au_var = var_sum / (cnt - 1)
|
636 |
+
|
637 |
+
return (au_var >= delta).sum().item(), au_var
|
638 |
+
|
Optimus/code/examples/big_ae/run_data_filtering.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
|
24 |
+
|
25 |
+
import pdb
|
26 |
+
import argparse
|
27 |
+
import glob
|
28 |
+
import logging
|
29 |
+
|
30 |
+
import os
|
31 |
+
import pickle
|
32 |
+
import json
|
33 |
+
import random
|
34 |
+
from pathlib import Path
|
35 |
+
|
36 |
+
import numpy as np
|
37 |
+
import torch
|
38 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
39 |
+
from torch.utils.data.distributed import DistributedSampler
|
40 |
+
from tensorboardX import SummaryWriter
|
41 |
+
from tqdm import tqdm, trange
|
42 |
+
from collections import defaultdict
|
43 |
+
|
44 |
+
# from azure.cosmosdb.table.tableservice import TableService
|
45 |
+
# from azure.cosmosdb.table.models import Entity
|
46 |
+
from datetime import datetime
|
47 |
+
|
48 |
+
|
49 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
50 |
+
BertConfig, BertForLatentConnector, BertTokenizer,
|
51 |
+
GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
|
52 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
53 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
54 |
+
|
55 |
+
from utils import (calc_iwnll, calc_mi, calc_au, BucketingDataLoader, MultipleFiles_DataLoader, BucketingMultipleFiles_DataLoader, frange_cycle_linear, frange_cycle_zero_linear)
|
56 |
+
|
57 |
+
from modules import VAE
|
58 |
+
|
59 |
+
|
60 |
+
# logging.getLogger("azure").setLevel(logging.WARNING)
|
61 |
+
# logging.getLogger("TableService").setLevel(logging.WARNING)
|
62 |
+
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
|
66 |
+
MODEL_CLASSES = {
|
67 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
68 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
69 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
|
70 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
storage_name="textae"
|
75 |
+
key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
|
76 |
+
# ts = TableService(account_name=storage_name, account_key=key)
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
81 |
+
if isinstance(tokenizer, list):
|
82 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
83 |
+
file_path=args.input_file_path
|
84 |
+
dataloader = MultipleFiles_DataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True, use_tensor=False)
|
85 |
+
else:
|
86 |
+
pass
|
87 |
+
return dataloader
|
88 |
+
|
89 |
+
|
90 |
+
def set_seed(args):
|
91 |
+
random.seed(args.seed)
|
92 |
+
np.random.seed(args.seed)
|
93 |
+
torch.manual_seed(args.seed)
|
94 |
+
if args.n_gpu > 0:
|
95 |
+
torch.cuda.manual_seed_all(args.seed)
|
96 |
+
|
97 |
+
|
98 |
+
def mask_tokens(inputs, tokenizer, args):
|
99 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
100 |
+
labels = inputs.clone()
|
101 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
102 |
+
|
103 |
+
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
|
104 |
+
labels[masked_indices==1] = -1 # We only compute loss on masked tokens
|
105 |
+
|
106 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
107 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
|
108 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
109 |
+
|
110 |
+
# 10% of the time, we replace masked input tokens with random word
|
111 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
|
112 |
+
indices_random = indices_random
|
113 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
114 |
+
inputs[indices_random] = random_words[indices_random]
|
115 |
+
|
116 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
117 |
+
return inputs, labels
|
118 |
+
|
119 |
+
|
120 |
+
def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
|
121 |
+
""" Train the model """
|
122 |
+
if args.local_rank in [-1, 0]:
|
123 |
+
tb_writer = SummaryWriter()
|
124 |
+
|
125 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
126 |
+
# train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
127 |
+
# train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
128 |
+
|
129 |
+
if args.max_steps > 0:
|
130 |
+
t_total = args.max_steps
|
131 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
132 |
+
else:
|
133 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
134 |
+
|
135 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
136 |
+
|
137 |
+
|
138 |
+
# model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
|
139 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
140 |
+
optimizer_grouped_parameters = [
|
141 |
+
{'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
142 |
+
{'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
143 |
+
]
|
144 |
+
|
145 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
146 |
+
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
147 |
+
|
148 |
+
|
149 |
+
if args.fp16:
|
150 |
+
try:
|
151 |
+
from apex import amp
|
152 |
+
except ImportError:
|
153 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
154 |
+
model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
|
155 |
+
|
156 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
157 |
+
if args.n_gpu > 1:
|
158 |
+
model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
|
159 |
+
|
160 |
+
# Distributed training (should be after apex fp16 initialization)
|
161 |
+
if args.local_rank != -1:
|
162 |
+
model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
|
163 |
+
output_device=args.local_rank,
|
164 |
+
find_unused_parameters=True)
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
files = Path(args.input_file_path)
|
169 |
+
num_files = len(list(files.glob('*seq64*.json')))
|
170 |
+
|
171 |
+
# create output file folder
|
172 |
+
if not os.path.exists(args.output_file_path) and args.local_rank in [-1, 0]:
|
173 |
+
os.makedirs(args.output_file_path)
|
174 |
+
|
175 |
+
|
176 |
+
# Train!
|
177 |
+
logger.info("***** Running training *****")
|
178 |
+
logger.info(" Num files = %d", num_files)
|
179 |
+
logger.info(" Num examples of first file = %d", train_dataloader.num_examples)
|
180 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
181 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
182 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
183 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
184 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
185 |
+
logger.info(" Total optimization steps = %d", t_total)
|
186 |
+
|
187 |
+
|
188 |
+
num_collected, num_dropped = 0, 0
|
189 |
+
|
190 |
+
model_vae.zero_grad()
|
191 |
+
num_train_epochs_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
192 |
+
|
193 |
+
n_iter = int(args.num_train_epochs) * len(train_dataloader)
|
194 |
+
|
195 |
+
tmp_list = []
|
196 |
+
dict_token_length = defaultdict(int)
|
197 |
+
|
198 |
+
|
199 |
+
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
200 |
+
os.makedirs(args.output_dir)
|
201 |
+
|
202 |
+
dict_file = os.path.join(args.output_dir, args.dataset.lower()+f'.length_freq.json' )
|
203 |
+
|
204 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
205 |
+
for epoch in num_train_epochs_iterator:
|
206 |
+
|
207 |
+
for idx_file in range(num_files):
|
208 |
+
|
209 |
+
examples = []
|
210 |
+
cached_features_file = os.path.join(args.output_file_path, args.dataset.lower()+f'.segmented.nltk.split.seq64.{train_dataloader.file_idx}.json' )
|
211 |
+
logger.info(f"Epoch {epoch}, File idx {train_dataloader.file_idx}")
|
212 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
213 |
+
|
214 |
+
# if idx_file > 11:
|
215 |
+
# break
|
216 |
+
|
217 |
+
for step, batch in enumerate(epoch_iterator):
|
218 |
+
|
219 |
+
inst, token_lengths = batch
|
220 |
+
dict_token_length[ token_lengths[0,0].item() ] += 1
|
221 |
+
|
222 |
+
if ( token_lengths> 256 ).sum().item()>0:
|
223 |
+
over_length_tensor = ( token_lengths> 256 ).sum(-1)
|
224 |
+
inst_ = [inst[i] for i in range(len(inst)) if over_length_tensor[i]==0 ]
|
225 |
+
examples += inst_
|
226 |
+
num_collected += len(inst_)
|
227 |
+
num_dropped += len(inst) - len(inst_)
|
228 |
+
logger.info(f"{num_dropped} files filtered.")
|
229 |
+
else:
|
230 |
+
examples += inst
|
231 |
+
num_collected += len(inst)
|
232 |
+
|
233 |
+
# Good practice: save your data multiple times on Philly
|
234 |
+
|
235 |
+
if args.use_philly:
|
236 |
+
save_solid = False
|
237 |
+
while not save_solid:
|
238 |
+
try:
|
239 |
+
with open(cached_features_file, 'w') as fp:
|
240 |
+
json.dump(examples, fp)
|
241 |
+
save_solid = True
|
242 |
+
except:
|
243 |
+
pass
|
244 |
+
else:
|
245 |
+
with open(cached_features_file, 'w') as fp:
|
246 |
+
json.dump(examples, fp)
|
247 |
+
logger.info(f"Saving features in the cached file at {cached_features_file}")
|
248 |
+
|
249 |
+
train_dataloader.reset()
|
250 |
+
|
251 |
+
if args.local_rank in [-1, 0]:
|
252 |
+
tb_writer.close()
|
253 |
+
|
254 |
+
logger.info(dict_token_length)
|
255 |
+
# Good practice: save your dict multiple times on Philly
|
256 |
+
if args.use_philly:
|
257 |
+
save_solid = False
|
258 |
+
while not save_solid:
|
259 |
+
try:
|
260 |
+
with open(dict_file, 'w') as fp:
|
261 |
+
json.dump(dict_token_length, fp)
|
262 |
+
save_solid = True
|
263 |
+
except:
|
264 |
+
pass
|
265 |
+
else:
|
266 |
+
with open(dict_file, 'w') as fp:
|
267 |
+
json.dump(dict_token_length, fp)
|
268 |
+
|
269 |
+
return num_collected, num_dropped
|
270 |
+
|
271 |
+
|
272 |
+
def main():
|
273 |
+
parser = argparse.ArgumentParser()
|
274 |
+
|
275 |
+
## Required parameters
|
276 |
+
parser.add_argument("--input_file_path", default=None, type=str, required=True,
|
277 |
+
help="The output directory where the input files will be written.")
|
278 |
+
parser.add_argument("--output_file_path", default=None, type=str, required=True,
|
279 |
+
help="The output directory where the output files will be written.")
|
280 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
281 |
+
help="The output directory where the logs and results will be saved.")
|
282 |
+
parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
## Other parameters
|
287 |
+
parser.add_argument("--ExpName", default="", type=str,
|
288 |
+
help="The experiment name used in Azure Table.")
|
289 |
+
|
290 |
+
## Encoder options
|
291 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
292 |
+
help="The encoder model architecture to be fine-tuned.")
|
293 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
294 |
+
help="The encoder model checkpoint for weights initialization.")
|
295 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
296 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
297 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
298 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
299 |
+
|
300 |
+
## Decoder options
|
301 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
302 |
+
help="The decoder model architecture to be fine-tuned.")
|
303 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
304 |
+
help="The decoder model checkpoint for weights initialization.")
|
305 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
306 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
307 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
308 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
309 |
+
|
310 |
+
## Variational auto-encoder
|
311 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
312 |
+
parser.add_argument("--use_deterministic_connect", action='store_true',
|
313 |
+
help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
|
314 |
+
|
315 |
+
## Objective functions
|
316 |
+
parser.add_argument("--mlm", action='store_true',
|
317 |
+
help="Train with masked-language modeling loss instead of language modeling.")
|
318 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
319 |
+
help="Ratio of tokens to mask for masked language modeling loss")
|
320 |
+
parser.add_argument("--beta", type=float, default=1.0,
|
321 |
+
help="The weighting hyper-parameter of the KL term in VAE")
|
322 |
+
|
323 |
+
|
324 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
325 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
326 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
327 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
328 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
329 |
+
help="Optional input sequence length after tokenization."
|
330 |
+
"The training dataset will be truncated in block of this size for training."
|
331 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
332 |
+
parser.add_argument("--do_train", action='store_true',
|
333 |
+
help="Whether to run training.")
|
334 |
+
parser.add_argument("--do_eval", action='store_true',
|
335 |
+
help="Whether to run eval on the dev set.")
|
336 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
337 |
+
help="Run evaluation during training at each logging step.")
|
338 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
339 |
+
help="Set this flag if you are using an uncased model.")
|
340 |
+
|
341 |
+
|
342 |
+
# Training Schedules
|
343 |
+
parser.add_argument("--ratio_increase", default=0.25, type=float,
|
344 |
+
help="Learning schedule, the percentage for the annealing stage.")
|
345 |
+
parser.add_argument("--ratio_zero", default=0.25, type=float,
|
346 |
+
help="Learning schedule, the percentage for the pure auto-encoding stage.")
|
347 |
+
parser.add_argument("--fb_mode", default=0, type=int,
|
348 |
+
help="free bit training mode.")
|
349 |
+
parser.add_argument("--dim_target_kl", default=3.0, type=float,
|
350 |
+
help="dim_target_kl free bit training mode.")
|
351 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
352 |
+
help="Batch size per GPU/CPU for training.")
|
353 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
354 |
+
help="Batch size per GPU/CPU for evaluation.")
|
355 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
356 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
357 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
358 |
+
help="The initial learning rate for Adam.")
|
359 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
360 |
+
help="Weight deay if we apply some.")
|
361 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
362 |
+
help="Epsilon for Adam optimizer.")
|
363 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
364 |
+
help="Max gradient norm.")
|
365 |
+
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
366 |
+
help="Total number of training epochs to perform.")
|
367 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
368 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
369 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
370 |
+
help="Linear warmup over warmup_steps.")
|
371 |
+
parser.add_argument("--use_philly", action='store_true',
|
372 |
+
help="Use Philly for computing.")
|
373 |
+
|
374 |
+
## IO: Logging and Saving
|
375 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
376 |
+
help="Log every X updates steps.")
|
377 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
378 |
+
help="Save checkpoint every X updates steps.")
|
379 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
380 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
381 |
+
parser.add_argument("--no_cuda", action='store_true',
|
382 |
+
help="Avoid using CUDA when available")
|
383 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
384 |
+
help="Overwrite the content of the output directory")
|
385 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
386 |
+
help="Overwrite the cached training and evaluation sets")
|
387 |
+
parser.add_argument('--seed', type=int, default=42,
|
388 |
+
help="random seed for initialization")
|
389 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
390 |
+
help="Evaluate the results at the given global step")
|
391 |
+
|
392 |
+
# Precision & Distributed Training
|
393 |
+
parser.add_argument('--fp16', action='store_true',
|
394 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
395 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
396 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
397 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
398 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
399 |
+
help="For distributed training: local_rank")
|
400 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
401 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
402 |
+
args = parser.parse_args()
|
403 |
+
|
404 |
+
if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
|
405 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
406 |
+
"flag (masked language modeling).")
|
407 |
+
|
408 |
+
if os.path.exists(args.output_file_path) and os.listdir(args.output_file_path) and args.do_train and not args.overwrite_output_dir:
|
409 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_file_path))
|
410 |
+
|
411 |
+
# Setup distant debugging if needed
|
412 |
+
if args.server_ip and args.server_port:
|
413 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
414 |
+
import ptvsd
|
415 |
+
print("Waiting for debugger attach")
|
416 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
417 |
+
ptvsd.wait_for_attach()
|
418 |
+
|
419 |
+
# Setup CUDA, GPU & distributed training
|
420 |
+
logger.info(f'Local rank is {args.local_rank}')
|
421 |
+
if args.local_rank == -1 or args.no_cuda:
|
422 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
423 |
+
args.n_gpu = torch.cuda.device_count()
|
424 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
425 |
+
torch.cuda.set_device(args.local_rank)
|
426 |
+
device = torch.device("cuda", args.local_rank)
|
427 |
+
torch.distributed.init_process_group(backend='nccl')
|
428 |
+
args.n_gpu = 1
|
429 |
+
args.device = device
|
430 |
+
|
431 |
+
# Setup logging
|
432 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
433 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
434 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
435 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
436 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
437 |
+
|
438 |
+
args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
|
439 |
+
table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
|
440 |
+
try:
|
441 |
+
ts.create_table(table_name)
|
442 |
+
except:
|
443 |
+
pass
|
444 |
+
|
445 |
+
|
446 |
+
# Set seed
|
447 |
+
set_seed(args)
|
448 |
+
|
449 |
+
# Load pretrained model and tokenizer
|
450 |
+
if args.local_rank not in [-1, 0]:
|
451 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
452 |
+
|
453 |
+
## Encoder
|
454 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
455 |
+
encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
|
456 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
457 |
+
if args.block_size <= 0:
|
458 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
459 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
460 |
+
model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
|
461 |
+
# model_encoder.to(args.device)
|
462 |
+
|
463 |
+
## Decoder
|
464 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
465 |
+
decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
|
466 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
467 |
+
if args.block_size <= 0:
|
468 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
469 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
470 |
+
model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
|
471 |
+
|
472 |
+
# Chunyuan: Add Padding token to GPT2
|
473 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
474 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
475 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
476 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
477 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
478 |
+
|
479 |
+
# model_decoder.to(args.device)
|
480 |
+
|
481 |
+
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) #
|
482 |
+
|
483 |
+
# on_gpu = next(model_vae.parameters()).is_cuda
|
484 |
+
|
485 |
+
|
486 |
+
|
487 |
+
if args.local_rank == 0:
|
488 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
489 |
+
|
490 |
+
logger.info("Training/evaluation parameters %s", args)
|
491 |
+
|
492 |
+
global_step= 0
|
493 |
+
# Training
|
494 |
+
if args.do_train:
|
495 |
+
if args.local_rank not in [-1, 0]:
|
496 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
497 |
+
|
498 |
+
train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
|
499 |
+
|
500 |
+
if args.local_rank == 0:
|
501 |
+
torch.distributed.barrier()
|
502 |
+
|
503 |
+
num_collected, num_dropped = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
|
504 |
+
logger.info(" num_collected = %s, num_dropped = %s", num_collected, num_dropped)
|
505 |
+
|
506 |
+
if __name__ == "__main__":
|
507 |
+
main()
|
Optimus/code/examples/big_ae/run_dialog_dataloader.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
|
24 |
+
|
25 |
+
import pdb
|
26 |
+
import argparse
|
27 |
+
import glob
|
28 |
+
import logging
|
29 |
+
|
30 |
+
import os
|
31 |
+
import pickle
|
32 |
+
import random
|
33 |
+
|
34 |
+
import numpy as np
|
35 |
+
import torch
|
36 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
37 |
+
from torch.utils.data.distributed import DistributedSampler
|
38 |
+
from tensorboardX import SummaryWriter
|
39 |
+
from tqdm import tqdm, trange
|
40 |
+
from collections import defaultdict
|
41 |
+
|
42 |
+
# from azure.cosmosdb.table.tableservice import TableService
|
43 |
+
# from azure.cosmosdb.table.models import Entity
|
44 |
+
from datetime import datetime
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
49 |
+
BertConfig, BertForLatentConnector, BertTokenizer,
|
50 |
+
GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
|
51 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
52 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
53 |
+
|
54 |
+
from utils import (calc_iwnll, calc_mi, calc_au, Dialog_BucketingDataLoader, TextDataset_Split, TextDataset_2Tokenizers, frange_cycle_linear, frange_cycle_zero_linear)
|
55 |
+
|
56 |
+
|
57 |
+
from modules import VAE
|
58 |
+
|
59 |
+
|
60 |
+
# logging.getLogger("azure").setLevel(logging.WARNING)
|
61 |
+
# logging.getLogger("TableService").setLevel(logging.WARNING)
|
62 |
+
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
|
66 |
+
MODEL_CLASSES = {
|
67 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
68 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
69 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
|
70 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
storage_name="textae"
|
75 |
+
key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
|
76 |
+
# ts = TableService(account_name=storage_name, account_key=key)
|
77 |
+
|
78 |
+
|
79 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
80 |
+
if isinstance(tokenizer, list):
|
81 |
+
if not evaluate:
|
82 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
83 |
+
file_path=args.train_data_file
|
84 |
+
else:
|
85 |
+
args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
86 |
+
file_path=args.eval_data_file
|
87 |
+
dataloader = Dialog_BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True)
|
88 |
+
else:
|
89 |
+
pass
|
90 |
+
return dataloader
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
def set_seed(args):
|
96 |
+
random.seed(args.seed)
|
97 |
+
np.random.seed(args.seed)
|
98 |
+
torch.manual_seed(args.seed)
|
99 |
+
if args.n_gpu > 0:
|
100 |
+
torch.cuda.manual_seed_all(args.seed)
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
|
105 |
+
""" Train the model """
|
106 |
+
if args.local_rank in [-1, 0]:
|
107 |
+
tb_writer = SummaryWriter()
|
108 |
+
|
109 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
110 |
+
# train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
111 |
+
# train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
112 |
+
|
113 |
+
if args.max_steps > 0:
|
114 |
+
t_total = args.max_steps
|
115 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
116 |
+
else:
|
117 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
118 |
+
|
119 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
120 |
+
|
121 |
+
|
122 |
+
# model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
|
123 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
124 |
+
optimizer_grouped_parameters = [
|
125 |
+
{'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
126 |
+
{'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
127 |
+
]
|
128 |
+
|
129 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
130 |
+
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
131 |
+
|
132 |
+
|
133 |
+
if args.fp16:
|
134 |
+
try:
|
135 |
+
from apex import amp
|
136 |
+
except ImportError:
|
137 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
138 |
+
model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
|
139 |
+
|
140 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
141 |
+
if args.n_gpu > 1:
|
142 |
+
model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
|
143 |
+
|
144 |
+
# Distributed training (should be after apex fp16 initialization)
|
145 |
+
if args.local_rank != -1:
|
146 |
+
model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
|
147 |
+
output_device=args.local_rank,
|
148 |
+
find_unused_parameters=True)
|
149 |
+
|
150 |
+
|
151 |
+
# Train!
|
152 |
+
logger.info("***** Running training *****")
|
153 |
+
logger.info(" Num examples = %d", train_dataloader.num_examples)
|
154 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
155 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
156 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
157 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
158 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
159 |
+
logger.info(" Total optimization steps = %d", t_total)
|
160 |
+
|
161 |
+
global_step = 0
|
162 |
+
tr_loss, logging_loss = 0.0, 0.0
|
163 |
+
|
164 |
+
|
165 |
+
model_vae.zero_grad()
|
166 |
+
|
167 |
+
# model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
|
168 |
+
|
169 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
170 |
+
|
171 |
+
n_iter = int(args.num_train_epochs) * len(train_dataloader)
|
172 |
+
beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta, n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
|
173 |
+
|
174 |
+
tmp_list = []
|
175 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
176 |
+
for epoch in train_iterator:
|
177 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
178 |
+
for step, batch in enumerate(epoch_iterator):
|
179 |
+
|
180 |
+
input_ids_bert_ctx, input_ids_bert, input_ids_gpt, token_lengths = batch
|
181 |
+
|
182 |
+
logger.info(f'Conxtext in Bert, Length {token_lengths[0]} ; Tokens: {input_ids_bert_ctx}')
|
183 |
+
logger.info(f'Response in Bert, Length {token_lengths[1]} ; Tokens: {input_ids_bert}')
|
184 |
+
logger.info(f'Response in GPT2, Length {token_lengths[2]} ; Tokens: {input_ids_gpt}')
|
185 |
+
# TODO: write donw training scripts for dialog response generation
|
186 |
+
|
187 |
+
|
188 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
189 |
+
|
190 |
+
global_step += 1
|
191 |
+
|
192 |
+
|
193 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
194 |
+
epoch_iterator.close()
|
195 |
+
break
|
196 |
+
|
197 |
+
|
198 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
199 |
+
train_iterator.close()
|
200 |
+
break
|
201 |
+
|
202 |
+
if args.local_rank in [-1, 0]:
|
203 |
+
tb_writer.close()
|
204 |
+
|
205 |
+
return global_step
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
def main():
|
213 |
+
parser = argparse.ArgumentParser()
|
214 |
+
|
215 |
+
## Required parameters
|
216 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
217 |
+
help="The input training data file (a text file).")
|
218 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
219 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
220 |
+
parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
|
221 |
+
|
222 |
+
## Other parameters
|
223 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
224 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
225 |
+
parser.add_argument("--ExpName", default="", type=str,
|
226 |
+
help="The experiment name used in Azure Table.")
|
227 |
+
|
228 |
+
## Encoder options
|
229 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
230 |
+
help="The encoder model architecture to be fine-tuned.")
|
231 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
232 |
+
help="The encoder model checkpoint for weights initialization.")
|
233 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
234 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
235 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
236 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
237 |
+
|
238 |
+
## Decoder options
|
239 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
240 |
+
help="The decoder model architecture to be fine-tuned.")
|
241 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
242 |
+
help="The decoder model checkpoint for weights initialization.")
|
243 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
244 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
245 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
246 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
247 |
+
|
248 |
+
## Variational auto-encoder
|
249 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
250 |
+
parser.add_argument("--use_deterministic_connect", action='store_true',
|
251 |
+
help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
|
252 |
+
parser.add_argument("--use_pretrained_model", action='store_true',
|
253 |
+
help="Use pre-trained auto-encoder models as the initialization")
|
254 |
+
|
255 |
+
## Objective functions
|
256 |
+
parser.add_argument("--mlm", action='store_true',
|
257 |
+
help="Train with masked-language modeling loss instead of language modeling.")
|
258 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
259 |
+
help="Ratio of tokens to mask for masked language modeling loss")
|
260 |
+
parser.add_argument("--beta", type=float, default=1.0,
|
261 |
+
help="The weighting hyper-parameter of the KL term in VAE")
|
262 |
+
|
263 |
+
|
264 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
265 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
266 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
267 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
268 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
269 |
+
help="Optional input sequence length after tokenization."
|
270 |
+
"The training dataset will be truncated in block of this size for training."
|
271 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
272 |
+
parser.add_argument("--do_train", action='store_true',
|
273 |
+
help="Whether to run training.")
|
274 |
+
parser.add_argument("--do_eval", action='store_true',
|
275 |
+
help="Whether to run eval on the dev set.")
|
276 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
277 |
+
help="Run evaluation during training at each logging step.")
|
278 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
279 |
+
help="Set this flag if you are using an uncased model.")
|
280 |
+
|
281 |
+
|
282 |
+
# Training Schedules
|
283 |
+
parser.add_argument("--ratio_increase", default=0.25, type=float,
|
284 |
+
help="Learning schedule, the percentage for the annealing stage.")
|
285 |
+
parser.add_argument("--ratio_zero", default=0.25, type=float,
|
286 |
+
help="Learning schedule, the percentage for the pure auto-encoding stage.")
|
287 |
+
parser.add_argument("--fb_mode", default=0, type=int,
|
288 |
+
help="free bit training mode.")
|
289 |
+
parser.add_argument("--dim_target_kl", default=3.0, type=float,
|
290 |
+
help="dim_target_kl free bit training mode.")
|
291 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
292 |
+
help="Batch size per GPU/CPU for training.")
|
293 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
294 |
+
help="Batch size per GPU/CPU for evaluation.")
|
295 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
296 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
297 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
298 |
+
help="The initial learning rate for Adam.")
|
299 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
300 |
+
help="Weight deay if we apply some.")
|
301 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
302 |
+
help="Epsilon for Adam optimizer.")
|
303 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
304 |
+
help="Max gradient norm.")
|
305 |
+
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
306 |
+
help="Total number of training epochs to perform.")
|
307 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
308 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
309 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
310 |
+
help="Linear warmup over warmup_steps.")
|
311 |
+
parser.add_argument("--use_philly", action='store_true',
|
312 |
+
help="Use Philly for computing.")
|
313 |
+
|
314 |
+
## IO: Logging and Saving
|
315 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
316 |
+
help="Log every X updates steps.")
|
317 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
318 |
+
help="Save checkpoint every X updates steps.")
|
319 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
320 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
321 |
+
parser.add_argument("--no_cuda", action='store_true',
|
322 |
+
help="Avoid using CUDA when available")
|
323 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
324 |
+
help="Overwrite the content of the output directory")
|
325 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
326 |
+
help="Overwrite the cached training and evaluation sets")
|
327 |
+
parser.add_argument('--seed', type=int, default=42,
|
328 |
+
help="random seed for initialization")
|
329 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
330 |
+
help="Evaluate the results at the given global step")
|
331 |
+
|
332 |
+
# Precision & Distributed Training
|
333 |
+
parser.add_argument('--fp16', action='store_true',
|
334 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
335 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
336 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
337 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
338 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
339 |
+
help="For distributed training: local_rank")
|
340 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
341 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
342 |
+
args = parser.parse_args()
|
343 |
+
|
344 |
+
if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
|
345 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
346 |
+
"flag (masked language modeling).")
|
347 |
+
if args.eval_data_file is None and args.do_eval:
|
348 |
+
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
349 |
+
"or remove the --do_eval argument.")
|
350 |
+
|
351 |
+
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
352 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
353 |
+
|
354 |
+
# Setup distant debugging if needed
|
355 |
+
if args.server_ip and args.server_port:
|
356 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
357 |
+
import ptvsd
|
358 |
+
print("Waiting for debugger attach")
|
359 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
360 |
+
ptvsd.wait_for_attach()
|
361 |
+
|
362 |
+
# Setup CUDA, GPU & distributed training
|
363 |
+
if args.local_rank == -1 or args.no_cuda:
|
364 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
365 |
+
args.n_gpu = torch.cuda.device_count()
|
366 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
367 |
+
torch.cuda.set_device(args.local_rank)
|
368 |
+
device = torch.device("cuda", args.local_rank)
|
369 |
+
torch.distributed.init_process_group(backend='nccl')
|
370 |
+
args.n_gpu = 1
|
371 |
+
args.device = device
|
372 |
+
|
373 |
+
# Setup logging
|
374 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
375 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
376 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
377 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
378 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
379 |
+
|
380 |
+
args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
|
381 |
+
table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
|
382 |
+
try:
|
383 |
+
ts.create_table(table_name)
|
384 |
+
except:
|
385 |
+
pass
|
386 |
+
|
387 |
+
|
388 |
+
# Set seed
|
389 |
+
set_seed(args)
|
390 |
+
|
391 |
+
# Load pretrained model and tokenizer
|
392 |
+
if args.local_rank not in [-1, 0]:
|
393 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
394 |
+
|
395 |
+
if args.use_pretrained_model:
|
396 |
+
|
397 |
+
args.encoder_model_type = args.encoder_model_type.lower()
|
398 |
+
args.decoder_model_type = args.decoder_model_type.lower()
|
399 |
+
|
400 |
+
global_step = args.gloabl_step_eval
|
401 |
+
|
402 |
+
output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
|
403 |
+
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
|
404 |
+
checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
405 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
406 |
+
|
407 |
+
# Load a trained Encoder model and vocabulary
|
408 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
409 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
|
410 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
411 |
+
|
412 |
+
model_encoder.to(args.device)
|
413 |
+
if args.block_size <= 0:
|
414 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
415 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
416 |
+
|
417 |
+
# Load a trained Decoder model and vocabulary
|
418 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
419 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
|
420 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
421 |
+
model_decoder.to(args.device)
|
422 |
+
if args.block_size <= 0:
|
423 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
424 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
425 |
+
|
426 |
+
else:
|
427 |
+
## Encoder
|
428 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
429 |
+
encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
|
430 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
431 |
+
if args.block_size <= 0:
|
432 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
433 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
434 |
+
model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
|
435 |
+
# model_encoder.to(args.device)
|
436 |
+
|
437 |
+
## Decoder
|
438 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
439 |
+
decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
|
440 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
441 |
+
if args.block_size <= 0:
|
442 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
443 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
444 |
+
model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
|
445 |
+
|
446 |
+
pdb.set_trace()
|
447 |
+
|
448 |
+
# Chunyuan: Add Padding token to GPT2
|
449 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
450 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
451 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
452 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
453 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
454 |
+
|
455 |
+
# model_decoder.to(args.device)
|
456 |
+
|
457 |
+
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) #
|
458 |
+
|
459 |
+
# on_gpu = next(model_vae.parameters()).is_cuda
|
460 |
+
|
461 |
+
|
462 |
+
|
463 |
+
if args.local_rank == 0:
|
464 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
465 |
+
|
466 |
+
logger.info("Training/evaluation parameters %s", args)
|
467 |
+
|
468 |
+
global_step= 0
|
469 |
+
# Training
|
470 |
+
if args.do_train:
|
471 |
+
if args.local_rank not in [-1, 0]:
|
472 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
473 |
+
|
474 |
+
train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
|
475 |
+
|
476 |
+
if args.local_rank == 0:
|
477 |
+
torch.distributed.barrier()
|
478 |
+
|
479 |
+
global_step = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
|
480 |
+
logger.info(" global_step = %s", global_step)
|
481 |
+
|
482 |
+
if __name__ == "__main__":
|
483 |
+
main()
|
Optimus/code/examples/big_ae/run_encoding_generation.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
4 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
|
18 |
+
"""
|
19 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
20 |
+
|
21 |
+
import argparse
|
22 |
+
import glob
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import pickle
|
26 |
+
import random
|
27 |
+
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import numpy as np
|
32 |
+
|
33 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
34 |
+
from torch.utils.data.distributed import DistributedSampler
|
35 |
+
from tqdm import tqdm, trange
|
36 |
+
|
37 |
+
|
38 |
+
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
|
39 |
+
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
|
40 |
+
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
41 |
+
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
42 |
+
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
43 |
+
from pytorch_transformers import BertForLatentConnector, BertTokenizer
|
44 |
+
|
45 |
+
from collections import defaultdict
|
46 |
+
from modules import VAE
|
47 |
+
from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
|
48 |
+
|
49 |
+
|
50 |
+
import pdb
|
51 |
+
|
52 |
+
|
53 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
54 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
55 |
+
level = logging.INFO)
|
56 |
+
logger = logging.getLogger(__name__)
|
57 |
+
|
58 |
+
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
59 |
+
|
60 |
+
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
61 |
+
|
62 |
+
MODEL_CLASSES = {
|
63 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
64 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
|
65 |
+
}
|
66 |
+
|
67 |
+
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
68 |
+
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
69 |
+
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
70 |
+
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
|
71 |
+
(except for Alexei and Maria) are discovered.
|
72 |
+
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
73 |
+
remainder of the story. 1883 Western Siberia,
|
74 |
+
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
75 |
+
Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
76 |
+
father initially slaps him for making such an accusation, Rasputin watches as the
|
77 |
+
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
78 |
+
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
79 |
+
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
80 |
+
|
81 |
+
|
82 |
+
def set_seed(args):
|
83 |
+
np.random.seed(args.seed)
|
84 |
+
torch.manual_seed(args.seed)
|
85 |
+
if args.n_gpu > 0:
|
86 |
+
torch.cuda.manual_seed_all(args.seed)
|
87 |
+
|
88 |
+
|
89 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
90 |
+
if isinstance(tokenizer, list):
|
91 |
+
dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
92 |
+
else:
|
93 |
+
dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
94 |
+
return dataset
|
95 |
+
|
96 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
97 |
+
if isinstance(tokenizer, list):
|
98 |
+
if not evaluate:
|
99 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
100 |
+
file_path=args.train_data_file
|
101 |
+
else:
|
102 |
+
args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
103 |
+
file_path=args.eval_data_file
|
104 |
+
dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
|
105 |
+
else:
|
106 |
+
pass
|
107 |
+
return dataloader
|
108 |
+
|
109 |
+
|
110 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
111 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
112 |
+
Args:
|
113 |
+
logits: logits distribution shape (vocabulary size)
|
114 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
115 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
116 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
117 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
118 |
+
"""
|
119 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
120 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
121 |
+
if top_k > 0:
|
122 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
123 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
124 |
+
logits[indices_to_remove] = filter_value
|
125 |
+
|
126 |
+
if top_p > 0.0:
|
127 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
128 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
129 |
+
|
130 |
+
# Remove tokens with cumulative probability above the threshold
|
131 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
132 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
133 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
134 |
+
sorted_indices_to_remove[..., 0] = 0
|
135 |
+
|
136 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
137 |
+
logits[indices_to_remove] = filter_value
|
138 |
+
return logits
|
139 |
+
|
140 |
+
|
141 |
+
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
|
142 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
143 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
144 |
+
generated = context
|
145 |
+
with torch.no_grad():
|
146 |
+
for _ in trange(length):
|
147 |
+
|
148 |
+
inputs = {'input_ids': generated}
|
149 |
+
if is_xlnet:
|
150 |
+
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
151 |
+
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
152 |
+
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
|
153 |
+
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
154 |
+
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
155 |
+
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
156 |
+
target_mapping[0, 0, -1] = 1.0 # predict last token
|
157 |
+
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
158 |
+
|
159 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
160 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
161 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
162 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
163 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
164 |
+
return generated
|
165 |
+
|
166 |
+
def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
|
167 |
+
|
168 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
169 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
170 |
+
generated = context
|
171 |
+
with torch.no_grad():
|
172 |
+
while True:
|
173 |
+
# for _ in trange(length):
|
174 |
+
inputs = {'input_ids': generated, 'past': past}
|
175 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
176 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
177 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
178 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
179 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
180 |
+
|
181 |
+
# pdb.set_trace()
|
182 |
+
if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
|
183 |
+
break
|
184 |
+
|
185 |
+
return generated
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
# a wrapper function to choose between different play modes
|
190 |
+
def evaluate_latent_space(args, model_vae, encoder_tokenizer, decoder_tokenizer, prefix=""):
|
191 |
+
|
192 |
+
eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
|
193 |
+
|
194 |
+
# Eval!
|
195 |
+
logger.info("***** Running recontruction evaluation {} *****".format(prefix))
|
196 |
+
logger.info(" Num examples = %d", len(eval_dataloader))
|
197 |
+
logger.info(" Batch size = %d", args.per_gpu_eval_batch_size)
|
198 |
+
|
199 |
+
model_vae.eval()
|
200 |
+
|
201 |
+
model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
|
202 |
+
|
203 |
+
if args.play_mode == 'reconstrction':
|
204 |
+
result = calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
|
205 |
+
result_file_name = "eval_recontruction_results.txt"
|
206 |
+
elif args.play_mode == 'interpolation':
|
207 |
+
result = calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
|
208 |
+
result_file_name = "eval_interpolation_results.txt"
|
209 |
+
else:
|
210 |
+
logger.info("Please specify the corrent play mode [reconstrction, interpolation]")
|
211 |
+
|
212 |
+
|
213 |
+
eval_output_dir = args.output_dir
|
214 |
+
output_eval_file = os.path.join(eval_output_dir, result_file_name)
|
215 |
+
|
216 |
+
with open(output_eval_file, "w") as writer:
|
217 |
+
logger.info("***** Eval {} results *****".format(args.play_mode))
|
218 |
+
for key in sorted(result.keys()):
|
219 |
+
logger.info(" %s \n %s", key, str(result[key]))
|
220 |
+
writer.write("%s \n %s\n" % (key, str(result[key])))
|
221 |
+
|
222 |
+
return result
|
223 |
+
|
224 |
+
|
225 |
+
def calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
|
226 |
+
|
227 |
+
count = 0
|
228 |
+
result = defaultdict(str)
|
229 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating recontruction"):
|
230 |
+
# pdb.set_trace()
|
231 |
+
x0, x1, x_lengths = batch
|
232 |
+
|
233 |
+
max_len_values, _ = x_lengths.max(0)
|
234 |
+
x0 = x0[:,:max_len_values[0]]
|
235 |
+
x1 = x1[:,:max_len_values[1]]
|
236 |
+
|
237 |
+
x0 = x0.to(args.device)
|
238 |
+
x1 = x1.to(args.device)
|
239 |
+
x_lengths = x_lengths.to(args.device)
|
240 |
+
|
241 |
+
context_tokens = decoder_tokenizer.encode('<BOS>')
|
242 |
+
|
243 |
+
with torch.no_grad():
|
244 |
+
|
245 |
+
text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
|
246 |
+
# result["INPUT TEXT " + str(count)].append(text_x0)
|
247 |
+
|
248 |
+
pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
|
249 |
+
|
250 |
+
# Connect hidden feature to the latent space
|
251 |
+
# latent_z, loss_kl = model_vae.connect(pooled_hidden_fea)
|
252 |
+
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
|
253 |
+
latent_z = mean.squeeze(1)
|
254 |
+
|
255 |
+
past = latent_z
|
256 |
+
out = sample_sequence_conditional(
|
257 |
+
model=model_vae.decoder,
|
258 |
+
context=context_tokens,
|
259 |
+
past=past,
|
260 |
+
length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
261 |
+
temperature=args.temperature,
|
262 |
+
top_k=args.top_k,
|
263 |
+
top_p=args.top_p,
|
264 |
+
device=args.device,
|
265 |
+
decoder_tokenizer = decoder_tokenizer
|
266 |
+
)
|
267 |
+
text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
268 |
+
text_x1 = text_x1.split()[1:-1]
|
269 |
+
text_x1 = ' '.join(text_x1) + '\n'
|
270 |
+
result[text_x0] = text_x1
|
271 |
+
|
272 |
+
count += 1
|
273 |
+
if count>args.total_sents:
|
274 |
+
break
|
275 |
+
|
276 |
+
|
277 |
+
return result
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
def calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
|
283 |
+
|
284 |
+
count = 0
|
285 |
+
latent_codes = []
|
286 |
+
sample_interval = 0
|
287 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating interpolation"):
|
288 |
+
# pdb.set_trace()
|
289 |
+
x0, x1, x_lengths = batch
|
290 |
+
|
291 |
+
max_len_values, _ = x_lengths.max(0)
|
292 |
+
x0 = x0[:,:max_len_values[0]]
|
293 |
+
x0 = x0.to(args.device)
|
294 |
+
x_lengths = x_lengths.to(args.device)
|
295 |
+
|
296 |
+
|
297 |
+
with torch.no_grad():
|
298 |
+
if sample_interval == 0 or sample_interval == args.total_sents:
|
299 |
+
text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
|
300 |
+
pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
|
301 |
+
|
302 |
+
# Connect hidden feature to the latent space
|
303 |
+
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
|
304 |
+
latent_z = mean.squeeze(1)
|
305 |
+
|
306 |
+
latent_codes.append(latent_z)
|
307 |
+
|
308 |
+
if sample_interval == 5:
|
309 |
+
latent_codes.append(latent_z)
|
310 |
+
sample_interval = 0
|
311 |
+
continue
|
312 |
+
else:
|
313 |
+
sample_interval += 1
|
314 |
+
continue
|
315 |
+
|
316 |
+
count += 1
|
317 |
+
if count>args.total_sents:
|
318 |
+
break
|
319 |
+
|
320 |
+
context_tokens = decoder_tokenizer.encode('<BOS>')
|
321 |
+
result = defaultdict(str)
|
322 |
+
latent_codes_interpolation = []
|
323 |
+
num_steps = args.num_interpolation_steps
|
324 |
+
for step in range(num_steps+1):
|
325 |
+
latent_z = latent_codes[0] + (latent_codes[1] - latent_codes[0]) * step * 1.0/num_steps
|
326 |
+
|
327 |
+
past = latent_z
|
328 |
+
out = sample_sequence_conditional(
|
329 |
+
model=model_vae.decoder,
|
330 |
+
context=context_tokens,
|
331 |
+
past=past,
|
332 |
+
length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
333 |
+
temperature=args.temperature,
|
334 |
+
top_k=args.top_k,
|
335 |
+
top_p=args.top_p,
|
336 |
+
device=args.device,
|
337 |
+
decoder_tokenizer = decoder_tokenizer
|
338 |
+
)
|
339 |
+
text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
340 |
+
text_x1 = text_x1.split()[1:-1]
|
341 |
+
text_x1 = ' '.join(text_x1)
|
342 |
+
result[step] = text_x1
|
343 |
+
|
344 |
+
return result
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
def main():
|
350 |
+
parser = argparse.ArgumentParser()
|
351 |
+
|
352 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
353 |
+
help="The input training data file (a text file).")
|
354 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
355 |
+
help="An input evaluation data file to evaluate the perplexity on (a text file).")
|
356 |
+
parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
|
357 |
+
help="The directory where checkpoints are saved.")
|
358 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
359 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
360 |
+
parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
|
361 |
+
|
362 |
+
## Variational auto-encoder
|
363 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
364 |
+
parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
|
365 |
+
parser.add_argument("--num_interpolation_steps", default=10, type=int, help="Total sentences to test recontruction.")
|
366 |
+
parser.add_argument("--play_mode", default="interpolation", type=str,
|
367 |
+
help="interpolation or reconstruction.")
|
368 |
+
|
369 |
+
|
370 |
+
## Encoder options
|
371 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
372 |
+
help="The encoder model architecture to be fine-tuned.")
|
373 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
374 |
+
help="The encoder model checkpoint for weights initialization.")
|
375 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
376 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
377 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
378 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
379 |
+
|
380 |
+
## Decoder options
|
381 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
382 |
+
help="The decoder model architecture to be fine-tuned.")
|
383 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
384 |
+
help="The decoder model checkpoint for weights initialization.")
|
385 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
386 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
387 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
388 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
389 |
+
|
390 |
+
|
391 |
+
parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
|
392 |
+
help="Batch size per GPU/CPU for training.")
|
393 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
394 |
+
help="Batch size per GPU/CPU for evaluation.")
|
395 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
396 |
+
help="Evaluate the results at the given global step")
|
397 |
+
|
398 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
399 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
400 |
+
|
401 |
+
|
402 |
+
## Variational auto-encoder
|
403 |
+
parser.add_argument("--nz", default=32, type=int,
|
404 |
+
help="Latent space dimension.")
|
405 |
+
|
406 |
+
parser.add_argument("--prompt", type=str, default="")
|
407 |
+
parser.add_argument("--padding_text", type=str, default="")
|
408 |
+
parser.add_argument("--length", type=int, default=20)
|
409 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
410 |
+
parser.add_argument("--top_k", type=int, default=0)
|
411 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
412 |
+
parser.add_argument("--no_cuda", action='store_true',
|
413 |
+
help="Avoid using CUDA when available")
|
414 |
+
parser.add_argument('--seed', type=int, default=42,
|
415 |
+
help="random seed for initialization")
|
416 |
+
|
417 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
418 |
+
help="Optional input sequence length after tokenization."
|
419 |
+
"The training dataset will be truncated in block of this size for training."
|
420 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
421 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
422 |
+
help="Set this flag if you are using an uncased model.")
|
423 |
+
|
424 |
+
parser.add_argument("--use_philly", action='store_true',
|
425 |
+
help="Use Philly for computing.")
|
426 |
+
|
427 |
+
args = parser.parse_args()
|
428 |
+
|
429 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
430 |
+
args.n_gpu = torch.cuda.device_count()
|
431 |
+
|
432 |
+
set_seed(args)
|
433 |
+
|
434 |
+
|
435 |
+
args.encoder_model_type = args.encoder_model_type.lower()
|
436 |
+
args.decoder_model_type = args.decoder_model_type.lower()
|
437 |
+
|
438 |
+
|
439 |
+
global_step = args.gloabl_step_eval
|
440 |
+
|
441 |
+
output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
|
442 |
+
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
|
443 |
+
checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
444 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
445 |
+
|
446 |
+
# Load a trained Encoder model and vocabulary that you have fine-tuned
|
447 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
448 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
|
449 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
450 |
+
|
451 |
+
model_encoder.to(args.device)
|
452 |
+
if args.block_size <= 0:
|
453 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
454 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
455 |
+
|
456 |
+
# Load a trained Decoder model and vocabulary that you have fine-tuned
|
457 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
458 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
|
459 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
460 |
+
model_decoder.to(args.device)
|
461 |
+
if args.block_size <= 0:
|
462 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
463 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
464 |
+
|
465 |
+
# Load full model
|
466 |
+
output_full_dir = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
|
467 |
+
checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
|
468 |
+
|
469 |
+
# Chunyuan: Add Padding token to GPT2
|
470 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
471 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
472 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
473 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
474 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
475 |
+
|
476 |
+
|
477 |
+
# Evaluation
|
478 |
+
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args)
|
479 |
+
model_vae.load_state_dict(checkpoint['model_state_dict'])
|
480 |
+
logger.info("Pre-trained Optimus is successfully loaded")
|
481 |
+
model_vae.to(args.device)
|
482 |
+
|
483 |
+
result = evaluate_latent_space(args, model_vae, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
|
484 |
+
|
485 |
+
|
486 |
+
if __name__ == '__main__':
|
487 |
+
main()
|
Optimus/code/examples/big_ae/run_generation_from_prior.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
4 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
|
18 |
+
"""
|
19 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
20 |
+
|
21 |
+
import argparse
|
22 |
+
import glob
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import pickle
|
26 |
+
import random
|
27 |
+
|
28 |
+
|
29 |
+
cwd = os.getcwd()
|
30 |
+
print(f"Current working dir is {cwd}")
|
31 |
+
|
32 |
+
import sys
|
33 |
+
sys.path.append('./')
|
34 |
+
pt_path = os.path.join( cwd, 'pytorch_transformers')
|
35 |
+
sys.path.append(pt_path)
|
36 |
+
print(f"Pytorch Transformer {pt_path}")
|
37 |
+
|
38 |
+
import torch
|
39 |
+
import torch.nn.functional as F
|
40 |
+
import numpy as np
|
41 |
+
|
42 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
43 |
+
from torch.utils.data.distributed import DistributedSampler
|
44 |
+
from tqdm import tqdm, trange
|
45 |
+
|
46 |
+
|
47 |
+
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
|
48 |
+
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
|
49 |
+
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
50 |
+
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
51 |
+
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
52 |
+
from pytorch_transformers import BertForLatentConnector, BertTokenizer
|
53 |
+
|
54 |
+
import pytorch_transformers
|
55 |
+
|
56 |
+
from collections import defaultdict
|
57 |
+
from modules import VAE
|
58 |
+
from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
|
59 |
+
from metrics import Bleu, SelfBleu
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
import pdb
|
64 |
+
|
65 |
+
|
66 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
67 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
68 |
+
level = logging.INFO)
|
69 |
+
logger = logging.getLogger(__name__)
|
70 |
+
|
71 |
+
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
72 |
+
|
73 |
+
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
74 |
+
|
75 |
+
MODEL_CLASSES = {
|
76 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
77 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
|
78 |
+
}
|
79 |
+
|
80 |
+
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
81 |
+
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
82 |
+
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
83 |
+
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
|
84 |
+
(except for Alexei and Maria) are discovered.
|
85 |
+
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
86 |
+
remainder of the story. 1883 Western Siberia,
|
87 |
+
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
88 |
+
Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
89 |
+
father initially slaps him for making such an accusation, Rasputin watches as the
|
90 |
+
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
91 |
+
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
92 |
+
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
93 |
+
|
94 |
+
|
95 |
+
def set_seed(args):
|
96 |
+
np.random.seed(args.seed)
|
97 |
+
torch.manual_seed(args.seed)
|
98 |
+
if args.n_gpu > 0:
|
99 |
+
torch.cuda.manual_seed_all(args.seed)
|
100 |
+
|
101 |
+
|
102 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
103 |
+
if isinstance(tokenizer, list):
|
104 |
+
dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
105 |
+
else:
|
106 |
+
dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
107 |
+
return dataset
|
108 |
+
|
109 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
110 |
+
if isinstance(tokenizer, list):
|
111 |
+
if not evaluate:
|
112 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
113 |
+
file_path=args.train_data_file
|
114 |
+
else:
|
115 |
+
args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
116 |
+
file_path=args.eval_data_file
|
117 |
+
dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
|
118 |
+
else:
|
119 |
+
pass
|
120 |
+
return dataloader
|
121 |
+
|
122 |
+
|
123 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
124 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
125 |
+
Args:
|
126 |
+
logits: logits distribution shape (vocabulary size)
|
127 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
128 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
129 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
130 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
131 |
+
"""
|
132 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
133 |
+
|
134 |
+
# top-k
|
135 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
136 |
+
if top_k > 0:
|
137 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
138 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
139 |
+
logits[indices_to_remove] = filter_value
|
140 |
+
|
141 |
+
# top-p
|
142 |
+
if top_p > 0.0:
|
143 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
144 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
145 |
+
|
146 |
+
# Remove tokens with cumulative probability above the threshold
|
147 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
148 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
149 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
150 |
+
sorted_indices_to_remove[..., 0] = 0
|
151 |
+
|
152 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
153 |
+
logits[indices_to_remove] = filter_value
|
154 |
+
return logits
|
155 |
+
|
156 |
+
|
157 |
+
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
|
158 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
159 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
160 |
+
generated = context
|
161 |
+
with torch.no_grad():
|
162 |
+
for _ in trange(length):
|
163 |
+
|
164 |
+
inputs = {'input_ids': generated}
|
165 |
+
if is_xlnet:
|
166 |
+
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
167 |
+
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
168 |
+
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
|
169 |
+
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
170 |
+
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
171 |
+
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
172 |
+
target_mapping[0, 0, -1] = 1.0 # predict last token
|
173 |
+
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
174 |
+
|
175 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
176 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
177 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
178 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
179 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
180 |
+
return generated
|
181 |
+
|
182 |
+
def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None, max_seq_length=-1):
|
183 |
+
|
184 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
185 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
186 |
+
generated = context
|
187 |
+
gen_seq_length = 0
|
188 |
+
with torch.no_grad():
|
189 |
+
while True:
|
190 |
+
inputs = {'input_ids': generated, 'past': past}
|
191 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
192 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
193 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
194 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
195 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
196 |
+
gen_seq_length += 1
|
197 |
+
# pdb.set_trace()
|
198 |
+
if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
|
199 |
+
break
|
200 |
+
if max_seq_length>0 and gen_seq_length>max_seq_length:
|
201 |
+
break
|
202 |
+
|
203 |
+
return generated
|
204 |
+
|
205 |
+
|
206 |
+
def evaluate_generation_fromp_prior(model_vae, decoder_tokenizer, args, ns=1):
|
207 |
+
|
208 |
+
loc = torch.zeros([args.nz]).to(args.device)
|
209 |
+
scale = torch.ones([args.nz]).to(args.device)
|
210 |
+
prior = torch.distributions.normal.Normal(loc, scale)
|
211 |
+
|
212 |
+
context_tokens = decoder_tokenizer.encode('<BOS>')
|
213 |
+
|
214 |
+
count = 0
|
215 |
+
result = defaultdict(str)
|
216 |
+
for i in tqdm(range(args.num_sents)):
|
217 |
+
|
218 |
+
with torch.no_grad():
|
219 |
+
latent_z = prior.sample()
|
220 |
+
# pdb.set_trace()
|
221 |
+
past = model_vae.decoder.linear(latent_z.unsqueeze(0))
|
222 |
+
|
223 |
+
# pdb.set_trace()
|
224 |
+
out = sample_sequence_conditional(
|
225 |
+
model=model_vae.decoder,
|
226 |
+
context=context_tokens,
|
227 |
+
past=past,
|
228 |
+
length=args.max_seq_length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
229 |
+
temperature=args.temperature,
|
230 |
+
top_k=args.top_k,
|
231 |
+
top_p=args.top_p,
|
232 |
+
device=args.device,
|
233 |
+
decoder_tokenizer = decoder_tokenizer,
|
234 |
+
max_seq_length = args.max_seq_length
|
235 |
+
)
|
236 |
+
text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
237 |
+
text_x1 = text_x1.split()[1:-1]
|
238 |
+
text_x1 = ' '.join(text_x1) + '\n'
|
239 |
+
result[i] = text_x1
|
240 |
+
|
241 |
+
if args.use_philly:
|
242 |
+
print("PROGRESS: {}%".format( round(100 * i /args.num_sents , 4)))
|
243 |
+
|
244 |
+
with open(args.output_generation_file, "w") as writer:
|
245 |
+
logger.info("***** SHOW generated sentences from prior *****")
|
246 |
+
for key in sorted(result.keys()):
|
247 |
+
# logger.info(" %s \n %s", key, str(result[key]))
|
248 |
+
# writer.write("%s \n %s\n" % (key, str(result[key])))
|
249 |
+
writer.write("%s" % str(result[key]))
|
250 |
+
|
251 |
+
return result
|
252 |
+
|
253 |
+
|
254 |
+
# bleu = evaluate_bleu(results, args)
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
def main():
|
262 |
+
parser = argparse.ArgumentParser()
|
263 |
+
|
264 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
265 |
+
help="The input training data file (a text file).")
|
266 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
267 |
+
help="An input evaluation data file to evaluate the perplexity on (a text file).")
|
268 |
+
parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
|
269 |
+
help="The directory where checkpoints are saved.")
|
270 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
271 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
272 |
+
parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
|
273 |
+
|
274 |
+
## Variational auto-encoder
|
275 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
276 |
+
parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
|
277 |
+
parser.add_argument("--num_sents", default=10, type=int, help="Total sentences to generate.")
|
278 |
+
|
279 |
+
|
280 |
+
## Encoder options
|
281 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
282 |
+
help="The encoder model architecture to be fine-tuned.")
|
283 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
284 |
+
help="The encoder model checkpoint for weights initialization.")
|
285 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
286 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
287 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
288 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
289 |
+
|
290 |
+
## Decoder options
|
291 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
292 |
+
help="The decoder model architecture to be fine-tuned.")
|
293 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
294 |
+
help="The decoder model checkpoint for weights initialization.")
|
295 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
296 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
297 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
298 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
299 |
+
|
300 |
+
|
301 |
+
parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
|
302 |
+
help="Batch size per GPU/CPU for training.")
|
303 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
304 |
+
help="Batch size per GPU/CPU for evaluation.")
|
305 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
306 |
+
help="Evaluate the results at the given global step")
|
307 |
+
|
308 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
309 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
310 |
+
|
311 |
+
|
312 |
+
## Variational auto-encoder
|
313 |
+
parser.add_argument("--nz", default=32, type=int,
|
314 |
+
help="Latent space dimension.")
|
315 |
+
|
316 |
+
parser.add_argument("--prompt", type=str, default="")
|
317 |
+
parser.add_argument("--padding_text", type=str, default="")
|
318 |
+
parser.add_argument("--length", type=int, default=20)
|
319 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
320 |
+
parser.add_argument("--top_k", type=int, default=0)
|
321 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
322 |
+
parser.add_argument("--no_cuda", action='store_true',
|
323 |
+
help="Avoid using CUDA when available")
|
324 |
+
parser.add_argument('--seed', type=int, default=42,
|
325 |
+
help="random seed for initialization")
|
326 |
+
|
327 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
328 |
+
help="Optional input sequence length after tokenization."
|
329 |
+
"The training dataset will be truncated in block of this size for training."
|
330 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
331 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
332 |
+
help="Set this flag if you are using an uncased model.")
|
333 |
+
|
334 |
+
parser.add_argument("--use_philly", action='store_true',
|
335 |
+
help="Use Philly for computing.")
|
336 |
+
|
337 |
+
args = parser.parse_args()
|
338 |
+
|
339 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
340 |
+
args.n_gpu = torch.cuda.device_count()
|
341 |
+
|
342 |
+
set_seed(args)
|
343 |
+
|
344 |
+
|
345 |
+
args.encoder_model_type = args.encoder_model_type.lower()
|
346 |
+
args.decoder_model_type = args.decoder_model_type.lower()
|
347 |
+
|
348 |
+
|
349 |
+
global_step = args.gloabl_step_eval
|
350 |
+
|
351 |
+
output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
|
352 |
+
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
|
353 |
+
checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
354 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
355 |
+
|
356 |
+
# Load a trained Encoder model and vocabulary that you have fine-tuned
|
357 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
358 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
|
359 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
360 |
+
|
361 |
+
model_encoder.to(args.device)
|
362 |
+
if args.block_size <= 0:
|
363 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
364 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
365 |
+
|
366 |
+
# Load a trained Decoder model and vocabulary that you have fine-tuned
|
367 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
368 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
|
369 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
370 |
+
model_decoder.to(args.device)
|
371 |
+
if args.block_size <= 0:
|
372 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
373 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
374 |
+
|
375 |
+
# pdb.set_trace()
|
376 |
+
# Chunyuan: Add Padding token to GPT2
|
377 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
378 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
379 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
380 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
381 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
382 |
+
|
383 |
+
|
384 |
+
# Evaluation
|
385 |
+
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)
|
386 |
+
|
387 |
+
if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)
|
388 |
+
args.output_generation_file = os.path.join(args.output_dir, f"generation_from_vae_prior_t{args.temperature}_p{args.top_p}.txt")
|
389 |
+
# args.output_generation_file = args.train_data_file
|
390 |
+
result = evaluate_generation_fromp_prior(model_vae, tokenizer_decoder, args)
|
391 |
+
|
392 |
+
|
393 |
+
bleu5 = Bleu(test_text= args.output_generation_file,
|
394 |
+
real_text=args.eval_data_file,
|
395 |
+
num_real_sentences=args.num_sents,
|
396 |
+
num_fake_sentences=args.num_sents,
|
397 |
+
gram=5).get_score()
|
398 |
+
logger.info(f'The bleu score is {bleu5}')
|
399 |
+
|
400 |
+
sbleu5 = SelfBleu(test_text= args.output_generation_file,
|
401 |
+
num_sentences=args.num_sents,
|
402 |
+
gram=5).get_score()
|
403 |
+
logger.info(f'The self-bleu score is {sbleu5}')
|
404 |
+
|
405 |
+
args.eval_results_file = os.path.join(args.output_dir, f"eval_results_t{args.temperature}_p{args.top_p}.txt")
|
406 |
+
eval_results = {'bleu5':bleu5 , 'sbleu5':sbleu5}
|
407 |
+
with open(args.eval_results_file, "w") as writer:
|
408 |
+
logger.info("***** SHOW the quantative evalution results *****")
|
409 |
+
for key in sorted(eval_results.keys()):
|
410 |
+
writer.write("%s %s" % (key, str(eval_results[key])) )
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == '__main__':
|
414 |
+
main()
|
Optimus/code/examples/big_ae/run_gpt2_generation.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
4 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
|
18 |
+
"""
|
19 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
20 |
+
|
21 |
+
import argparse
|
22 |
+
import glob
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import pickle
|
26 |
+
import random
|
27 |
+
|
28 |
+
|
29 |
+
cwd = os.getcwd()
|
30 |
+
print(f"Current working dir is {cwd}")
|
31 |
+
|
32 |
+
import sys
|
33 |
+
sys.path.append('./')
|
34 |
+
pt_path = os.path.join( cwd, 'pytorch_transformers')
|
35 |
+
sys.path.append(pt_path)
|
36 |
+
print(f"Pytorch Transformer {pt_path}")
|
37 |
+
|
38 |
+
import torch
|
39 |
+
import torch.nn.functional as F
|
40 |
+
import numpy as np
|
41 |
+
|
42 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
43 |
+
from torch.utils.data.distributed import DistributedSampler
|
44 |
+
from tqdm import tqdm, trange
|
45 |
+
|
46 |
+
|
47 |
+
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
|
48 |
+
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
|
49 |
+
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
50 |
+
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
51 |
+
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
52 |
+
from pytorch_transformers import BertForLatentConnector, BertTokenizer
|
53 |
+
|
54 |
+
import pytorch_transformers
|
55 |
+
|
56 |
+
from collections import defaultdict
|
57 |
+
from modules import VAE
|
58 |
+
from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
|
59 |
+
from metrics import Bleu, SelfBleu
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
import pdb
|
64 |
+
|
65 |
+
|
66 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
67 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
68 |
+
level = logging.INFO)
|
69 |
+
logger = logging.getLogger(__name__)
|
70 |
+
|
71 |
+
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
72 |
+
|
73 |
+
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
74 |
+
|
75 |
+
MODEL_CLASSES = {
|
76 |
+
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
77 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
|
78 |
+
}
|
79 |
+
|
80 |
+
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
81 |
+
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
82 |
+
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
83 |
+
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
|
84 |
+
(except for Alexei and Maria) are discovered.
|
85 |
+
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
86 |
+
remainder of the story. 1883 Western Siberia,
|
87 |
+
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
88 |
+
Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
89 |
+
father initially slaps him for making such an accusation, Rasputin watches as the
|
90 |
+
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
91 |
+
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
92 |
+
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
93 |
+
|
94 |
+
|
95 |
+
def set_seed(args):
|
96 |
+
np.random.seed(args.seed)
|
97 |
+
torch.manual_seed(args.seed)
|
98 |
+
if args.n_gpu > 0:
|
99 |
+
torch.cuda.manual_seed_all(args.seed)
|
100 |
+
|
101 |
+
|
102 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
103 |
+
if isinstance(tokenizer, list):
|
104 |
+
dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
105 |
+
else:
|
106 |
+
dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
107 |
+
return dataset
|
108 |
+
|
109 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
110 |
+
if isinstance(tokenizer, list):
|
111 |
+
if not evaluate:
|
112 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
113 |
+
file_path=args.train_data_file
|
114 |
+
else:
|
115 |
+
args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
116 |
+
file_path=args.eval_data_file
|
117 |
+
dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
|
118 |
+
else:
|
119 |
+
pass
|
120 |
+
return dataloader
|
121 |
+
|
122 |
+
|
123 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
124 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
125 |
+
Args:
|
126 |
+
logits: logits distribution shape (vocabulary size)
|
127 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
128 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
129 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
130 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
131 |
+
"""
|
132 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
133 |
+
|
134 |
+
# top-k
|
135 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
136 |
+
if top_k > 0:
|
137 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
138 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
139 |
+
logits[indices_to_remove] = filter_value
|
140 |
+
|
141 |
+
# top-p
|
142 |
+
if top_p > 0.0:
|
143 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
144 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
145 |
+
|
146 |
+
# Remove tokens with cumulative probability above the threshold
|
147 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
148 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
149 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
150 |
+
sorted_indices_to_remove[..., 0] = 0
|
151 |
+
|
152 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
153 |
+
logits[indices_to_remove] = filter_value
|
154 |
+
return logits
|
155 |
+
|
156 |
+
|
157 |
+
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu', decoder_tokenizer=None, max_seq_length=-1):
|
158 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
159 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
160 |
+
generated = context
|
161 |
+
gen_seq_length = 0
|
162 |
+
with torch.no_grad():
|
163 |
+
while True:
|
164 |
+
|
165 |
+
inputs = {'input_ids': generated}
|
166 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
167 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
168 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
169 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
170 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
171 |
+
gen_seq_length += 1
|
172 |
+
if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
|
173 |
+
break
|
174 |
+
if max_seq_length>0 and gen_seq_length>max_seq_length:
|
175 |
+
break
|
176 |
+
|
177 |
+
|
178 |
+
return generated
|
179 |
+
|
180 |
+
def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None, max_seq_length=-1):
|
181 |
+
|
182 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
183 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
184 |
+
generated = context
|
185 |
+
gen_seq_length = 0
|
186 |
+
with torch.no_grad():
|
187 |
+
while True:
|
188 |
+
inputs = {'input_ids': generated, 'past': past}
|
189 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
190 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
191 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
192 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
193 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
194 |
+
gen_seq_length += 1
|
195 |
+
# pdb.set_trace()
|
196 |
+
if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
|
197 |
+
break
|
198 |
+
if max_seq_length>0 and gen_seq_length>max_seq_length:
|
199 |
+
break
|
200 |
+
|
201 |
+
return generated
|
202 |
+
|
203 |
+
|
204 |
+
def evaluate_generation_from_gpt2(model, decoder_tokenizer, args, ns=1):
|
205 |
+
|
206 |
+
loc = torch.zeros([args.nz]).to(args.device)
|
207 |
+
scale = torch.ones([args.nz]).to(args.device)
|
208 |
+
prior = torch.distributions.normal.Normal(loc, scale)
|
209 |
+
|
210 |
+
context_tokens = decoder_tokenizer.encode('<BOS>')
|
211 |
+
|
212 |
+
count = 0
|
213 |
+
result = defaultdict(str)
|
214 |
+
for i in tqdm(range(args.num_sents)):
|
215 |
+
|
216 |
+
with torch.no_grad():
|
217 |
+
|
218 |
+
out = sample_sequence(
|
219 |
+
model=model,
|
220 |
+
context=context_tokens,
|
221 |
+
length=args.max_seq_length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
222 |
+
temperature=args.temperature,
|
223 |
+
top_k=args.top_k,
|
224 |
+
top_p=args.top_p,
|
225 |
+
device=args.device,
|
226 |
+
decoder_tokenizer = decoder_tokenizer,
|
227 |
+
max_seq_length = args.max_seq_length
|
228 |
+
)
|
229 |
+
text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
230 |
+
text_x1 = text_x1.split()[1:-1]
|
231 |
+
text_x1 = ' '.join(text_x1) + '\n'
|
232 |
+
result[i] = text_x1
|
233 |
+
|
234 |
+
if args.use_philly:
|
235 |
+
print("PROGRESS: {}%".format( round(100 * i /args.num_sents , 4)))
|
236 |
+
|
237 |
+
with open(args.output_generation_file, "w") as writer:
|
238 |
+
logger.info("***** SHOW generated sentences from prior *****")
|
239 |
+
for key in sorted(result.keys()):
|
240 |
+
# logger.info(" %s \n %s", key, str(result[key]))
|
241 |
+
# writer.write("%s \n %s\n" % (key, str(result[key])))
|
242 |
+
writer.write("%s" % str(result[key]))
|
243 |
+
|
244 |
+
return result
|
245 |
+
|
246 |
+
|
247 |
+
# bleu = evaluate_bleu(results, args)
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
def main():
|
255 |
+
parser = argparse.ArgumentParser()
|
256 |
+
|
257 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
258 |
+
help="The input training data file (a text file).")
|
259 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
260 |
+
help="An input evaluation data file to evaluate the perplexity on (a text file).")
|
261 |
+
parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
|
262 |
+
help="The directory where checkpoints are saved.")
|
263 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
264 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
265 |
+
parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
|
266 |
+
|
267 |
+
## Variational auto-encoder
|
268 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
269 |
+
parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
|
270 |
+
parser.add_argument("--num_sents", default=10, type=int, help="Total sentences to generate.")
|
271 |
+
|
272 |
+
|
273 |
+
## Encoder options
|
274 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
275 |
+
help="The encoder model architecture to be fine-tuned.")
|
276 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
277 |
+
help="The encoder model checkpoint for weights initialization.")
|
278 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
279 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
280 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
281 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
282 |
+
|
283 |
+
## Decoder options
|
284 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
285 |
+
help="The decoder model architecture to be fine-tuned.")
|
286 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
287 |
+
help="The decoder model checkpoint for weights initialization.")
|
288 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
289 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
290 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
291 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
292 |
+
|
293 |
+
|
294 |
+
parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
|
295 |
+
help="Batch size per GPU/CPU for training.")
|
296 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
297 |
+
help="Batch size per GPU/CPU for evaluation.")
|
298 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
299 |
+
help="Evaluate the results at the given global step")
|
300 |
+
|
301 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
302 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
303 |
+
|
304 |
+
|
305 |
+
## Variational auto-encoder
|
306 |
+
parser.add_argument("--nz", default=32, type=int,
|
307 |
+
help="Latent space dimension.")
|
308 |
+
|
309 |
+
parser.add_argument("--prompt", type=str, default="")
|
310 |
+
parser.add_argument("--padding_text", type=str, default="")
|
311 |
+
parser.add_argument("--length", type=int, default=20)
|
312 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
313 |
+
parser.add_argument("--top_k", type=int, default=0)
|
314 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
315 |
+
parser.add_argument("--no_cuda", action='store_true',
|
316 |
+
help="Avoid using CUDA when available")
|
317 |
+
parser.add_argument('--seed', type=int, default=42,
|
318 |
+
help="random seed for initialization")
|
319 |
+
|
320 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
321 |
+
help="Optional input sequence length after tokenization."
|
322 |
+
"The training dataset will be truncated in block of this size for training."
|
323 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
324 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
325 |
+
help="Set this flag if you are using an uncased model.")
|
326 |
+
|
327 |
+
parser.add_argument("--use_philly", action='store_true',
|
328 |
+
help="Use Philly for computing.")
|
329 |
+
|
330 |
+
args = parser.parse_args()
|
331 |
+
|
332 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
333 |
+
args.n_gpu = torch.cuda.device_count()
|
334 |
+
|
335 |
+
set_seed(args)
|
336 |
+
args.decoder_model_type = args.decoder_model_type.lower()
|
337 |
+
|
338 |
+
|
339 |
+
global_step = args.gloabl_step_eval
|
340 |
+
|
341 |
+
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-{}'.format(global_step))
|
342 |
+
checkpoints = [ output_decoder_dir ]
|
343 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
344 |
+
|
345 |
+
# Load a trained Decoder model and vocabulary that you have fine-tuned
|
346 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
347 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
|
348 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
349 |
+
model_decoder.to(args.device)
|
350 |
+
if args.block_size <= 0:
|
351 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
352 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
353 |
+
|
354 |
+
# pdb.set_trace()
|
355 |
+
# Chunyuan: Add Padding token to GPT2
|
356 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
357 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
358 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
359 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
360 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
361 |
+
|
362 |
+
|
363 |
+
# Evaluation
|
364 |
+
if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)
|
365 |
+
args.output_generation_file = os.path.join(args.output_dir, f"generation_from_gpt2_t{args.temperature}_p{args.top_p}.txt")
|
366 |
+
# args.output_generation_file = args.train_data_file
|
367 |
+
result = evaluate_generation_from_gpt2(model_decoder, tokenizer_decoder, args)
|
368 |
+
|
369 |
+
bleu5 = Bleu(test_text= args.output_generation_file,
|
370 |
+
real_text=args.eval_data_file,
|
371 |
+
num_real_sentences=args.num_sents,
|
372 |
+
num_fake_sentences=args.num_sents,
|
373 |
+
gram=5).get_score()
|
374 |
+
logger.info(f'The bleu score is {bleu5}')
|
375 |
+
|
376 |
+
sbleu5 = SelfBleu(test_text= args.output_generation_file,
|
377 |
+
num_sentences=args.num_sents,
|
378 |
+
gram=5).get_score()
|
379 |
+
logger.info(f'The self-bleu score is {sbleu5}')
|
380 |
+
|
381 |
+
args.eval_results_file = os.path.join(args.output_dir, f"eval_results_t{args.temperature}_p{args.top_p}.txt")
|
382 |
+
eval_results = {'bleu5':bleu5 , 'sbleu5':sbleu5}
|
383 |
+
with open(args.eval_results_file, "w") as writer:
|
384 |
+
logger.info("***** SHOW the quantative evalution results *****")
|
385 |
+
for key in sorted(eval_results.keys()):
|
386 |
+
writer.write("%s %s" % (key, str(eval_results[key])) )
|
387 |
+
|
388 |
+
|
389 |
+
if __name__ == '__main__':
|
390 |
+
main()
|
Optimus/code/examples/big_ae/run_latent_generation.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
4 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
|
18 |
+
"""
|
19 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
20 |
+
|
21 |
+
import argparse
|
22 |
+
import glob
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import pickle
|
26 |
+
import random
|
27 |
+
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import numpy as np
|
32 |
+
|
33 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
34 |
+
from torch.utils.data.distributed import DistributedSampler
|
35 |
+
from tqdm import tqdm, trange
|
36 |
+
|
37 |
+
|
38 |
+
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
|
39 |
+
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
|
40 |
+
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
41 |
+
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
42 |
+
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
43 |
+
from pytorch_transformers import BertForLatentConnector, BertTokenizer
|
44 |
+
|
45 |
+
from collections import defaultdict
|
46 |
+
from modules import VAE
|
47 |
+
from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
|
48 |
+
|
49 |
+
|
50 |
+
import pdb
|
51 |
+
|
52 |
+
|
53 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
54 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
55 |
+
level = logging.INFO)
|
56 |
+
logger = logging.getLogger(__name__)
|
57 |
+
|
58 |
+
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
59 |
+
|
60 |
+
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
61 |
+
|
62 |
+
MODEL_CLASSES = {
|
63 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
64 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
|
65 |
+
}
|
66 |
+
|
67 |
+
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
68 |
+
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
69 |
+
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
70 |
+
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
|
71 |
+
(except for Alexei and Maria) are discovered.
|
72 |
+
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
73 |
+
remainder of the story. 1883 Western Siberia,
|
74 |
+
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
75 |
+
Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
76 |
+
father initially slaps him for making such an accusation, Rasputin watches as the
|
77 |
+
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
78 |
+
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
79 |
+
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
80 |
+
|
81 |
+
|
82 |
+
def set_seed(args):
|
83 |
+
np.random.seed(args.seed)
|
84 |
+
torch.manual_seed(args.seed)
|
85 |
+
if args.n_gpu > 0:
|
86 |
+
torch.cuda.manual_seed_all(args.seed)
|
87 |
+
|
88 |
+
|
89 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
90 |
+
if isinstance(tokenizer, list):
|
91 |
+
dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
92 |
+
else:
|
93 |
+
dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
94 |
+
return dataset
|
95 |
+
|
96 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
97 |
+
if isinstance(tokenizer, list):
|
98 |
+
if not evaluate:
|
99 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
100 |
+
file_path=args.train_data_file
|
101 |
+
else:
|
102 |
+
args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
103 |
+
file_path=args.eval_data_file
|
104 |
+
dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
|
105 |
+
else:
|
106 |
+
pass
|
107 |
+
return dataloader
|
108 |
+
|
109 |
+
|
110 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
111 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
112 |
+
Args:
|
113 |
+
logits: logits distribution shape (vocabulary size)
|
114 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
115 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
116 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
117 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
118 |
+
"""
|
119 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
120 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
121 |
+
if top_k > 0:
|
122 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
123 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
124 |
+
logits[indices_to_remove] = filter_value
|
125 |
+
|
126 |
+
if top_p > 0.0:
|
127 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
128 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
129 |
+
|
130 |
+
# Remove tokens with cumulative probability above the threshold
|
131 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
132 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
133 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
134 |
+
sorted_indices_to_remove[..., 0] = 0
|
135 |
+
|
136 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
137 |
+
logits[indices_to_remove] = filter_value
|
138 |
+
return logits
|
139 |
+
|
140 |
+
|
141 |
+
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
|
142 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
143 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
144 |
+
generated = context
|
145 |
+
with torch.no_grad():
|
146 |
+
for _ in trange(length):
|
147 |
+
|
148 |
+
inputs = {'input_ids': generated}
|
149 |
+
if is_xlnet:
|
150 |
+
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
151 |
+
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
152 |
+
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
|
153 |
+
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
154 |
+
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
155 |
+
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
156 |
+
target_mapping[0, 0, -1] = 1.0 # predict last token
|
157 |
+
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
158 |
+
|
159 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
160 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
161 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
162 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
163 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
164 |
+
return generated
|
165 |
+
|
166 |
+
def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
|
167 |
+
|
168 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
169 |
+
context = context.unsqueeze(0).repeat(num_samples, 1)
|
170 |
+
generated = context
|
171 |
+
with torch.no_grad():
|
172 |
+
while True:
|
173 |
+
# for _ in trange(length):
|
174 |
+
inputs = {'input_ids': generated, 'past': past}
|
175 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
176 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
177 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
178 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
179 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
180 |
+
|
181 |
+
# pdb.set_trace()
|
182 |
+
if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
|
183 |
+
break
|
184 |
+
|
185 |
+
return generated
|
186 |
+
|
187 |
+
|
188 |
+
def latent_code_from_text(text, tokenizer_encoder, model_vae, args):
|
189 |
+
tokenized1 = tokenizer_encoder.encode(text)
|
190 |
+
tokenized1 = [101] + tokenized1 + [102]
|
191 |
+
coded1 = torch.Tensor([tokenized1])
|
192 |
+
coded1 =torch.Tensor.long(coded1)
|
193 |
+
with torch.no_grad():
|
194 |
+
x0 = coded1
|
195 |
+
x0 = x0.to(args.device)
|
196 |
+
pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
|
197 |
+
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
|
198 |
+
latent_z = mean.squeeze(1)
|
199 |
+
coded_length = len(tokenized1)
|
200 |
+
return latent_z, coded_length
|
201 |
+
|
202 |
+
def text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder):
|
203 |
+
past = latent_z
|
204 |
+
context_tokens = tokenizer_decoder.encode('<BOS>')
|
205 |
+
|
206 |
+
length = 128 # maximum length, but not used
|
207 |
+
out = sample_sequence_conditional(
|
208 |
+
model=model_vae.decoder,
|
209 |
+
context=context_tokens,
|
210 |
+
past=past,
|
211 |
+
length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
212 |
+
temperature=args.temperature,
|
213 |
+
top_k=args.top_k,
|
214 |
+
top_p=args.top_p,
|
215 |
+
device=args.device,
|
216 |
+
decoder_tokenizer = tokenizer_decoder
|
217 |
+
)
|
218 |
+
text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
219 |
+
text_x1 = text_x1.split()[1:-1]
|
220 |
+
text_x1 = ' '.join(text_x1)
|
221 |
+
return text_x1
|
222 |
+
|
223 |
+
|
224 |
+
# a wrapper function to choose between different play modes
|
225 |
+
def evaluate_latent_space(args, model_vae, encoder_tokenizer, decoder_tokenizer, prefix=""):
|
226 |
+
|
227 |
+
eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
|
228 |
+
|
229 |
+
# Eval!
|
230 |
+
logger.info("***** Running recontruction evaluation {} *****".format(prefix))
|
231 |
+
logger.info(" Num examples = %d", len(eval_dataloader))
|
232 |
+
logger.info(" Batch size = %d", args.per_gpu_eval_batch_size)
|
233 |
+
|
234 |
+
model_vae.eval()
|
235 |
+
|
236 |
+
model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
|
237 |
+
|
238 |
+
if args.play_mode == 'reconstrction':
|
239 |
+
result = calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
|
240 |
+
result_file_name = "eval_recontruction_results.txt"
|
241 |
+
elif args.play_mode == 'interpolation':
|
242 |
+
result = calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
|
243 |
+
result_file_name = "eval_interpolation_results.txt"
|
244 |
+
else:
|
245 |
+
logger.info("Please specify the corrent play mode [reconstrction, interpolation]")
|
246 |
+
|
247 |
+
|
248 |
+
eval_output_dir = args.output_dir
|
249 |
+
output_eval_file = os.path.join(eval_output_dir, result_file_name)
|
250 |
+
|
251 |
+
with open(output_eval_file, "w") as writer:
|
252 |
+
logger.info("***** Eval {} results *****".format(args.play_mode))
|
253 |
+
for key in sorted(result.keys()):
|
254 |
+
logger.info(" %s \n %s", key, str(result[key]))
|
255 |
+
writer.write("%s \n %s\n" % (key, str(result[key])))
|
256 |
+
|
257 |
+
return result
|
258 |
+
|
259 |
+
|
260 |
+
def calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
|
261 |
+
|
262 |
+
count = 0
|
263 |
+
result = defaultdict(str)
|
264 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating recontruction"):
|
265 |
+
# pdb.set_trace()
|
266 |
+
x0, x1, x_lengths = batch
|
267 |
+
|
268 |
+
max_len_values, _ = x_lengths.max(0)
|
269 |
+
x0 = x0[:,:max_len_values[0]]
|
270 |
+
x1 = x1[:,:max_len_values[1]]
|
271 |
+
|
272 |
+
x0 = x0.to(args.device)
|
273 |
+
x1 = x1.to(args.device)
|
274 |
+
x_lengths = x_lengths.to(args.device)
|
275 |
+
|
276 |
+
context_tokens = decoder_tokenizer.encode('<BOS>')
|
277 |
+
|
278 |
+
with torch.no_grad():
|
279 |
+
|
280 |
+
text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
|
281 |
+
# result["INPUT TEXT " + str(count)].append(text_x0)
|
282 |
+
|
283 |
+
pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
|
284 |
+
|
285 |
+
# Connect hidden feature to the latent space
|
286 |
+
# latent_z, loss_kl = model_vae.connect(pooled_hidden_fea)
|
287 |
+
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
|
288 |
+
latent_z = mean.squeeze(1)
|
289 |
+
|
290 |
+
past = latent_z
|
291 |
+
out = sample_sequence_conditional(
|
292 |
+
model=model_vae.decoder,
|
293 |
+
context=context_tokens,
|
294 |
+
past=past,
|
295 |
+
length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
296 |
+
temperature=args.temperature,
|
297 |
+
top_k=args.top_k,
|
298 |
+
top_p=args.top_p,
|
299 |
+
device=args.device,
|
300 |
+
decoder_tokenizer = decoder_tokenizer
|
301 |
+
)
|
302 |
+
text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
303 |
+
text_x1 = text_x1.split()[1:-1]
|
304 |
+
text_x1 = ' '.join(text_x1) + '\n'
|
305 |
+
result[text_x0] = text_x1
|
306 |
+
|
307 |
+
count += 1
|
308 |
+
if count>args.total_sents:
|
309 |
+
break
|
310 |
+
|
311 |
+
|
312 |
+
return result
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
def calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
|
318 |
+
|
319 |
+
count = 0
|
320 |
+
latent_codes = []
|
321 |
+
sample_interval = 0
|
322 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating interpolation"):
|
323 |
+
# pdb.set_trace()
|
324 |
+
x0, x1, x_lengths = batch
|
325 |
+
|
326 |
+
max_len_values, _ = x_lengths.max(0)
|
327 |
+
x0 = x0[:,:max_len_values[0]]
|
328 |
+
x0 = x0.to(args.device)
|
329 |
+
x_lengths = x_lengths.to(args.device)
|
330 |
+
|
331 |
+
|
332 |
+
with torch.no_grad():
|
333 |
+
if sample_interval == 0 or sample_interval == args.total_sents:
|
334 |
+
text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
|
335 |
+
pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
|
336 |
+
|
337 |
+
# Connect hidden feature to the latent space
|
338 |
+
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
|
339 |
+
latent_z = mean.squeeze(1)
|
340 |
+
|
341 |
+
latent_codes.append(latent_z)
|
342 |
+
|
343 |
+
if sample_interval == 5:
|
344 |
+
latent_codes.append(latent_z)
|
345 |
+
sample_interval = 0
|
346 |
+
continue
|
347 |
+
else:
|
348 |
+
sample_interval += 1
|
349 |
+
continue
|
350 |
+
|
351 |
+
count += 1
|
352 |
+
if count>args.total_sents:
|
353 |
+
break
|
354 |
+
|
355 |
+
context_tokens = decoder_tokenizer.encode('<BOS>')
|
356 |
+
result = defaultdict(str)
|
357 |
+
latent_codes_interpolation = []
|
358 |
+
num_steps = args.num_interpolation_steps
|
359 |
+
for step in range(num_steps+1):
|
360 |
+
latent_z = latent_codes[0] + (latent_codes[1] - latent_codes[0]) * step * 1.0/num_steps
|
361 |
+
|
362 |
+
past = latent_z
|
363 |
+
out = sample_sequence_conditional(
|
364 |
+
model=model_vae.decoder,
|
365 |
+
context=context_tokens,
|
366 |
+
past=past,
|
367 |
+
length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
|
368 |
+
temperature=args.temperature,
|
369 |
+
top_k=args.top_k,
|
370 |
+
top_p=args.top_p,
|
371 |
+
device=args.device,
|
372 |
+
decoder_tokenizer = decoder_tokenizer
|
373 |
+
)
|
374 |
+
text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
|
375 |
+
text_x1 = text_x1.split()[1:-1]
|
376 |
+
text_x1 = ' '.join(text_x1)
|
377 |
+
result[step] = text_x1
|
378 |
+
|
379 |
+
return result
|
380 |
+
|
381 |
+
|
382 |
+
def interpolate(model_vae, tokenizer_encoder, tokenizer_decoder, args):
|
383 |
+
# and then in the main function
|
384 |
+
latent_z1, coded_length1 = latent_code_from_text(args.sent_source, tokenizer_encoder, model_vae, args)
|
385 |
+
latent_z2, coded_length2 = latent_code_from_text(args.sent_target, tokenizer_encoder, model_vae, args)
|
386 |
+
|
387 |
+
result = defaultdict(str)
|
388 |
+
|
389 |
+
num_steps = args.num_interpolation_steps + 1
|
390 |
+
for step in range(num_steps+1):
|
391 |
+
latent_z = latent_z1 + (latent_z2 - latent_z1) * step * 1.0/num_steps
|
392 |
+
|
393 |
+
text_interpolate = text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder)
|
394 |
+
result[step] = text_interpolate
|
395 |
+
print(text_interpolate)
|
396 |
+
|
397 |
+
return result
|
398 |
+
|
399 |
+
|
400 |
+
def analogy(model_vae, tokenizer_encoder, tokenizer_decoder, args):
|
401 |
+
|
402 |
+
latent_z1, coded_length1 = latent_code_from_text(args.sent_source, tokenizer_encoder, model_vae, args)
|
403 |
+
latent_z2, coded_length2 = latent_code_from_text(args.sent_target, tokenizer_encoder, model_vae, args)
|
404 |
+
latent_z3, coded_length3 = latent_code_from_text(args.sent_input, tokenizer_encoder, model_vae, args)
|
405 |
+
|
406 |
+
result = defaultdict(str)
|
407 |
+
|
408 |
+
latent_z = latent_z3 + args.degree_to_target * (latent_z2 - latent_z1)
|
409 |
+
|
410 |
+
text_analogy = text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder)
|
411 |
+
result[0] = text_analogy
|
412 |
+
print(text_analogy)
|
413 |
+
|
414 |
+
return result
|
415 |
+
|
416 |
+
|
417 |
+
def main():
|
418 |
+
parser = argparse.ArgumentParser()
|
419 |
+
|
420 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
421 |
+
help="The input training data file (a text file).")
|
422 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
423 |
+
help="An input evaluation data file to evaluate the perplexity on (a text file).")
|
424 |
+
parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
|
425 |
+
help="The directory where checkpoints are saved.")
|
426 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
427 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
428 |
+
parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
|
429 |
+
|
430 |
+
## Variational auto-encoder
|
431 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
432 |
+
parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
|
433 |
+
parser.add_argument("--num_interpolation_steps", default=10, type=int, help="Total sentences to test recontruction.")
|
434 |
+
parser.add_argument("--play_mode", default="interpolation", type=str,
|
435 |
+
help="interpolation or reconstruction.")
|
436 |
+
|
437 |
+
|
438 |
+
## Encoder options
|
439 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
440 |
+
help="The encoder model architecture to be fine-tuned.")
|
441 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
442 |
+
help="The encoder model checkpoint for weights initialization.")
|
443 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
444 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
445 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
446 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
447 |
+
|
448 |
+
## Decoder options
|
449 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
450 |
+
help="The decoder model architecture to be fine-tuned.")
|
451 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
452 |
+
help="The decoder model checkpoint for weights initialization.")
|
453 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
454 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
455 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
456 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
457 |
+
|
458 |
+
|
459 |
+
parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
|
460 |
+
help="Batch size per GPU/CPU for training.")
|
461 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
462 |
+
help="Batch size per GPU/CPU for evaluation.")
|
463 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
464 |
+
help="Evaluate the results at the given global step")
|
465 |
+
|
466 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
467 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
468 |
+
|
469 |
+
# Interact with users
|
470 |
+
parser.add_argument("--interact_with_user_input", action='store_true', help="Use user input to interact_with.")
|
471 |
+
parser.add_argument("--sent_source", type=str, default="")
|
472 |
+
parser.add_argument("--sent_target", type=str, default="")
|
473 |
+
parser.add_argument("--sent_input", type=str, default="")
|
474 |
+
parser.add_argument("--degree_to_target", type=float, default="1.0")
|
475 |
+
|
476 |
+
## Variational auto-encoder
|
477 |
+
parser.add_argument("--nz", default=32, type=int,
|
478 |
+
help="Latent space dimension.")
|
479 |
+
|
480 |
+
parser.add_argument("--prompt", type=str, default="")
|
481 |
+
parser.add_argument("--padding_text", type=str, default="")
|
482 |
+
parser.add_argument("--length", type=int, default=20)
|
483 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
484 |
+
parser.add_argument("--top_k", type=int, default=0)
|
485 |
+
parser.add_argument("--top_p", type=float, default=1.0)
|
486 |
+
parser.add_argument("--no_cuda", action='store_true',
|
487 |
+
help="Avoid using CUDA when available")
|
488 |
+
parser.add_argument('--seed', type=int, default=42,
|
489 |
+
help="random seed for initialization")
|
490 |
+
|
491 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
492 |
+
help="Optional input sequence length after tokenization."
|
493 |
+
"The training dataset will be truncated in block of this size for training."
|
494 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
495 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
496 |
+
help="Set this flag if you are using an uncased model.")
|
497 |
+
|
498 |
+
parser.add_argument("--use_philly", action='store_true',
|
499 |
+
help="Use Philly for computing.")
|
500 |
+
|
501 |
+
args = parser.parse_args()
|
502 |
+
|
503 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
504 |
+
args.n_gpu = torch.cuda.device_count()
|
505 |
+
|
506 |
+
set_seed(args)
|
507 |
+
|
508 |
+
|
509 |
+
args.encoder_model_type = args.encoder_model_type.lower()
|
510 |
+
args.decoder_model_type = args.decoder_model_type.lower()
|
511 |
+
|
512 |
+
|
513 |
+
global_step = args.gloabl_step_eval
|
514 |
+
|
515 |
+
output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
|
516 |
+
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
|
517 |
+
checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
518 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
519 |
+
|
520 |
+
# Load a trained Encoder model and vocabulary that you have fine-tuned
|
521 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
522 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
|
523 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
524 |
+
|
525 |
+
model_encoder.to(args.device)
|
526 |
+
if args.block_size <= 0:
|
527 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
528 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
529 |
+
|
530 |
+
# Load a trained Decoder model and vocabulary that you have fine-tuned
|
531 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
532 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
|
533 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
534 |
+
model_decoder.to(args.device)
|
535 |
+
if args.block_size <= 0:
|
536 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
537 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
538 |
+
|
539 |
+
# Load full model
|
540 |
+
output_full_dir = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
|
541 |
+
checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
|
542 |
+
|
543 |
+
# Chunyuan: Add Padding token to GPT2
|
544 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
545 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
546 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
547 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
548 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
549 |
+
|
550 |
+
|
551 |
+
# Evaluation
|
552 |
+
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args)
|
553 |
+
model_vae.load_state_dict(checkpoint['model_state_dict'])
|
554 |
+
logger.info("Pre-trained Optimus is successfully loaded")
|
555 |
+
model_vae.to(args.device)
|
556 |
+
|
557 |
+
if args.interact_with_user_input:
|
558 |
+
|
559 |
+
if args.play_mode == 'interpolation':
|
560 |
+
if len(args.sent_source) > 0 and len(args.sent_source) > 0:
|
561 |
+
result = interpolate(model_vae, tokenizer_encoder, tokenizer_decoder, args)
|
562 |
+
else:
|
563 |
+
print('Please check: specify the source and target sentences!')
|
564 |
+
|
565 |
+
if args.play_mode == 'analogy':
|
566 |
+
if len(args.sent_source) > 0 and len(args.sent_source) > 0 and len(args.sent_input) > 0:
|
567 |
+
result = analogy(model_vae, tokenizer_encoder, tokenizer_decoder, args)
|
568 |
+
else:
|
569 |
+
print('Please check: specify the source, target and input analogy sentences!')
|
570 |
+
|
571 |
+
|
572 |
+
else:
|
573 |
+
result = evaluate_latent_space(args, model_vae, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
|
574 |
+
|
575 |
+
|
576 |
+
if __name__ == '__main__':
|
577 |
+
main()
|
Optimus/code/examples/big_ae/run_lm_ae_pretraining.py
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
|
24 |
+
|
25 |
+
import pdb
|
26 |
+
import argparse
|
27 |
+
import glob
|
28 |
+
import logging
|
29 |
+
import os
|
30 |
+
import pickle
|
31 |
+
import random
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
36 |
+
from torch.utils.data.distributed import DistributedSampler
|
37 |
+
from tensorboardX import SummaryWriter
|
38 |
+
from tqdm import tqdm, trange
|
39 |
+
|
40 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
41 |
+
BertConfig, BertModel, BertTokenizer,
|
42 |
+
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
43 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
44 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
45 |
+
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
MODEL_CLASSES = {
|
51 |
+
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
52 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
53 |
+
'bert': (BertConfig, BertModel, BertTokenizer),
|
54 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
class TextDataset(Dataset):
|
59 |
+
def __init__(self, tokenizer, file_path='train', block_size=512):
|
60 |
+
assert os.path.isfile(file_path)
|
61 |
+
directory, filename = os.path.split(file_path)
|
62 |
+
cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
|
63 |
+
|
64 |
+
if os.path.exists(cached_features_file):
|
65 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
66 |
+
with open(cached_features_file, 'rb') as handle:
|
67 |
+
self.examples = pickle.load(handle)
|
68 |
+
else:
|
69 |
+
logger.info("Creating features from dataset file at %s", directory)
|
70 |
+
|
71 |
+
self.examples = []
|
72 |
+
with open(file_path, encoding="utf-8") as f:
|
73 |
+
text = f.read()
|
74 |
+
|
75 |
+
|
76 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
77 |
+
|
78 |
+
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
79 |
+
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
|
80 |
+
tokenized_text = tokenized_text[block_size:]
|
81 |
+
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
82 |
+
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
83 |
+
# can change this behavior by adding (model specific) padding.
|
84 |
+
|
85 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
86 |
+
with open(cached_features_file, 'wb') as handle:
|
87 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.examples)
|
91 |
+
|
92 |
+
def __getitem__(self, item):
|
93 |
+
return torch.tensor(self.examples[item])
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
class TextDataset_2Tokenizers(Dataset):
|
98 |
+
def __init__(self, tokenizers, file_path='train', block_size=512):
|
99 |
+
assert os.path.isfile(file_path)
|
100 |
+
directory, filename = os.path.split(file_path)
|
101 |
+
cached_features_file = os.path.join(directory, f'cached_lm_gpt_bert_{block_size}_{filename}')
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
if os.path.exists(cached_features_file):
|
106 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
107 |
+
with open(cached_features_file, 'rb') as handle:
|
108 |
+
self.examples = pickle.load(handle)
|
109 |
+
else:
|
110 |
+
logger.info("Creating features from dataset file at %s", directory)
|
111 |
+
|
112 |
+
|
113 |
+
with open(file_path, encoding="utf-8") as f:
|
114 |
+
text = f.read()
|
115 |
+
|
116 |
+
# pdb.set_trace()
|
117 |
+
self.examples = []
|
118 |
+
# Chunyuan: divide the linguistic text into the same length, then different tokenization schemes are applied
|
119 |
+
while len(text) >= block_size: # Truncate in block of block_size
|
120 |
+
|
121 |
+
tokenized_text0 = tokenizers[0].convert_tokens_to_ids(tokenizers[0].tokenize(text[:block_size]))
|
122 |
+
tokenized_text0 = tokenizers[0].add_special_tokens_single_sentence(tokenized_text0)
|
123 |
+
tokenized_text0_length = len(tokenized_text0)
|
124 |
+
pad_token=tokenizers[0].convert_tokens_to_ids([tokenizers[0].pad_token])[0]
|
125 |
+
tokenized_text0 = tokenized_text0 + ([pad_token] * (block_size - tokenized_text0_length) ) # Pad up to the sequence length.
|
126 |
+
assert len(tokenized_text0) == block_size
|
127 |
+
|
128 |
+
tokenized_text1 = tokenizers[1].convert_tokens_to_ids(tokenizers[1].tokenize(text[:block_size]))
|
129 |
+
tokenized_text1 = tokenizers[1].add_special_tokens_single_sentence(tokenized_text1)
|
130 |
+
tokenized_text1_length = len(tokenized_text1)
|
131 |
+
pad_token=tokenizers[1].convert_tokens_to_ids([tokenizers[1].pad_token])[0]
|
132 |
+
tokenized_text1 = tokenized_text1 + ([pad_token] * (block_size - tokenized_text1_length) ) # Pad up to the sequence length.
|
133 |
+
assert len(tokenized_text1) == block_size
|
134 |
+
|
135 |
+
self.examples.append([tokenized_text0, tokenized_text0_length, tokenized_text1, tokenized_text1_length])
|
136 |
+
|
137 |
+
text = text[block_size:]
|
138 |
+
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
139 |
+
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
140 |
+
# can change this behavior by adding (model specific) padding.
|
141 |
+
|
142 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
143 |
+
with open(cached_features_file, 'wb') as handle:
|
144 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.examples)
|
148 |
+
|
149 |
+
def __getitem__(self, item):
|
150 |
+
# pdb.set_trace()
|
151 |
+
# Convert to Tensors and build dataset
|
152 |
+
tokenized_text0= torch.tensor(self.examples[item][0], dtype=torch.long)
|
153 |
+
tokenized_text1= torch.tensor(self.examples[item][2], dtype=torch.long)
|
154 |
+
tokenized_text_lengths = torch.tensor([self.examples[item][1], self.examples[item][3]], dtype=torch.long)
|
155 |
+
# pdb.set_trace()
|
156 |
+
return (tokenized_text0, tokenized_text1, tokenized_text_lengths)
|
157 |
+
|
158 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
159 |
+
if isinstance(tokenizer, list):
|
160 |
+
dataset = TextDataset_2Tokenizers(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
161 |
+
else:
|
162 |
+
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
163 |
+
return dataset
|
164 |
+
|
165 |
+
|
166 |
+
def set_seed(args):
|
167 |
+
random.seed(args.seed)
|
168 |
+
np.random.seed(args.seed)
|
169 |
+
torch.manual_seed(args.seed)
|
170 |
+
if args.n_gpu > 0:
|
171 |
+
torch.cuda.manual_seed_all(args.seed)
|
172 |
+
|
173 |
+
|
174 |
+
def mask_tokens(inputs, tokenizer, args):
|
175 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
176 |
+
labels = inputs.clone()
|
177 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
178 |
+
|
179 |
+
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
|
180 |
+
labels[masked_indices==1] = -1 # We only compute loss on masked tokens
|
181 |
+
|
182 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
183 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
|
184 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
185 |
+
|
186 |
+
# 10% of the time, we replace masked input tokens with random word
|
187 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
|
188 |
+
indices_random = indices_random
|
189 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
190 |
+
inputs[indices_random] = random_words[indices_random]
|
191 |
+
|
192 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
193 |
+
return inputs, labels
|
194 |
+
|
195 |
+
|
196 |
+
def train(args, train_dataset, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer):
|
197 |
+
""" Train the model """
|
198 |
+
if args.local_rank in [-1, 0]:
|
199 |
+
tb_writer = SummaryWriter()
|
200 |
+
|
201 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
202 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
203 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
204 |
+
|
205 |
+
if args.max_steps > 0:
|
206 |
+
t_total = args.max_steps
|
207 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
208 |
+
else:
|
209 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
210 |
+
|
211 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
212 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
213 |
+
optimizer_grouped_encoder_parameters = [
|
214 |
+
{'params': [p for n, p in model_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
215 |
+
{'params': [p for n, p in model_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
216 |
+
]
|
217 |
+
|
218 |
+
optimizer_grouped_decoder_parameters = [
|
219 |
+
{'params': [p for n, p in model_decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
220 |
+
{'params': [p for n, p in model_decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
221 |
+
]
|
222 |
+
|
223 |
+
|
224 |
+
optimizer_encoder = AdamW(optimizer_grouped_encoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
225 |
+
optimizer_decoder = AdamW(optimizer_grouped_decoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
226 |
+
scheduler_encoder = WarmupLinearSchedule(optimizer_encoder, warmup_steps=args.warmup_steps, t_total=t_total)
|
227 |
+
scheduler_decoder = WarmupLinearSchedule(optimizer_decoder, warmup_steps=args.warmup_steps, t_total=t_total)
|
228 |
+
|
229 |
+
if args.fp16:
|
230 |
+
try:
|
231 |
+
from apex import amp
|
232 |
+
except ImportError:
|
233 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
234 |
+
model_encoder, optimizer_encoder = amp.initialize(model_encoder, optimizer_encoder, opt_level=args.fp16_opt_level)
|
235 |
+
model_decoder, optimizer_decoder = amp.initialize(model_decoder, optimizer_decoder, opt_level=args.fp16_opt_level)
|
236 |
+
|
237 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
238 |
+
if args.n_gpu > 1:
|
239 |
+
model_encoder = torch.nn.DataParallel(model_encoder)
|
240 |
+
model_decoder = torch.nn.DataParallel(model_decoder)
|
241 |
+
|
242 |
+
# Distributed training (should be after apex fp16 initialization)
|
243 |
+
if args.local_rank != -1:
|
244 |
+
model_encoder = torch.nn.parallel.DistributedDataParallel(model_encoder, device_ids=[args.local_rank],
|
245 |
+
output_device=args.local_rank,
|
246 |
+
find_unused_parameters=True)
|
247 |
+
model_decoder = torch.nn.parallel.DistributedDataParallel(model_decoder, device_ids=[args.local_rank],
|
248 |
+
output_device=args.local_rank,
|
249 |
+
find_unused_parameters=True)
|
250 |
+
|
251 |
+
# Train!
|
252 |
+
logger.info("***** Running training *****")
|
253 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
254 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
255 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
256 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
257 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
258 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
259 |
+
logger.info(" Total optimization steps = %d", t_total)
|
260 |
+
|
261 |
+
global_step = 0
|
262 |
+
tr_loss, logging_loss = 0.0, 0.0
|
263 |
+
model_encoder.zero_grad()
|
264 |
+
model_decoder.zero_grad()
|
265 |
+
|
266 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
267 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
268 |
+
for _ in train_iterator:
|
269 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
270 |
+
for step, batch in enumerate(epoch_iterator):
|
271 |
+
|
272 |
+
tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
|
273 |
+
# tokenized_text0 = tokenized_text0.to(args.device)
|
274 |
+
# tokenized_text1 = tokenized_text1.to(args.device)
|
275 |
+
# prepare input-output data for reconstruction
|
276 |
+
inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
|
277 |
+
labels = tokenized_text1
|
278 |
+
|
279 |
+
inputs = inputs.to(args.device)
|
280 |
+
labels = labels.to(args.device)
|
281 |
+
|
282 |
+
model_encoder.train()
|
283 |
+
model_decoder.train()
|
284 |
+
|
285 |
+
|
286 |
+
# Encoding
|
287 |
+
outputs = model_encoder(inputs)
|
288 |
+
pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
|
289 |
+
|
290 |
+
|
291 |
+
# Decoding
|
292 |
+
outputs = model_decoder(input_ids=tokenized_text1, past=pooled_hidden_fea, labels=labels)
|
293 |
+
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
294 |
+
|
295 |
+
|
296 |
+
if args.n_gpu > 1:
|
297 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
298 |
+
if args.gradient_accumulation_steps > 1:
|
299 |
+
loss = loss / args.gradient_accumulation_steps
|
300 |
+
|
301 |
+
if args.fp16:
|
302 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
303 |
+
scaled_loss.backward()
|
304 |
+
else:
|
305 |
+
loss.backward()
|
306 |
+
|
307 |
+
tr_loss += loss.item()
|
308 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
309 |
+
if args.fp16:
|
310 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_encoder), args.max_grad_norm)
|
311 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_decoder), args.max_grad_norm)
|
312 |
+
else:
|
313 |
+
torch.nn.utils.clip_grad_norm_(model_encoder.parameters(), args.max_grad_norm)
|
314 |
+
torch.nn.utils.clip_grad_norm_(model_decoder.parameters(), args.max_grad_norm)
|
315 |
+
optimizer_encoder.step()
|
316 |
+
optimizer_decoder.step()
|
317 |
+
scheduler_encoder.step() # Update learning rate schedule
|
318 |
+
scheduler_decoder.step()
|
319 |
+
model_encoder.zero_grad()
|
320 |
+
model_decoder.zero_grad()
|
321 |
+
global_step += 1
|
322 |
+
|
323 |
+
|
324 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
325 |
+
# Log metrics
|
326 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
327 |
+
results = evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer)
|
328 |
+
for key, value in results.items():
|
329 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
330 |
+
tb_writer.add_scalar('lr_encoder', scheduler_encoder.get_lr()[0], global_step)
|
331 |
+
tb_writer.add_scalar('lr_decoder', scheduler_decoder.get_lr()[0], global_step)
|
332 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
333 |
+
logging_loss = tr_loss
|
334 |
+
|
335 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
336 |
+
# Save model checkpoint
|
337 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
338 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
339 |
+
if not os.path.exists(output_encoder_dir):
|
340 |
+
os.makedirs(output_encoder_dir)
|
341 |
+
if not os.path.exists(output_decoder_dir):
|
342 |
+
os.makedirs(output_decoder_dir)
|
343 |
+
|
344 |
+
model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
|
345 |
+
model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
|
346 |
+
|
347 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
348 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
349 |
+
|
350 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
351 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
352 |
+
|
353 |
+
logger.info("Saving model checkpoint to %s", output_encoder_dir)
|
354 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
355 |
+
|
356 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
357 |
+
epoch_iterator.close()
|
358 |
+
break
|
359 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
360 |
+
train_iterator.close()
|
361 |
+
break
|
362 |
+
|
363 |
+
if args.local_rank in [-1, 0]:
|
364 |
+
tb_writer.close()
|
365 |
+
|
366 |
+
return global_step, tr_loss / global_step
|
367 |
+
|
368 |
+
|
369 |
+
def evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer, prefix=""):
|
370 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
371 |
+
eval_output_dir = args.output_dir
|
372 |
+
|
373 |
+
eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
|
374 |
+
|
375 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
376 |
+
os.makedirs(eval_output_dir)
|
377 |
+
|
378 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
379 |
+
# Note that DistributedSampler samples randomly
|
380 |
+
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
381 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
382 |
+
|
383 |
+
# Eval!
|
384 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
385 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
386 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
387 |
+
eval_loss = 0.0
|
388 |
+
nb_eval_steps = 0
|
389 |
+
model_encoder.eval()
|
390 |
+
model_decoder.eval()
|
391 |
+
|
392 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
393 |
+
# pdb.set_trace()
|
394 |
+
tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
|
395 |
+
# prepare input-output data for evaluation
|
396 |
+
inputs, labels = tokenized_text0, tokenized_text1
|
397 |
+
|
398 |
+
tokenized_text1 = tokenized_text1.to(args.device)
|
399 |
+
inputs = inputs.to(args.device)
|
400 |
+
labels = labels.to(args.device)
|
401 |
+
|
402 |
+
with torch.no_grad():
|
403 |
+
# Encoding
|
404 |
+
outputs = model_encoder(inputs)
|
405 |
+
pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
|
406 |
+
|
407 |
+
# Decoding
|
408 |
+
outputs = model_decoder(input_ids=tokenized_text1, past=pooled_hidden_fea, labels=labels)
|
409 |
+
lm_loss = outputs[0]
|
410 |
+
|
411 |
+
eval_loss += lm_loss.mean().item()
|
412 |
+
nb_eval_steps += 1
|
413 |
+
|
414 |
+
eval_loss = eval_loss / nb_eval_steps
|
415 |
+
perplexity = torch.exp(torch.tensor(eval_loss))
|
416 |
+
|
417 |
+
result = {
|
418 |
+
"perplexity": perplexity
|
419 |
+
}
|
420 |
+
|
421 |
+
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
422 |
+
with open(output_eval_file, "w") as writer:
|
423 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
424 |
+
for key in sorted(result.keys()):
|
425 |
+
logger.info(" %s = %s", key, str(result[key]))
|
426 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
427 |
+
|
428 |
+
return result
|
429 |
+
|
430 |
+
|
431 |
+
def main():
|
432 |
+
parser = argparse.ArgumentParser()
|
433 |
+
|
434 |
+
## Required parameters
|
435 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
436 |
+
help="The input training data file (a text file).")
|
437 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
438 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
439 |
+
|
440 |
+
## Other parameters
|
441 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
442 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
443 |
+
|
444 |
+
## Encoder options
|
445 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
446 |
+
help="The encoder model architecture to be fine-tuned.")
|
447 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
448 |
+
help="The encoder model checkpoint for weights initialization.")
|
449 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
450 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
451 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
452 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
453 |
+
|
454 |
+
## Decoder options
|
455 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
456 |
+
help="The decoder model architecture to be fine-tuned.")
|
457 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
458 |
+
help="The decoder model checkpoint for weights initialization.")
|
459 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
460 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
461 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
462 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
463 |
+
|
464 |
+
## Objective functions
|
465 |
+
parser.add_argument("--mlm", action='store_true',
|
466 |
+
help="Train with masked-language modeling loss instead of language modeling.")
|
467 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
468 |
+
help="Ratio of tokens to mask for masked language modeling loss")
|
469 |
+
|
470 |
+
|
471 |
+
|
472 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
473 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
474 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
475 |
+
help="Optional input sequence length after tokenization."
|
476 |
+
"The training dataset will be truncated in block of this size for training."
|
477 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
478 |
+
parser.add_argument("--do_train", action='store_true',
|
479 |
+
help="Whether to run training.")
|
480 |
+
parser.add_argument("--do_eval", action='store_true',
|
481 |
+
help="Whether to run eval on the dev set.")
|
482 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
483 |
+
help="Run evaluation during training at each logging step.")
|
484 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
485 |
+
help="Set this flag if you are using an uncased model.")
|
486 |
+
|
487 |
+
|
488 |
+
# Training Schedules
|
489 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
490 |
+
help="Batch size per GPU/CPU for training.")
|
491 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
|
492 |
+
help="Batch size per GPU/CPU for evaluation.")
|
493 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
494 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
495 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
496 |
+
help="The initial learning rate for Adam.")
|
497 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
498 |
+
help="Weight deay if we apply some.")
|
499 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
500 |
+
help="Epsilon for Adam optimizer.")
|
501 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
502 |
+
help="Max gradient norm.")
|
503 |
+
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
504 |
+
help="Total number of training epochs to perform.")
|
505 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
506 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
507 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
508 |
+
help="Linear warmup over warmup_steps.")
|
509 |
+
|
510 |
+
|
511 |
+
## IO: Logging and Saving
|
512 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
513 |
+
help="Log every X updates steps.")
|
514 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
515 |
+
help="Save checkpoint every X updates steps.")
|
516 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
517 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
518 |
+
parser.add_argument("--no_cuda", action='store_true',
|
519 |
+
help="Avoid using CUDA when available")
|
520 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
521 |
+
help="Overwrite the content of the output directory")
|
522 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
523 |
+
help="Overwrite the cached training and evaluation sets")
|
524 |
+
parser.add_argument('--seed', type=int, default=42,
|
525 |
+
help="random seed for initialization")
|
526 |
+
|
527 |
+
# Precision & Distributed Training
|
528 |
+
parser.add_argument('--fp16', action='store_true',
|
529 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
530 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
531 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
532 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
533 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
534 |
+
help="For distributed training: local_rank")
|
535 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
536 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
537 |
+
args = parser.parse_args()
|
538 |
+
|
539 |
+
if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
|
540 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
541 |
+
"flag (masked language modeling).")
|
542 |
+
if args.eval_data_file is None and args.do_eval:
|
543 |
+
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
544 |
+
"or remove the --do_eval argument.")
|
545 |
+
|
546 |
+
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
547 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
548 |
+
|
549 |
+
# Setup distant debugging if needed
|
550 |
+
if args.server_ip and args.server_port:
|
551 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
552 |
+
import ptvsd
|
553 |
+
print("Waiting for debugger attach")
|
554 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
555 |
+
ptvsd.wait_for_attach()
|
556 |
+
|
557 |
+
# Setup CUDA, GPU & distributed training
|
558 |
+
if args.local_rank == -1 or args.no_cuda:
|
559 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
560 |
+
args.n_gpu = torch.cuda.device_count()
|
561 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
562 |
+
torch.cuda.set_device(args.local_rank)
|
563 |
+
device = torch.device("cuda", args.local_rank)
|
564 |
+
torch.distributed.init_process_group(backend='nccl')
|
565 |
+
args.n_gpu = 1
|
566 |
+
args.device = device
|
567 |
+
|
568 |
+
# Setup logging
|
569 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
570 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
571 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
572 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
573 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
574 |
+
|
575 |
+
# Set seed
|
576 |
+
set_seed(args)
|
577 |
+
|
578 |
+
# Load pretrained model and tokenizer
|
579 |
+
if args.local_rank not in [-1, 0]:
|
580 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
581 |
+
|
582 |
+
## Encoder
|
583 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
584 |
+
encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
|
585 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
586 |
+
if args.block_size <= 0:
|
587 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
588 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
589 |
+
model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config)
|
590 |
+
model_encoder.to(args.device)
|
591 |
+
|
592 |
+
## Decoder
|
593 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
594 |
+
decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
|
595 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
596 |
+
if args.block_size <= 0:
|
597 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
598 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
599 |
+
model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config)
|
600 |
+
|
601 |
+
# Chunyuan: Add Padding token to GPT2
|
602 |
+
special_tokens_dict = {'pad_token': '<PAD>'}
|
603 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
604 |
+
print('We have added', num_added_toks, 'tokens')
|
605 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
606 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
607 |
+
|
608 |
+
model_decoder.to(args.device)
|
609 |
+
|
610 |
+
if args.local_rank == 0:
|
611 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
612 |
+
|
613 |
+
logger.info("Training/evaluation parameters %s", args)
|
614 |
+
|
615 |
+
global_step= 0
|
616 |
+
# Training
|
617 |
+
if args.do_train:
|
618 |
+
if args.local_rank not in [-1, 0]:
|
619 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
620 |
+
|
621 |
+
train_dataset = load_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
|
622 |
+
|
623 |
+
if args.local_rank == 0:
|
624 |
+
torch.distributed.barrier()
|
625 |
+
|
626 |
+
global_step, tr_loss = train(args, train_dataset, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder)
|
627 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
628 |
+
|
629 |
+
|
630 |
+
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
631 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
632 |
+
# Create output directory if needed
|
633 |
+
# Save model checkpoint
|
634 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
635 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
636 |
+
if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
|
637 |
+
os.makedirs(output_encoder_dir)
|
638 |
+
if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
|
639 |
+
os.makedirs(output_decoder_dir)
|
640 |
+
|
641 |
+
logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
|
642 |
+
logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
|
643 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
644 |
+
# They can then be reloaded using `from_pretrained()`
|
645 |
+
|
646 |
+
model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
|
647 |
+
model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
|
648 |
+
|
649 |
+
# Good practice: save your training arguments together with the trained model
|
650 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
651 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
652 |
+
|
653 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
654 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
655 |
+
|
656 |
+
|
657 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
658 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir)
|
659 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(output_encoder_dir, do_lower_case=args.do_lower_case)
|
660 |
+
model_encoder.to(args.device)
|
661 |
+
|
662 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
663 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
|
664 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(output_decoder_dir, do_lower_case=args.do_lower_case)
|
665 |
+
model_decoder.to(args.device)
|
666 |
+
|
667 |
+
|
668 |
+
# Evaluation
|
669 |
+
results = {}
|
670 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
671 |
+
global_step= 881
|
672 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
673 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
674 |
+
checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
675 |
+
|
676 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
677 |
+
for checkpoint in checkpoints:
|
678 |
+
global_step = checkpoint[0].split('-')[-1] if len(checkpoints) > 1 else ""
|
679 |
+
|
680 |
+
model_encoder = encoder_model_class.from_pretrained(checkpoint[0])
|
681 |
+
model_encoder.to(args.device)
|
682 |
+
model_decoder = decoder_model_class.from_pretrained(checkpoint[1])
|
683 |
+
model_decoder.to(args.device)
|
684 |
+
result = evaluate(args, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
|
685 |
+
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
686 |
+
results.update(result)
|
687 |
+
|
688 |
+
return results
|
689 |
+
|
690 |
+
|
691 |
+
if __name__ == "__main__":
|
692 |
+
main()
|
Optimus/code/examples/big_ae/run_lm_causal_pretraining.py
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
|
24 |
+
|
25 |
+
import pdb
|
26 |
+
import argparse
|
27 |
+
import glob
|
28 |
+
import logging
|
29 |
+
import os
|
30 |
+
import pickle
|
31 |
+
import random
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
36 |
+
from torch.utils.data.distributed import DistributedSampler
|
37 |
+
from tensorboardX import SummaryWriter
|
38 |
+
from tqdm import tqdm, trange
|
39 |
+
|
40 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
41 |
+
BertConfig, BertModel, BertTokenizer,
|
42 |
+
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
43 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
44 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
45 |
+
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
MODEL_CLASSES = {
|
51 |
+
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
52 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
53 |
+
'bert': (BertConfig, BertModel, BertTokenizer),
|
54 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
class TextDataset(Dataset):
|
59 |
+
def __init__(self, tokenizer, file_path='train', block_size=512):
|
60 |
+
assert os.path.isfile(file_path)
|
61 |
+
directory, filename = os.path.split(file_path)
|
62 |
+
cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
|
63 |
+
|
64 |
+
if os.path.exists(cached_features_file):
|
65 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
66 |
+
with open(cached_features_file, 'rb') as handle:
|
67 |
+
self.examples = pickle.load(handle)
|
68 |
+
else:
|
69 |
+
logger.info("Creating features from dataset file at %s", directory)
|
70 |
+
|
71 |
+
self.examples = []
|
72 |
+
with open(file_path, encoding="utf-8") as f:
|
73 |
+
text = f.read()
|
74 |
+
|
75 |
+
|
76 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
77 |
+
|
78 |
+
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
79 |
+
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
|
80 |
+
tokenized_text = tokenized_text[block_size:]
|
81 |
+
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
82 |
+
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
83 |
+
# can change this behavior by adding (model specific) padding.
|
84 |
+
|
85 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
86 |
+
with open(cached_features_file, 'wb') as handle:
|
87 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.examples)
|
91 |
+
|
92 |
+
def __getitem__(self, item):
|
93 |
+
return torch.tensor(self.examples[item])
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
class TextDataset_2Tokenizers(Dataset):
|
98 |
+
def __init__(self, tokenizers, file_path='train', block_size=512):
|
99 |
+
assert os.path.isfile(file_path)
|
100 |
+
directory, filename = os.path.split(file_path)
|
101 |
+
cached_features_file = os.path.join(directory, f'cached_lm_gpt_bert_{block_size}_{filename}')
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
if os.path.exists(cached_features_file):
|
106 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
107 |
+
with open(cached_features_file, 'rb') as handle:
|
108 |
+
self.examples = pickle.load(handle)
|
109 |
+
else:
|
110 |
+
logger.info("Creating features from dataset file at %s", directory)
|
111 |
+
|
112 |
+
|
113 |
+
with open(file_path, encoding="utf-8") as f:
|
114 |
+
text = f.read()
|
115 |
+
|
116 |
+
# pdb.set_trace()
|
117 |
+
self.examples = []
|
118 |
+
# Chunyuan: divide the linguistic text into the same length, then different tokenization schemes are applied
|
119 |
+
while len(text) >= block_size: # Truncate in block of block_size
|
120 |
+
|
121 |
+
tokenized_text0 = tokenizers[0].convert_tokens_to_ids(tokenizers[0].tokenize(text[:block_size]))
|
122 |
+
tokenized_text0 = tokenizers[0].add_special_tokens_single_sentence(tokenized_text0)
|
123 |
+
tokenized_text0_length = len(tokenized_text0)
|
124 |
+
pad_token=tokenizers[0].convert_tokens_to_ids([tokenizers[0].pad_token])[0]
|
125 |
+
tokenized_text0 = tokenized_text0 + ([pad_token] * (block_size - tokenized_text0_length) ) # Pad up to the sequence length.
|
126 |
+
assert len(tokenized_text0) == block_size
|
127 |
+
|
128 |
+
tokenized_text1 = tokenizers[1].convert_tokens_to_ids(tokenizers[1].tokenize(text[:block_size]))
|
129 |
+
tokenized_text1 = tokenizers[1].add_special_tokens_single_sentence(tokenized_text1)
|
130 |
+
tokenized_text1_length = len(tokenized_text1)
|
131 |
+
pad_token=tokenizers[1].convert_tokens_to_ids([tokenizers[1].pad_token])[0]
|
132 |
+
tokenized_text1 = tokenized_text1 + ([pad_token] * (block_size - tokenized_text1_length) ) # Pad up to the sequence length.
|
133 |
+
assert len(tokenized_text1) == block_size
|
134 |
+
|
135 |
+
self.examples.append([tokenized_text0, tokenized_text0_length, tokenized_text1, tokenized_text1_length])
|
136 |
+
|
137 |
+
text = text[block_size:]
|
138 |
+
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
139 |
+
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
140 |
+
# can change this behavior by adding (model specific) padding.
|
141 |
+
|
142 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
143 |
+
with open(cached_features_file, 'wb') as handle:
|
144 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.examples)
|
148 |
+
|
149 |
+
def __getitem__(self, item):
|
150 |
+
# pdb.set_trace()
|
151 |
+
# Convert to Tensors and build dataset
|
152 |
+
tokenized_text0= torch.tensor(self.examples[item][0], dtype=torch.long)
|
153 |
+
tokenized_text1= torch.tensor(self.examples[item][2], dtype=torch.long)
|
154 |
+
tokenized_text_lengths = torch.tensor([self.examples[item][1], self.examples[item][3]], dtype=torch.long)
|
155 |
+
# pdb.set_trace()
|
156 |
+
return (tokenized_text0, tokenized_text1, tokenized_text_lengths)
|
157 |
+
|
158 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
159 |
+
if isinstance(tokenizer, list):
|
160 |
+
dataset = TextDataset_2Tokenizers(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
161 |
+
else:
|
162 |
+
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
163 |
+
return dataset
|
164 |
+
|
165 |
+
|
166 |
+
def set_seed(args):
|
167 |
+
random.seed(args.seed)
|
168 |
+
np.random.seed(args.seed)
|
169 |
+
torch.manual_seed(args.seed)
|
170 |
+
if args.n_gpu > 0:
|
171 |
+
torch.cuda.manual_seed_all(args.seed)
|
172 |
+
|
173 |
+
|
174 |
+
def mask_tokens(inputs, tokenizer, args):
|
175 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
176 |
+
labels = inputs.clone()
|
177 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
178 |
+
|
179 |
+
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
|
180 |
+
labels[masked_indices==1] = -1 # We only compute loss on masked tokens
|
181 |
+
|
182 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
183 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
|
184 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
185 |
+
|
186 |
+
# 10% of the time, we replace masked input tokens with random word
|
187 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
|
188 |
+
indices_random = indices_random
|
189 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
190 |
+
inputs[indices_random] = random_words[indices_random]
|
191 |
+
|
192 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
193 |
+
return inputs, labels
|
194 |
+
|
195 |
+
|
196 |
+
def train(args, train_dataset, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer):
|
197 |
+
""" Train the model """
|
198 |
+
if args.local_rank in [-1, 0]:
|
199 |
+
tb_writer = SummaryWriter()
|
200 |
+
|
201 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
202 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
203 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
204 |
+
|
205 |
+
if args.max_steps > 0:
|
206 |
+
t_total = args.max_steps
|
207 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
208 |
+
else:
|
209 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
210 |
+
|
211 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
212 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
213 |
+
optimizer_grouped_encoder_parameters = [
|
214 |
+
{'params': [p for n, p in model_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
215 |
+
{'params': [p for n, p in model_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
216 |
+
]
|
217 |
+
|
218 |
+
optimizer_grouped_decoder_parameters = [
|
219 |
+
{'params': [p for n, p in model_decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
220 |
+
{'params': [p for n, p in model_decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
221 |
+
]
|
222 |
+
|
223 |
+
|
224 |
+
optimizer_encoder = AdamW(optimizer_grouped_encoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
225 |
+
optimizer_decoder = AdamW(optimizer_grouped_decoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
226 |
+
scheduler_encoder = WarmupLinearSchedule(optimizer_encoder, warmup_steps=args.warmup_steps, t_total=t_total)
|
227 |
+
scheduler_decoder = WarmupLinearSchedule(optimizer_decoder, warmup_steps=args.warmup_steps, t_total=t_total)
|
228 |
+
|
229 |
+
if args.fp16:
|
230 |
+
try:
|
231 |
+
from apex import amp
|
232 |
+
except ImportError:
|
233 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
234 |
+
model_encoder, optimizer_encoder = amp.initialize(model_encoder, optimizer_encoder, opt_level=args.fp16_opt_level)
|
235 |
+
model_decoder, optimizer_decoder = amp.initialize(model_decoder, optimizer_decoder, opt_level=args.fp16_opt_level)
|
236 |
+
|
237 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
238 |
+
if args.n_gpu > 1:
|
239 |
+
model_encoder = torch.nn.DataParallel(model_encoder)
|
240 |
+
model_decoder = torch.nn.DataParallel(model_decoder)
|
241 |
+
|
242 |
+
# Distributed training (should be after apex fp16 initialization)
|
243 |
+
if args.local_rank != -1:
|
244 |
+
model_encoder = torch.nn.parallel.DistributedDataParallel(model_encoder, device_ids=[args.local_rank],
|
245 |
+
output_device=args.local_rank,
|
246 |
+
find_unused_parameters=True)
|
247 |
+
model_decoder = torch.nn.parallel.DistributedDataParallel(model_decoder, device_ids=[args.local_rank],
|
248 |
+
output_device=args.local_rank,
|
249 |
+
find_unused_parameters=True)
|
250 |
+
|
251 |
+
# Train!
|
252 |
+
logger.info("***** Running training *****")
|
253 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
254 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
255 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
256 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
257 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
258 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
259 |
+
logger.info(" Total optimization steps = %d", t_total)
|
260 |
+
|
261 |
+
global_step = 0
|
262 |
+
tr_loss, logging_loss = 0.0, 0.0
|
263 |
+
model_encoder.zero_grad()
|
264 |
+
model_decoder.zero_grad()
|
265 |
+
|
266 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
267 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
268 |
+
for _ in train_iterator:
|
269 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
270 |
+
for step, batch in enumerate(epoch_iterator):
|
271 |
+
|
272 |
+
tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
|
273 |
+
# tokenized_text0 = tokenized_text0.to(args.device)
|
274 |
+
# tokenized_text1 = tokenized_text1.to(args.device)
|
275 |
+
# prepare input-output data for reconstruction
|
276 |
+
inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
|
277 |
+
labels = tokenized_text1
|
278 |
+
|
279 |
+
inputs = inputs.to(args.device)
|
280 |
+
labels = labels.to(args.device)
|
281 |
+
|
282 |
+
model_encoder.train()
|
283 |
+
model_decoder.train()
|
284 |
+
|
285 |
+
|
286 |
+
# Encoding
|
287 |
+
outputs = model_encoder(inputs)
|
288 |
+
pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
|
289 |
+
|
290 |
+
|
291 |
+
# Decoding
|
292 |
+
outputs = model_decoder(input_ids=tokenized_text1, past=None, labels=labels)
|
293 |
+
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
294 |
+
|
295 |
+
|
296 |
+
if args.n_gpu > 1:
|
297 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
298 |
+
if args.gradient_accumulation_steps > 1:
|
299 |
+
loss = loss / args.gradient_accumulation_steps
|
300 |
+
|
301 |
+
if args.fp16:
|
302 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
303 |
+
scaled_loss.backward()
|
304 |
+
else:
|
305 |
+
loss.backward()
|
306 |
+
|
307 |
+
tr_loss += loss.item()
|
308 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
309 |
+
if args.fp16:
|
310 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_encoder), args.max_grad_norm)
|
311 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_decoder), args.max_grad_norm)
|
312 |
+
else:
|
313 |
+
torch.nn.utils.clip_grad_norm_(model_encoder.parameters(), args.max_grad_norm)
|
314 |
+
torch.nn.utils.clip_grad_norm_(model_decoder.parameters(), args.max_grad_norm)
|
315 |
+
optimizer_encoder.step()
|
316 |
+
optimizer_decoder.step()
|
317 |
+
scheduler_encoder.step() # Update learning rate schedule
|
318 |
+
scheduler_decoder.step()
|
319 |
+
model_encoder.zero_grad()
|
320 |
+
model_decoder.zero_grad()
|
321 |
+
global_step += 1
|
322 |
+
|
323 |
+
|
324 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
325 |
+
# Log metrics
|
326 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
327 |
+
results = evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer)
|
328 |
+
for key, value in results.items():
|
329 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
330 |
+
tb_writer.add_scalar('lr_encoder', scheduler_encoder.get_lr()[0], global_step)
|
331 |
+
tb_writer.add_scalar('lr_decoder', scheduler_decoder.get_lr()[0], global_step)
|
332 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
333 |
+
logging_loss = tr_loss
|
334 |
+
|
335 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
336 |
+
# Save model checkpoint
|
337 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
338 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
339 |
+
if not os.path.exists(output_encoder_dir):
|
340 |
+
os.makedirs(output_encoder_dir)
|
341 |
+
if not os.path.exists(output_decoder_dir):
|
342 |
+
os.makedirs(output_decoder_dir)
|
343 |
+
|
344 |
+
model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
|
345 |
+
model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
|
346 |
+
|
347 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
348 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
349 |
+
|
350 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
351 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
352 |
+
|
353 |
+
logger.info("Saving model checkpoint to %s", output_encoder_dir)
|
354 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
355 |
+
|
356 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
357 |
+
epoch_iterator.close()
|
358 |
+
break
|
359 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
360 |
+
train_iterator.close()
|
361 |
+
break
|
362 |
+
|
363 |
+
if args.local_rank in [-1, 0]:
|
364 |
+
tb_writer.close()
|
365 |
+
|
366 |
+
return global_step, tr_loss / global_step
|
367 |
+
|
368 |
+
|
369 |
+
def evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer, prefix=""):
|
370 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
371 |
+
eval_output_dir = args.output_dir
|
372 |
+
|
373 |
+
eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
|
374 |
+
|
375 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
376 |
+
os.makedirs(eval_output_dir)
|
377 |
+
|
378 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
379 |
+
# Note that DistributedSampler samples randomly
|
380 |
+
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
381 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
382 |
+
|
383 |
+
# Eval!
|
384 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
385 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
386 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
387 |
+
eval_loss = 0.0
|
388 |
+
nb_eval_steps = 0
|
389 |
+
model_encoder.eval()
|
390 |
+
model_decoder.eval()
|
391 |
+
|
392 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
393 |
+
# pdb.set_trace()
|
394 |
+
tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
|
395 |
+
# prepare input-output data for evaluation
|
396 |
+
inputs, labels = tokenized_text0, tokenized_text1
|
397 |
+
|
398 |
+
tokenized_text1 = tokenized_text1.to(args.device)
|
399 |
+
inputs = inputs.to(args.device)
|
400 |
+
labels = labels.to(args.device)
|
401 |
+
|
402 |
+
with torch.no_grad():
|
403 |
+
# Encoding
|
404 |
+
outputs = model_encoder(inputs)
|
405 |
+
pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
|
406 |
+
|
407 |
+
# Decoding
|
408 |
+
outputs = model_decoder(input_ids=tokenized_text1, past=None, labels=labels)
|
409 |
+
lm_loss = outputs[0]
|
410 |
+
|
411 |
+
eval_loss += lm_loss.mean().item()
|
412 |
+
nb_eval_steps += 1
|
413 |
+
|
414 |
+
eval_loss = eval_loss / nb_eval_steps
|
415 |
+
perplexity = torch.exp(torch.tensor(eval_loss))
|
416 |
+
|
417 |
+
result = {
|
418 |
+
"perplexity": perplexity
|
419 |
+
}
|
420 |
+
|
421 |
+
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
422 |
+
with open(output_eval_file, "w") as writer:
|
423 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
424 |
+
for key in sorted(result.keys()):
|
425 |
+
logger.info(" %s = %s", key, str(result[key]))
|
426 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
427 |
+
|
428 |
+
return result
|
429 |
+
|
430 |
+
|
431 |
+
def main():
|
432 |
+
parser = argparse.ArgumentParser()
|
433 |
+
|
434 |
+
## Required parameters
|
435 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
436 |
+
help="The input training data file (a text file).")
|
437 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
438 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
439 |
+
|
440 |
+
## Other parameters
|
441 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
442 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
443 |
+
|
444 |
+
## Encoder options
|
445 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
446 |
+
help="The encoder model architecture to be fine-tuned.")
|
447 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
448 |
+
help="The encoder model checkpoint for weights initialization.")
|
449 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
450 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
451 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
452 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
453 |
+
|
454 |
+
## Decoder options
|
455 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
456 |
+
help="The decoder model architecture to be fine-tuned.")
|
457 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
458 |
+
help="The decoder model checkpoint for weights initialization.")
|
459 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
460 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
461 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
462 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
463 |
+
|
464 |
+
## Objective functions
|
465 |
+
parser.add_argument("--mlm", action='store_true',
|
466 |
+
help="Train with masked-language modeling loss instead of language modeling.")
|
467 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
468 |
+
help="Ratio of tokens to mask for masked language modeling loss")
|
469 |
+
|
470 |
+
|
471 |
+
|
472 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
473 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
474 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
475 |
+
help="Optional input sequence length after tokenization."
|
476 |
+
"The training dataset will be truncated in block of this size for training."
|
477 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
478 |
+
parser.add_argument("--do_train", action='store_true',
|
479 |
+
help="Whether to run training.")
|
480 |
+
parser.add_argument("--do_eval", action='store_true',
|
481 |
+
help="Whether to run eval on the dev set.")
|
482 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
483 |
+
help="Run evaluation during training at each logging step.")
|
484 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
485 |
+
help="Set this flag if you are using an uncased model.")
|
486 |
+
|
487 |
+
|
488 |
+
# Training Schedules
|
489 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
490 |
+
help="Batch size per GPU/CPU for training.")
|
491 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
|
492 |
+
help="Batch size per GPU/CPU for evaluation.")
|
493 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
494 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
495 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
496 |
+
help="The initial learning rate for Adam.")
|
497 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
498 |
+
help="Weight deay if we apply some.")
|
499 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
500 |
+
help="Epsilon for Adam optimizer.")
|
501 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
502 |
+
help="Max gradient norm.")
|
503 |
+
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
504 |
+
help="Total number of training epochs to perform.")
|
505 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
506 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
507 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
508 |
+
help="Linear warmup over warmup_steps.")
|
509 |
+
|
510 |
+
|
511 |
+
## IO: Logging and Saving
|
512 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
513 |
+
help="Log every X updates steps.")
|
514 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
515 |
+
help="Save checkpoint every X updates steps.")
|
516 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
517 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
518 |
+
parser.add_argument("--no_cuda", action='store_true',
|
519 |
+
help="Avoid using CUDA when available")
|
520 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
521 |
+
help="Overwrite the content of the output directory")
|
522 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
523 |
+
help="Overwrite the cached training and evaluation sets")
|
524 |
+
parser.add_argument('--seed', type=int, default=42,
|
525 |
+
help="random seed for initialization")
|
526 |
+
|
527 |
+
# Precision & Distributed Training
|
528 |
+
parser.add_argument('--fp16', action='store_true',
|
529 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
530 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
531 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
532 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
533 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
534 |
+
help="For distributed training: local_rank")
|
535 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
536 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
537 |
+
args = parser.parse_args()
|
538 |
+
|
539 |
+
if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
|
540 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
541 |
+
"flag (masked language modeling).")
|
542 |
+
if args.eval_data_file is None and args.do_eval:
|
543 |
+
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
544 |
+
"or remove the --do_eval argument.")
|
545 |
+
|
546 |
+
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
547 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
548 |
+
|
549 |
+
# Setup distant debugging if needed
|
550 |
+
if args.server_ip and args.server_port:
|
551 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
552 |
+
import ptvsd
|
553 |
+
print("Waiting for debugger attach")
|
554 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
555 |
+
ptvsd.wait_for_attach()
|
556 |
+
|
557 |
+
# Setup CUDA, GPU & distributed training
|
558 |
+
if args.local_rank == -1 or args.no_cuda:
|
559 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
560 |
+
args.n_gpu = torch.cuda.device_count()
|
561 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
562 |
+
torch.cuda.set_device(args.local_rank)
|
563 |
+
device = torch.device("cuda", args.local_rank)
|
564 |
+
torch.distributed.init_process_group(backend='nccl')
|
565 |
+
args.n_gpu = 1
|
566 |
+
args.device = device
|
567 |
+
|
568 |
+
# Setup logging
|
569 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
570 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
571 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
572 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
573 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
574 |
+
|
575 |
+
# Set seed
|
576 |
+
set_seed(args)
|
577 |
+
|
578 |
+
# Load pretrained model and tokenizer
|
579 |
+
if args.local_rank not in [-1, 0]:
|
580 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
581 |
+
|
582 |
+
## Encoder
|
583 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
584 |
+
encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
|
585 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
586 |
+
if args.block_size <= 0:
|
587 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
588 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
589 |
+
model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config)
|
590 |
+
model_encoder.to(args.device)
|
591 |
+
|
592 |
+
## Decoder
|
593 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
594 |
+
decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
|
595 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
596 |
+
if args.block_size <= 0:
|
597 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
598 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
599 |
+
model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config)
|
600 |
+
|
601 |
+
# Chunyuan: Add Padding token to GPT2
|
602 |
+
special_tokens_dict = {'pad_token': '<PAD>'}
|
603 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
604 |
+
print('We have added', num_added_toks, 'tokens')
|
605 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
606 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
607 |
+
|
608 |
+
model_decoder.to(args.device)
|
609 |
+
|
610 |
+
if args.local_rank == 0:
|
611 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
612 |
+
|
613 |
+
logger.info("Training/evaluation parameters %s", args)
|
614 |
+
|
615 |
+
global_step= 0
|
616 |
+
# Training
|
617 |
+
if args.do_train:
|
618 |
+
if args.local_rank not in [-1, 0]:
|
619 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
620 |
+
|
621 |
+
train_dataset = load_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
|
622 |
+
|
623 |
+
if args.local_rank == 0:
|
624 |
+
torch.distributed.barrier()
|
625 |
+
|
626 |
+
global_step, tr_loss = train(args, train_dataset, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder)
|
627 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
628 |
+
|
629 |
+
|
630 |
+
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
631 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
632 |
+
# Create output directory if needed
|
633 |
+
# Save model checkpoint
|
634 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
635 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
636 |
+
if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
|
637 |
+
os.makedirs(output_encoder_dir)
|
638 |
+
if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
|
639 |
+
os.makedirs(output_decoder_dir)
|
640 |
+
|
641 |
+
logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
|
642 |
+
logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
|
643 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
644 |
+
# They can then be reloaded using `from_pretrained()`
|
645 |
+
|
646 |
+
model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
|
647 |
+
model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
|
648 |
+
|
649 |
+
# Good practice: save your training arguments together with the trained model
|
650 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
651 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
652 |
+
|
653 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
654 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
655 |
+
|
656 |
+
|
657 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
658 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir)
|
659 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(output_encoder_dir, do_lower_case=args.do_lower_case)
|
660 |
+
model_encoder.to(args.device)
|
661 |
+
|
662 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
663 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
|
664 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(output_decoder_dir, do_lower_case=args.do_lower_case)
|
665 |
+
model_decoder.to(args.device)
|
666 |
+
|
667 |
+
|
668 |
+
# Evaluation
|
669 |
+
results = {}
|
670 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
671 |
+
global_step= 881
|
672 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
673 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
674 |
+
checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
675 |
+
|
676 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
677 |
+
for checkpoint in checkpoints:
|
678 |
+
global_step = checkpoint[0].split('-')[-1] if len(checkpoints) > 1 else ""
|
679 |
+
|
680 |
+
model_encoder = encoder_model_class.from_pretrained(checkpoint[0])
|
681 |
+
model_encoder.to(args.device)
|
682 |
+
model_decoder = decoder_model_class.from_pretrained(checkpoint[1])
|
683 |
+
model_decoder.to(args.device)
|
684 |
+
result = evaluate(args, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
|
685 |
+
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
686 |
+
results.update(result)
|
687 |
+
|
688 |
+
return results
|
689 |
+
|
690 |
+
|
691 |
+
if __name__ == "__main__":
|
692 |
+
main()
|
Optimus/code/examples/big_ae/run_lm_finetuning_baseline.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
import sys
|
27 |
+
sys.path.insert(0, '.')
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
import glob
|
31 |
+
import logging
|
32 |
+
import os
|
33 |
+
import pickle
|
34 |
+
import random
|
35 |
+
|
36 |
+
import numpy as np
|
37 |
+
import torch
|
38 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
|
39 |
+
from torch.utils.data.distributed import DistributedSampler
|
40 |
+
from tensorboardX import SummaryWriter
|
41 |
+
from tqdm import tqdm, trange
|
42 |
+
|
43 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
44 |
+
BertConfig, BertForMaskedLM, BertTokenizer,
|
45 |
+
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
46 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
47 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
48 |
+
|
49 |
+
from utils import (calc_iwnll, calc_mi, calc_au, TextDataset_Split, TextDataset_2Tokenizers)
|
50 |
+
|
51 |
+
import pdb
|
52 |
+
|
53 |
+
logger = logging.getLogger(__name__)
|
54 |
+
|
55 |
+
|
56 |
+
MODEL_CLASSES = {
|
57 |
+
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
58 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
59 |
+
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
60 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
class TextDataset(Dataset):
|
65 |
+
def __init__(self, tokenizer, file_path='train', block_size=512):
|
66 |
+
assert os.path.isfile(file_path)
|
67 |
+
directory, filename = os.path.split(file_path)
|
68 |
+
cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
|
69 |
+
|
70 |
+
if os.path.exists(cached_features_file):
|
71 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
72 |
+
with open(cached_features_file, 'rb') as handle:
|
73 |
+
self.examples = pickle.load(handle)
|
74 |
+
else:
|
75 |
+
logger.info("Creating features from dataset file at %s", directory)
|
76 |
+
|
77 |
+
self.examples = []
|
78 |
+
with open(file_path, encoding="utf-8") as f:
|
79 |
+
text = f.read()
|
80 |
+
|
81 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
82 |
+
|
83 |
+
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
84 |
+
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
|
85 |
+
tokenized_text = tokenized_text[block_size:]
|
86 |
+
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
87 |
+
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
88 |
+
# can change this behavior by adding (model specific) padding.
|
89 |
+
|
90 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
91 |
+
with open(cached_features_file, 'wb') as handle:
|
92 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.examples)
|
96 |
+
|
97 |
+
def __getitem__(self, item):
|
98 |
+
return torch.tensor(self.examples[item])
|
99 |
+
|
100 |
+
|
101 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
102 |
+
if isinstance(tokenizer, list):
|
103 |
+
dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
104 |
+
else:
|
105 |
+
dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
106 |
+
return dataset
|
107 |
+
|
108 |
+
|
109 |
+
def set_seed(args):
|
110 |
+
random.seed(args.seed)
|
111 |
+
np.random.seed(args.seed)
|
112 |
+
torch.manual_seed(args.seed)
|
113 |
+
if args.n_gpu > 0:
|
114 |
+
torch.cuda.manual_seed_all(args.seed)
|
115 |
+
|
116 |
+
|
117 |
+
def mask_tokens(inputs, tokenizer, args):
|
118 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
119 |
+
labels = inputs.clone()
|
120 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
121 |
+
|
122 |
+
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
|
123 |
+
labels[masked_indices==1] = -1 # We only compute loss on masked tokens
|
124 |
+
|
125 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
126 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
|
127 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
128 |
+
|
129 |
+
# 10% of the time, we replace masked input tokens with random word
|
130 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
|
131 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
132 |
+
inputs[indices_random] = random_words[indices_random]
|
133 |
+
|
134 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
135 |
+
return inputs, labels
|
136 |
+
|
137 |
+
|
138 |
+
def train(args, train_dataset, model, tokenizer):
|
139 |
+
""" Train the model """
|
140 |
+
if args.local_rank in [-1, 0]:
|
141 |
+
tb_writer = SummaryWriter()
|
142 |
+
|
143 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
144 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
145 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
146 |
+
|
147 |
+
if args.max_steps > 0:
|
148 |
+
t_total = args.max_steps
|
149 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
150 |
+
else:
|
151 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
152 |
+
|
153 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
154 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
155 |
+
optimizer_grouped_parameters = [
|
156 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
157 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
158 |
+
]
|
159 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
160 |
+
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
161 |
+
if args.fp16:
|
162 |
+
try:
|
163 |
+
from apex import amp
|
164 |
+
except ImportError:
|
165 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
166 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
167 |
+
|
168 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
169 |
+
if args.n_gpu > 1:
|
170 |
+
model = torch.nn.DataParallel(model)
|
171 |
+
|
172 |
+
# Distributed training (should be after apex fp16 initialization)
|
173 |
+
if args.local_rank != -1:
|
174 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
175 |
+
output_device=args.local_rank,
|
176 |
+
find_unused_parameters=True)
|
177 |
+
|
178 |
+
|
179 |
+
# Train!
|
180 |
+
logger.info("***** Running training *****")
|
181 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
182 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
183 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
184 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
185 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
186 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
187 |
+
logger.info(" Total optimization steps = %d", t_total)
|
188 |
+
|
189 |
+
global_step = 0
|
190 |
+
tr_loss, logging_loss = 0.0, 0.0
|
191 |
+
model.zero_grad()
|
192 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
193 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
194 |
+
for _ in train_iterator:
|
195 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
196 |
+
for step, batch in enumerate(epoch_iterator):
|
197 |
+
|
198 |
+
tokenized_text1, tokenized_text_lengths = batch
|
199 |
+
|
200 |
+
inputs, labels = tokenized_text1, tokenized_text1
|
201 |
+
|
202 |
+
inputs = inputs.to(args.device)
|
203 |
+
labels = labels.to(args.device)
|
204 |
+
|
205 |
+
model.train()
|
206 |
+
|
207 |
+
outputs = model(inputs, labels=labels, label_ignore=tokenizer.pad_token_id)
|
208 |
+
|
209 |
+
# pdb.set_trace()
|
210 |
+
loss = outputs[0].mean() # model outputs are always tuple in pytorch-transformers (see doc)
|
211 |
+
|
212 |
+
if args.use_philly:
|
213 |
+
print("PROGRESS: {}%".format(round(100 * (step + epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) * len(epoch_iterator)) , 4)))
|
214 |
+
print("EVALERR: {}%".format(loss))
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
if args.n_gpu > 1:
|
219 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
220 |
+
if args.gradient_accumulation_steps > 1:
|
221 |
+
loss = loss / args.gradient_accumulation_steps
|
222 |
+
|
223 |
+
if args.fp16:
|
224 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
225 |
+
scaled_loss.backward()
|
226 |
+
else:
|
227 |
+
loss.backward()
|
228 |
+
|
229 |
+
tr_loss += loss.item()
|
230 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
231 |
+
if args.fp16:
|
232 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
233 |
+
else:
|
234 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
235 |
+
optimizer.step()
|
236 |
+
scheduler.step() # Update learning rate schedule
|
237 |
+
model.zero_grad()
|
238 |
+
global_step += 1
|
239 |
+
|
240 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
241 |
+
# Log metrics
|
242 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
243 |
+
results = evaluate(args, model, tokenizer)
|
244 |
+
for key, value in results.items():
|
245 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
246 |
+
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
247 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
248 |
+
logging_loss = tr_loss
|
249 |
+
|
250 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
251 |
+
# Save model checkpoint
|
252 |
+
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
253 |
+
if not os.path.exists(output_dir):
|
254 |
+
os.makedirs(output_dir)
|
255 |
+
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
256 |
+
model_to_save.save_pretrained(output_dir)
|
257 |
+
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
258 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
259 |
+
|
260 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
261 |
+
epoch_iterator.close()
|
262 |
+
break
|
263 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
264 |
+
train_iterator.close()
|
265 |
+
break
|
266 |
+
|
267 |
+
if args.local_rank in [-1, 0]:
|
268 |
+
tb_writer.close()
|
269 |
+
|
270 |
+
return global_step, tr_loss / global_step
|
271 |
+
|
272 |
+
|
273 |
+
def evaluate(args, model, tokenizer, prefix=""):
|
274 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
275 |
+
eval_output_dir = args.output_dir
|
276 |
+
|
277 |
+
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
278 |
+
|
279 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
280 |
+
os.makedirs(eval_output_dir)
|
281 |
+
|
282 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
283 |
+
# Note that DistributedSampler samples randomly
|
284 |
+
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
285 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
286 |
+
|
287 |
+
# Eval!
|
288 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
289 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
290 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
291 |
+
eval_loss = 0.0
|
292 |
+
eval_loss_sum = 0.0
|
293 |
+
nb_eval_steps = 0
|
294 |
+
report_num_words = 0
|
295 |
+
|
296 |
+
model.eval()
|
297 |
+
|
298 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
299 |
+
|
300 |
+
tokenized_text1, x_lengths = batch
|
301 |
+
x_lengths = x_lengths.to(args.device)
|
302 |
+
report_num_words += x_lengths.sum().item()
|
303 |
+
|
304 |
+
inputs, labels = tokenized_text1, tokenized_text1
|
305 |
+
|
306 |
+
inputs = inputs.to(args.device)
|
307 |
+
labels = labels.to(args.device)
|
308 |
+
|
309 |
+
|
310 |
+
with torch.no_grad():
|
311 |
+
outputs = model(inputs, labels=labels, label_ignore=tokenizer.pad_token_id)
|
312 |
+
lm_loss = outputs[0]
|
313 |
+
|
314 |
+
|
315 |
+
eval_loss += lm_loss.mean().item()/x_lengths.sum().item()
|
316 |
+
eval_loss_sum += lm_loss.sum().item()
|
317 |
+
|
318 |
+
|
319 |
+
nb_eval_steps += 1
|
320 |
+
|
321 |
+
# pdb.set_trace()
|
322 |
+
|
323 |
+
eval_loss = eval_loss / nb_eval_steps
|
324 |
+
perplexity1 = torch.exp(torch.tensor(eval_loss))
|
325 |
+
perplexity2 = torch.exp(torch.tensor(eval_loss_sum / report_num_words))
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
result = {
|
330 |
+
"perplexity1": perplexity1, "perplexity2": perplexity2
|
331 |
+
}
|
332 |
+
|
333 |
+
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
334 |
+
with open(output_eval_file, "w") as writer:
|
335 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
336 |
+
for key in sorted(result.keys()):
|
337 |
+
logger.info(" %s = %s", key, str(result[key]))
|
338 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
339 |
+
|
340 |
+
return result
|
341 |
+
|
342 |
+
|
343 |
+
def main():
|
344 |
+
parser = argparse.ArgumentParser()
|
345 |
+
|
346 |
+
## Required parameters
|
347 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
348 |
+
help="The input training data file (a text file).")
|
349 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
350 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
351 |
+
parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
|
352 |
+
|
353 |
+
|
354 |
+
## Other parameters
|
355 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
356 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
357 |
+
|
358 |
+
parser.add_argument("--model_type", default="bert", type=str,
|
359 |
+
help="The model architecture to be fine-tuned.")
|
360 |
+
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
|
361 |
+
help="The model checkpoint for weights initialization.")
|
362 |
+
|
363 |
+
|
364 |
+
parser.add_argument("--use_philly", action='store_true',
|
365 |
+
help="Use Philly for computing.")
|
366 |
+
|
367 |
+
parser.add_argument("--mlm", action='store_true',
|
368 |
+
help="Train with masked-language modeling loss instead of language modeling.")
|
369 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
370 |
+
help="Ratio of tokens to mask for masked language modeling loss")
|
371 |
+
|
372 |
+
parser.add_argument("--config_name", default="", type=str,
|
373 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
374 |
+
parser.add_argument("--tokenizer_name", default="", type=str,
|
375 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
376 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
377 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
378 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
379 |
+
help="Optional input sequence length after tokenization."
|
380 |
+
"The training dataset will be truncated in block of this size for training."
|
381 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
382 |
+
parser.add_argument("--do_train", action='store_true',
|
383 |
+
help="Whether to run training.")
|
384 |
+
parser.add_argument("--do_eval", action='store_true',
|
385 |
+
help="Whether to run eval on the dev set.")
|
386 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
387 |
+
help="Run evaluation during training at each logging step.")
|
388 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
389 |
+
help="Set this flag if you are using an uncased model.")
|
390 |
+
|
391 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
392 |
+
help="Batch size per GPU/CPU for training.")
|
393 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
394 |
+
help="Batch size per GPU/CPU for evaluation.")
|
395 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
396 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
397 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
398 |
+
help="The initial learning rate for Adam.")
|
399 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
400 |
+
help="Weight deay if we apply some.")
|
401 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
402 |
+
help="Epsilon for Adam optimizer.")
|
403 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
404 |
+
help="Max gradient norm.")
|
405 |
+
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
406 |
+
help="Total number of training epochs to perform.")
|
407 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
408 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
409 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
410 |
+
help="Linear warmup over warmup_steps.")
|
411 |
+
|
412 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
413 |
+
help="Evaluate the results at the given global step")
|
414 |
+
|
415 |
+
|
416 |
+
|
417 |
+
parser.add_argument('--logging_steps', type=int, default=100,
|
418 |
+
help="Log every X updates steps.")
|
419 |
+
parser.add_argument('--save_steps', type=int, default=100,
|
420 |
+
help="Save checkpoint every X updates steps.")
|
421 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
422 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
423 |
+
parser.add_argument("--no_cuda", action='store_true',
|
424 |
+
help="Avoid using CUDA when available")
|
425 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
426 |
+
help="Overwrite the content of the output directory")
|
427 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
428 |
+
help="Overwrite the cached training and evaluation sets")
|
429 |
+
parser.add_argument('--seed', type=int, default=42,
|
430 |
+
help="random seed for initialization")
|
431 |
+
|
432 |
+
parser.add_argument('--fp16', action='store_true',
|
433 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
434 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
435 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
436 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
437 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
438 |
+
help="For distributed training: local_rank")
|
439 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
440 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
441 |
+
args = parser.parse_args()
|
442 |
+
|
443 |
+
if args.model_type in ["bert", "roberta"] and not args.mlm:
|
444 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
445 |
+
"flag (masked language modeling).")
|
446 |
+
if args.eval_data_file is None and args.do_eval:
|
447 |
+
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
448 |
+
"or remove the --do_eval argument.")
|
449 |
+
|
450 |
+
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
451 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
452 |
+
|
453 |
+
# Setup distant debugging if needed
|
454 |
+
if args.server_ip and args.server_port:
|
455 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
456 |
+
import ptvsd
|
457 |
+
print("Waiting for debugger attach")
|
458 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
459 |
+
ptvsd.wait_for_attach()
|
460 |
+
|
461 |
+
# Setup CUDA, GPU & distributed training
|
462 |
+
if args.local_rank == -1 or args.no_cuda:
|
463 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
464 |
+
args.n_gpu = torch.cuda.device_count()
|
465 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
466 |
+
torch.cuda.set_device(args.local_rank)
|
467 |
+
device = torch.device("cuda", args.local_rank)
|
468 |
+
torch.distributed.init_process_group(backend='nccl')
|
469 |
+
args.n_gpu = 1
|
470 |
+
args.device = device
|
471 |
+
|
472 |
+
# Setup logging
|
473 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
474 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
475 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
476 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
477 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
478 |
+
|
479 |
+
# Set seed
|
480 |
+
set_seed(args)
|
481 |
+
|
482 |
+
# Load pretrained model and tokenizer
|
483 |
+
if args.local_rank not in [-1, 0]:
|
484 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
485 |
+
|
486 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
487 |
+
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
488 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
489 |
+
if args.block_size <= 0:
|
490 |
+
args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
|
491 |
+
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
|
492 |
+
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
493 |
+
model.to(args.device)
|
494 |
+
|
495 |
+
# Chunyuan: Add Padding token to GPT2
|
496 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
497 |
+
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
498 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
499 |
+
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
500 |
+
assert tokenizer.pad_token == '<PAD>'
|
501 |
+
|
502 |
+
|
503 |
+
# pdb.set_trace()
|
504 |
+
|
505 |
+
if args.local_rank == 0:
|
506 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
507 |
+
|
508 |
+
logger.info("Training/evaluation parameters %s", args)
|
509 |
+
|
510 |
+
# Training
|
511 |
+
global_step= 0
|
512 |
+
if args.do_train:
|
513 |
+
if args.local_rank not in [-1, 0]:
|
514 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
515 |
+
|
516 |
+
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
|
517 |
+
|
518 |
+
if args.local_rank == 0:
|
519 |
+
torch.distributed.barrier()
|
520 |
+
|
521 |
+
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
522 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
523 |
+
|
524 |
+
|
525 |
+
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
526 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
527 |
+
# Create output directory if needed
|
528 |
+
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
529 |
+
os.makedirs(args.output_dir)
|
530 |
+
|
531 |
+
logger.info("Saving model checkpoint to %s", args.output_dir)
|
532 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
533 |
+
# They can then be reloaded using `from_pretrained()`
|
534 |
+
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
535 |
+
model_to_save.save_pretrained(args.output_dir)
|
536 |
+
tokenizer.save_pretrained(args.output_dir)
|
537 |
+
|
538 |
+
# Good practice: save your training arguments together with the trained model
|
539 |
+
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
540 |
+
|
541 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
542 |
+
model = model_class.from_pretrained(args.output_dir)
|
543 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
544 |
+
model.to(args.device)
|
545 |
+
|
546 |
+
|
547 |
+
# Evaluation
|
548 |
+
results = {}
|
549 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
550 |
+
|
551 |
+
if global_step == 0:
|
552 |
+
global_step = args.gloabl_step_eval
|
553 |
+
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
554 |
+
|
555 |
+
checkpoints = [args.output_dir]
|
556 |
+
if args.eval_all_checkpoints:
|
557 |
+
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
558 |
+
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
559 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
560 |
+
print("Evaluate the following checkpoints: %s", checkpoints)
|
561 |
+
for checkpoint in checkpoints:
|
562 |
+
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
563 |
+
model = model_class.from_pretrained(checkpoint)
|
564 |
+
model.to(args.device)
|
565 |
+
result = evaluate(args, model, tokenizer, prefix=global_step)
|
566 |
+
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
567 |
+
results.update(result)
|
568 |
+
|
569 |
+
return results
|
570 |
+
|
571 |
+
|
572 |
+
if __name__ == "__main__":
|
573 |
+
main()
|
Optimus/code/examples/big_ae/run_lm_gpt2_training.py
ADDED
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
|
24 |
+
|
25 |
+
import pdb
|
26 |
+
import argparse
|
27 |
+
import glob
|
28 |
+
import logging
|
29 |
+
|
30 |
+
import os
|
31 |
+
import pickle
|
32 |
+
import random
|
33 |
+
|
34 |
+
import numpy as np
|
35 |
+
import torch
|
36 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
37 |
+
from torch.utils.data.distributed import DistributedSampler
|
38 |
+
from tensorboardX import SummaryWriter
|
39 |
+
from tqdm import tqdm, trange
|
40 |
+
from collections import defaultdict
|
41 |
+
|
42 |
+
# from azure.cosmosdb.table.tableservice import TableService
|
43 |
+
# from azure.cosmosdb.table.models import Entity
|
44 |
+
from datetime import datetime
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
49 |
+
BertConfig, BertForLatentConnector, BertTokenizer,
|
50 |
+
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
51 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
52 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
53 |
+
|
54 |
+
from utils import (BucketingDataLoader, TextDataset_Split, TextDataset_2Tokenizers)
|
55 |
+
|
56 |
+
|
57 |
+
logger = logging.getLogger(__name__)
|
58 |
+
|
59 |
+
|
60 |
+
MODEL_CLASSES = {
|
61 |
+
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
62 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
63 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
|
64 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
storage_name="textae"
|
69 |
+
key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
|
70 |
+
# ts = TableService(account_name=storage_name, account_key=key)
|
71 |
+
|
72 |
+
|
73 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
74 |
+
if isinstance(tokenizer, list):
|
75 |
+
dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
76 |
+
else:
|
77 |
+
dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
78 |
+
return dataset
|
79 |
+
|
80 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
81 |
+
if isinstance(tokenizer, list):
|
82 |
+
if not evaluate:
|
83 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
84 |
+
file_path=args.train_data_file
|
85 |
+
else:
|
86 |
+
args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
87 |
+
file_path=args.eval_data_file
|
88 |
+
dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True)
|
89 |
+
else:
|
90 |
+
pass
|
91 |
+
return dataloader
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def set_seed(args):
|
97 |
+
random.seed(args.seed)
|
98 |
+
np.random.seed(args.seed)
|
99 |
+
torch.manual_seed(args.seed)
|
100 |
+
if args.n_gpu > 0:
|
101 |
+
torch.cuda.manual_seed_all(args.seed)
|
102 |
+
|
103 |
+
|
104 |
+
def mask_tokens(inputs, tokenizer, args):
|
105 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
106 |
+
labels = inputs.clone()
|
107 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
108 |
+
|
109 |
+
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
|
110 |
+
labels[masked_indices==1] = -1 # We only compute loss on masked tokens
|
111 |
+
|
112 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
113 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
|
114 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
115 |
+
|
116 |
+
# 10% of the time, we replace masked input tokens with random word
|
117 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
|
118 |
+
indices_random = indices_random
|
119 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
120 |
+
inputs[indices_random] = random_words[indices_random]
|
121 |
+
|
122 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
123 |
+
return inputs, labels
|
124 |
+
|
125 |
+
|
126 |
+
def train(args, train_dataloader, model, encoder_tokenizer, decoder_tokenizer, table_name):
|
127 |
+
""" Train the model """
|
128 |
+
if args.local_rank in [-1, 0]:
|
129 |
+
tb_writer = SummaryWriter()
|
130 |
+
|
131 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
132 |
+
# train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
133 |
+
# train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
134 |
+
|
135 |
+
if args.max_steps > 0:
|
136 |
+
t_total = args.max_steps
|
137 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
138 |
+
else:
|
139 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
140 |
+
|
141 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
142 |
+
|
143 |
+
|
144 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
145 |
+
optimizer_grouped_parameters = [
|
146 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
147 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
148 |
+
]
|
149 |
+
|
150 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
151 |
+
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
152 |
+
|
153 |
+
|
154 |
+
if args.fp16:
|
155 |
+
try:
|
156 |
+
from apex import amp
|
157 |
+
except ImportError:
|
158 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
159 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
160 |
+
|
161 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
162 |
+
if args.n_gpu > 1:
|
163 |
+
model = torch.nn.DataParallel(model, device_ids=range(args.n_gpu)).to(args.device)
|
164 |
+
|
165 |
+
# Distributed training (should be after apex fp16 initialization)
|
166 |
+
if args.local_rank != -1:
|
167 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
168 |
+
output_device=args.local_rank,
|
169 |
+
find_unused_parameters=True)
|
170 |
+
|
171 |
+
|
172 |
+
# Train!
|
173 |
+
logger.info("***** Running training *****")
|
174 |
+
logger.info(" Num examples = %d", train_dataloader.num_examples)
|
175 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
176 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
177 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
178 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
179 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
180 |
+
logger.info(" Total optimization steps = %d", t_total)
|
181 |
+
|
182 |
+
global_step = 0
|
183 |
+
tr_loss, logging_loss = 0.0, 0.0
|
184 |
+
|
185 |
+
|
186 |
+
model.zero_grad()
|
187 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
188 |
+
|
189 |
+
n_iter = int(args.num_train_epochs) * len(train_dataloader)
|
190 |
+
|
191 |
+
tmp_list = []
|
192 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
193 |
+
for epoch in train_iterator:
|
194 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
195 |
+
for step, batch in enumerate(epoch_iterator):
|
196 |
+
|
197 |
+
tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
|
198 |
+
inputs, labels = tokenized_text1.to(args.device), tokenized_text1.to(args.device)
|
199 |
+
|
200 |
+
model.train()
|
201 |
+
|
202 |
+
outputs = model(inputs, labels=labels, label_ignore=decoder_tokenizer.pad_token_id)
|
203 |
+
loss = outputs[0].mean() # model outputs are always tuple in pytorch-transformers (see doc)
|
204 |
+
|
205 |
+
if args.n_gpu > 1:
|
206 |
+
loss = loss.mean()
|
207 |
+
|
208 |
+
if args.use_philly:
|
209 |
+
print("PROGRESS: {}%".format(round(100 * (step + epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) * len(epoch_iterator)) , 4)))
|
210 |
+
print("EVALERR: {}%".format(loss))
|
211 |
+
|
212 |
+
epoch_iterator.set_description(
|
213 |
+
(
|
214 |
+
f'iter: {step + epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
|
215 |
+
)
|
216 |
+
)
|
217 |
+
|
218 |
+
if args.gradient_accumulation_steps > 1:
|
219 |
+
loss = loss / args.gradient_accumulation_steps
|
220 |
+
|
221 |
+
if args.fp16:
|
222 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
223 |
+
scaled_loss.backward()
|
224 |
+
else:
|
225 |
+
loss.backward()
|
226 |
+
|
227 |
+
tr_loss += loss.item()
|
228 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
229 |
+
if args.fp16:
|
230 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
231 |
+
else:
|
232 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
233 |
+
|
234 |
+
optimizer.step()
|
235 |
+
|
236 |
+
scheduler.step() # Update learning rate schedule
|
237 |
+
|
238 |
+
model.zero_grad()
|
239 |
+
|
240 |
+
global_step += 1
|
241 |
+
|
242 |
+
|
243 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
244 |
+
# Log metrics
|
245 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
246 |
+
results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer)
|
247 |
+
for key, value in results.items():
|
248 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
249 |
+
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
250 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
251 |
+
logging_loss = tr_loss
|
252 |
+
|
253 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
254 |
+
|
255 |
+
# Save decoder model checkpoint
|
256 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
257 |
+
|
258 |
+
if not os.path.exists(output_decoder_dir):
|
259 |
+
os.makedirs(output_decoder_dir)
|
260 |
+
|
261 |
+
model_decoder_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
262 |
+
if args.use_philly:
|
263 |
+
save_solid = False
|
264 |
+
while not save_solid:
|
265 |
+
try:
|
266 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
267 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
|
268 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
269 |
+
save_solid = True
|
270 |
+
except:
|
271 |
+
pass
|
272 |
+
else:
|
273 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
274 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
|
275 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
276 |
+
|
277 |
+
|
278 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
279 |
+
epoch_iterator.close()
|
280 |
+
break
|
281 |
+
|
282 |
+
|
283 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
284 |
+
train_iterator.close()
|
285 |
+
break
|
286 |
+
|
287 |
+
if args.local_rank in [-1, 0]:
|
288 |
+
tb_writer.close()
|
289 |
+
|
290 |
+
return global_step, tr_loss / global_step
|
291 |
+
|
292 |
+
|
293 |
+
def evaluate(args, model, encoder_tokenizer, decoder_tokenizer, table_name, prefix="", subset="test"):
|
294 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
295 |
+
eval_output_dir = args.output_dir
|
296 |
+
|
297 |
+
logger.info("***** Running evaluation on {} dataset *****".format(subset))
|
298 |
+
|
299 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
300 |
+
os.makedirs(eval_output_dir)
|
301 |
+
|
302 |
+
args.per_gpu_eval_batch_size = 1
|
303 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
304 |
+
|
305 |
+
eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
|
306 |
+
|
307 |
+
# Eval!
|
308 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
309 |
+
logger.info(" Num examples = %d", len(eval_dataloader))
|
310 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
311 |
+
eval_loss = 0.0
|
312 |
+
eval_loss_sum = 0.0
|
313 |
+
nb_eval_steps = 0
|
314 |
+
report_num_words = 0
|
315 |
+
|
316 |
+
model.eval()
|
317 |
+
|
318 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
319 |
+
|
320 |
+
_, tokenized_text1, tokenized_text_lengths = batch
|
321 |
+
inputs, labels = tokenized_text1.to(args.device), tokenized_text1.to(args.device)
|
322 |
+
|
323 |
+
x_lengths = tokenized_text_lengths[:,1].to(args.device)
|
324 |
+
report_num_words += x_lengths.sum().item()
|
325 |
+
|
326 |
+
|
327 |
+
with torch.no_grad():
|
328 |
+
outputs = model(inputs, labels=labels, label_ignore=decoder_tokenizer.pad_token_id)
|
329 |
+
lm_loss = outputs[0]
|
330 |
+
|
331 |
+
eval_loss += lm_loss.mean().item()/x_lengths.sum().item()
|
332 |
+
eval_loss_sum += lm_loss.sum().item()
|
333 |
+
|
334 |
+
nb_eval_steps += 1
|
335 |
+
|
336 |
+
eval_loss = eval_loss / nb_eval_steps
|
337 |
+
perplexity1 = torch.exp(torch.tensor(eval_loss))
|
338 |
+
perplexity2 = torch.exp(torch.tensor(eval_loss_sum / report_num_words))
|
339 |
+
|
340 |
+
|
341 |
+
result = {
|
342 |
+
"perplexity1": perplexity1, "perplexity2": perplexity2
|
343 |
+
}
|
344 |
+
|
345 |
+
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
346 |
+
with open(output_eval_file, "w") as writer:
|
347 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
348 |
+
for key in sorted(result.keys()):
|
349 |
+
logger.info(" %s = %s", key, str(result[key]))
|
350 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
351 |
+
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
return result
|
356 |
+
|
357 |
+
|
358 |
+
def main():
|
359 |
+
parser = argparse.ArgumentParser()
|
360 |
+
|
361 |
+
## Required parameters
|
362 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
363 |
+
help="The input training data file (a text file).")
|
364 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
365 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
366 |
+
parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
|
367 |
+
|
368 |
+
## Other parameters
|
369 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
370 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
371 |
+
parser.add_argument("--ExpName", default="", type=str,
|
372 |
+
help="The experiment name used in Azure Table.")
|
373 |
+
parser.add_argument("--save_bert_gpt_init", action='store_true',
|
374 |
+
help="Use Philly for computing.")
|
375 |
+
|
376 |
+
|
377 |
+
## Encoder options
|
378 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
379 |
+
help="The encoder model architecture to be fine-tuned.")
|
380 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
381 |
+
help="The encoder model checkpoint for weights initialization.")
|
382 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
383 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
384 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
385 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
386 |
+
|
387 |
+
## Decoder options
|
388 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
389 |
+
help="The decoder model architecture to be fine-tuned.")
|
390 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
391 |
+
help="The decoder model checkpoint for weights initialization.")
|
392 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
393 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
394 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
395 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
396 |
+
|
397 |
+
## Variational auto-encoder
|
398 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
399 |
+
parser.add_argument("--use_deterministic_connect", action='store_true',
|
400 |
+
help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
|
401 |
+
parser.add_argument("--use_pretrained_model", action='store_true',
|
402 |
+
help="Use pre-trained auto-encoder models as the initialization")
|
403 |
+
|
404 |
+
## Objective functions
|
405 |
+
parser.add_argument("--mlm", action='store_true',
|
406 |
+
help="Train with masked-language modeling loss instead of language modeling.")
|
407 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
408 |
+
help="Ratio of tokens to mask for masked language modeling loss")
|
409 |
+
parser.add_argument("--beta", type=float, default=1.0,
|
410 |
+
help="The weighting hyper-parameter of the KL term in VAE")
|
411 |
+
|
412 |
+
|
413 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
414 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
415 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
416 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
417 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
418 |
+
help="Optional input sequence length after tokenization."
|
419 |
+
"The training dataset will be truncated in block of this size for training."
|
420 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
421 |
+
parser.add_argument("--do_train", action='store_true',
|
422 |
+
help="Whether to run training.")
|
423 |
+
parser.add_argument("--do_eval", action='store_true',
|
424 |
+
help="Whether to run eval on the dev set.")
|
425 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
426 |
+
help="Run evaluation during training at each logging step.")
|
427 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
428 |
+
help="Set this flag if you are using an uncased model.")
|
429 |
+
|
430 |
+
|
431 |
+
# Training Schedules
|
432 |
+
parser.add_argument("--ratio_increase", default=0.25, type=float,
|
433 |
+
help="Learning schedule, the percentage for the annealing stage.")
|
434 |
+
parser.add_argument("--ratio_zero", default=0.25, type=float,
|
435 |
+
help="Learning schedule, the percentage for the pure auto-encoding stage.")
|
436 |
+
parser.add_argument("--fb_mode", default=0, type=int,
|
437 |
+
help="free bit training mode.")
|
438 |
+
parser.add_argument("--dim_target_kl", default=3.0, type=float,
|
439 |
+
help="dim_target_kl free bit training mode.")
|
440 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
441 |
+
help="Batch size per GPU/CPU for training.")
|
442 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
443 |
+
help="Batch size per GPU/CPU for evaluation.")
|
444 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
445 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
446 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
447 |
+
help="The initial learning rate for Adam.")
|
448 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
449 |
+
help="Weight deay if we apply some.")
|
450 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
451 |
+
help="Epsilon for Adam optimizer.")
|
452 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
453 |
+
help="Max gradient norm.")
|
454 |
+
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
455 |
+
help="Total number of training epochs to perform.")
|
456 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
457 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
458 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
459 |
+
help="Linear warmup over warmup_steps.")
|
460 |
+
parser.add_argument("--use_philly", action='store_true',
|
461 |
+
help="Use Philly for computing.")
|
462 |
+
|
463 |
+
|
464 |
+
## IO: Logging and Saving
|
465 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
466 |
+
help="Log every X updates steps.")
|
467 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
468 |
+
help="Save checkpoint every X updates steps.")
|
469 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
470 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
471 |
+
parser.add_argument("--no_cuda", action='store_true',
|
472 |
+
help="Avoid using CUDA when available")
|
473 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
474 |
+
help="Overwrite the content of the output directory")
|
475 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
476 |
+
help="Overwrite the cached training and evaluation sets")
|
477 |
+
parser.add_argument('--seed', type=int, default=42,
|
478 |
+
help="random seed for initialization")
|
479 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
480 |
+
help="Evaluate the results at the given global step")
|
481 |
+
|
482 |
+
# Precision & Distributed Training
|
483 |
+
parser.add_argument('--fp16', action='store_true',
|
484 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
485 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
486 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
487 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
488 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
489 |
+
help="For distributed training: local_rank")
|
490 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
491 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
492 |
+
args = parser.parse_args()
|
493 |
+
|
494 |
+
if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
|
495 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
496 |
+
"flag (masked language modeling).")
|
497 |
+
if args.eval_data_file is None and args.do_eval:
|
498 |
+
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
499 |
+
"or remove the --do_eval argument.")
|
500 |
+
|
501 |
+
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
502 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
503 |
+
|
504 |
+
# Setup distant debugging if needed
|
505 |
+
if args.server_ip and args.server_port:
|
506 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
507 |
+
import ptvsd
|
508 |
+
print("Waiting for debugger attach")
|
509 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
510 |
+
ptvsd.wait_for_attach()
|
511 |
+
|
512 |
+
# Setup CUDA, GPU & distributed training
|
513 |
+
if args.local_rank == -1 or args.no_cuda:
|
514 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
515 |
+
args.n_gpu = torch.cuda.device_count()
|
516 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
517 |
+
torch.cuda.set_device(args.local_rank)
|
518 |
+
device = torch.device("cuda", args.local_rank)
|
519 |
+
torch.distributed.init_process_group(backend='nccl')
|
520 |
+
args.n_gpu = 1
|
521 |
+
args.device = device
|
522 |
+
|
523 |
+
# Setup logging
|
524 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
525 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
526 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
527 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
528 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
529 |
+
|
530 |
+
args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
|
531 |
+
table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
|
532 |
+
try:
|
533 |
+
ts.create_table(table_name)
|
534 |
+
except:
|
535 |
+
pass
|
536 |
+
|
537 |
+
|
538 |
+
# Set seed
|
539 |
+
set_seed(args)
|
540 |
+
|
541 |
+
# Load pretrained model and tokenizer
|
542 |
+
if args.local_rank not in [-1, 0]:
|
543 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
544 |
+
|
545 |
+
|
546 |
+
## Encoder
|
547 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
548 |
+
encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
|
549 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
550 |
+
if args.block_size <= 0:
|
551 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
552 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
553 |
+
model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
|
554 |
+
# model_encoder.to(args.device)
|
555 |
+
|
556 |
+
## Decoder
|
557 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
558 |
+
decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
|
559 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
560 |
+
if args.block_size <= 0:
|
561 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
562 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
563 |
+
model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config)
|
564 |
+
|
565 |
+
# Chunyuan: Add Padding token to GPT2
|
566 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
567 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
568 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
569 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
570 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
571 |
+
|
572 |
+
model_decoder.to(args.device)
|
573 |
+
|
574 |
+
|
575 |
+
if args.local_rank == 0:
|
576 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
577 |
+
|
578 |
+
logger.info("Training/evaluation parameters %s", args)
|
579 |
+
|
580 |
+
global_step= 0
|
581 |
+
# Training
|
582 |
+
if args.do_train:
|
583 |
+
if args.local_rank not in [-1, 0]:
|
584 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
585 |
+
|
586 |
+
train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
|
587 |
+
|
588 |
+
if args.local_rank == 0:
|
589 |
+
torch.distributed.barrier()
|
590 |
+
|
591 |
+
global_step, tr_loss = train(args, train_dataloader, model_decoder, tokenizer_encoder, tokenizer_decoder, table_name)
|
592 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
593 |
+
|
594 |
+
|
595 |
+
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
596 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
597 |
+
# Create output directory if needed
|
598 |
+
# Save model checkpoint
|
599 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
600 |
+
if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
|
601 |
+
os.makedirs(output_decoder_dir)
|
602 |
+
|
603 |
+
|
604 |
+
logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
|
605 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
606 |
+
# They can then be reloaded using `from_pretrained()`
|
607 |
+
|
608 |
+
model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
|
609 |
+
|
610 |
+
# Good practice: save your training arguments together with the trained model
|
611 |
+
|
612 |
+
if args.use_philly:
|
613 |
+
save_solid = False
|
614 |
+
while not save_solid:
|
615 |
+
try:
|
616 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
617 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
618 |
+
save_solid = True
|
619 |
+
except:
|
620 |
+
pass
|
621 |
+
else:
|
622 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
623 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_encoder_args.bin'))
|
624 |
+
|
625 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
626 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
|
627 |
+
model_decoder.to(args.device)
|
628 |
+
|
629 |
+
|
630 |
+
# Evaluation
|
631 |
+
results = {}
|
632 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
633 |
+
if global_step == 0:
|
634 |
+
global_step = args.gloabl_step_eval
|
635 |
+
|
636 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
637 |
+
checkpoints = [ output_decoder_dir ]
|
638 |
+
|
639 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
640 |
+
for checkpoint in checkpoints:
|
641 |
+
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
642 |
+
|
643 |
+
model_decoder = decoder_model_class.from_pretrained(checkpoint)
|
644 |
+
model_decoder.to(args.device)
|
645 |
+
|
646 |
+
result = evaluate(args, model_decoder, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='test')
|
647 |
+
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
648 |
+
results.update(result)
|
649 |
+
|
650 |
+
# result = evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='train')
|
651 |
+
# result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
652 |
+
# results.update(result)
|
653 |
+
|
654 |
+
return results
|
655 |
+
|
656 |
+
|
657 |
+
if __name__ == "__main__":
|
658 |
+
main()
|
Optimus/code/examples/big_ae/run_lm_vae_label_ctrl_gen.py
ADDED
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
import pdb
|
24 |
+
import argparse
|
25 |
+
import glob
|
26 |
+
import logging
|
27 |
+
import os
|
28 |
+
import pickle
|
29 |
+
import random
|
30 |
+
import numpy as np
|
31 |
+
import torch
|
32 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
33 |
+
from torch.utils.data.distributed import DistributedSampler
|
34 |
+
from tensorboardX import SummaryWriter
|
35 |
+
from tqdm import tqdm, trange
|
36 |
+
from collections import defaultdict
|
37 |
+
# from azure.cosmosdb.table.tableservice import TableService
|
38 |
+
# from azure.cosmosdb.table.models import Entity
|
39 |
+
from datetime import datetime
|
40 |
+
import sys
|
41 |
+
import json
|
42 |
+
import nltk
|
43 |
+
nltk.download('punkt')
|
44 |
+
|
45 |
+
sys.path.append('../../')
|
46 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
47 |
+
BertConfig, BertForLatentConnector, BertTokenizer,
|
48 |
+
GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
|
49 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
50 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
51 |
+
from utils import (TextDataset_Split, TextDataset_2Tokenizers_LCtrlG,
|
52 |
+
frange_cycle_linear, frange_cycle_zero_linear, AverageValueMeter)
|
53 |
+
# from modules import ARAE
|
54 |
+
from modules import CARA
|
55 |
+
# logging.getLogger("azure").setLevel(logging.WARNING)
|
56 |
+
# logging.getLogger("TableService").setLevel(logging.WARNING)
|
57 |
+
logger = logging.getLogger(__name__)
|
58 |
+
import time
|
59 |
+
def get_time_str():
|
60 |
+
return time.ctime().replace(' ', '_').replace(':', '-')
|
61 |
+
|
62 |
+
MODEL_CLASSES = {
|
63 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
64 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
65 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
|
66 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
storage_name="textae"
|
71 |
+
key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
|
72 |
+
# ts = TableService(account_name=storage_name, account_key=key)
|
73 |
+
|
74 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
75 |
+
if isinstance(tokenizer, list):
|
76 |
+
dataset = TextDataset_2Tokenizers_LCtrlG(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file,
|
77 |
+
block_size=args.block_size, create_new=args.create_new)
|
78 |
+
else:
|
79 |
+
raise NotImplementedError
|
80 |
+
# dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
81 |
+
return dataset
|
82 |
+
|
83 |
+
def set_seed(args):
|
84 |
+
random.seed(args.seed)
|
85 |
+
np.random.seed(args.seed)
|
86 |
+
torch.manual_seed(args.seed)
|
87 |
+
if args.n_gpu > 0:
|
88 |
+
torch.cuda.manual_seed_all(args.seed)
|
89 |
+
|
90 |
+
def mask_tokens(inputs, tokenizer, args):
|
91 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
92 |
+
labels = inputs.clone()
|
93 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
94 |
+
|
95 |
+
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
|
96 |
+
labels[masked_indices==1] = -1 # We only compute loss on masked tokens
|
97 |
+
|
98 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
99 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
|
100 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
101 |
+
|
102 |
+
# 10% of the time, we replace masked input tokens with random word
|
103 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
|
104 |
+
indices_random = indices_random
|
105 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
106 |
+
inputs[indices_random] = random_words[indices_random]
|
107 |
+
|
108 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
109 |
+
return inputs, labels
|
110 |
+
|
111 |
+
def train(args, train_dataset, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, logff):
|
112 |
+
""" Train the model """
|
113 |
+
if args.local_rank in [-1, 0]:
|
114 |
+
tb_writer = SummaryWriter()
|
115 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
116 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
117 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
118 |
+
if args.max_steps > 0:
|
119 |
+
t_total = args.max_steps
|
120 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
121 |
+
else:
|
122 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
123 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
124 |
+
# model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
|
125 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
126 |
+
optimizer_grouped_parameters = [
|
127 |
+
{'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
128 |
+
{'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
129 |
+
]
|
130 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
131 |
+
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
132 |
+
if args.fp16:
|
133 |
+
try:
|
134 |
+
from apex import amp
|
135 |
+
except ImportError:
|
136 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
137 |
+
model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
|
138 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
139 |
+
if args.n_gpu > 1:
|
140 |
+
model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
|
141 |
+
# Distributed training (should be after apex fp16 initialization)
|
142 |
+
if args.local_rank != -1:
|
143 |
+
model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
|
144 |
+
# model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
|
145 |
+
|
146 |
+
# Train!
|
147 |
+
logger.info("***** Running training *****")
|
148 |
+
logff.write("***** Running training *****\n")
|
149 |
+
logger.info(" Num examples = {}".format(len(train_dataset)))
|
150 |
+
logff.write(" Num examples = {}\n".format(len(train_dataset)))
|
151 |
+
logger.info(" Num Epochs = {}".format(args.num_train_epochs))
|
152 |
+
logff.write(" Num Epochs = {}\n".format(args.num_train_epochs))
|
153 |
+
logger.info(" Instantaneous batch size per GPU = {}".format(args.per_gpu_train_batch_size))
|
154 |
+
logff.write(" Instantaneous batch size per GPU = {}\n".format(args.per_gpu_train_batch_size))
|
155 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
156 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
157 |
+
logff.write(" Total train batch size (w. parallel, distributed & accumulation) = {}\n".format(
|
158 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)))
|
159 |
+
logger.info(" Gradient Accumulation steps = {}".format(args.gradient_accumulation_steps))
|
160 |
+
logff.write(" Gradient Accumulation steps = {}\n".format(args.gradient_accumulation_steps))
|
161 |
+
logger.info(" Total optimization steps = {}".format( t_total))
|
162 |
+
logff.write(" Total optimization steps = {}\n".format(t_total))
|
163 |
+
logff.flush()
|
164 |
+
global_step = 0
|
165 |
+
tr_loss, logging_loss = 0.0, 0.0
|
166 |
+
model_vae.zero_grad()
|
167 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
168 |
+
n_iter = int(args.num_train_epochs) * len(train_dataloader)
|
169 |
+
beta_t_list = frange_cycle_zero_linear(n_iter, start=1.0, stop=args.beta_cls, n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
|
170 |
+
|
171 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
172 |
+
accmeter = {
|
173 |
+
'acc_encode_z_dis': AverageValueMeter(),
|
174 |
+
'acc_gen_z_dis': AverageValueMeter(),
|
175 |
+
'acc_encode_z_cls': AverageValueMeter(),
|
176 |
+
'acc_cls': AverageValueMeter(),
|
177 |
+
# 'acc_at_soft_cls': AverageValueMeter(),
|
178 |
+
}
|
179 |
+
lossmeter = {
|
180 |
+
'loss': AverageValueMeter(),
|
181 |
+
'loss_rec': AverageValueMeter(),
|
182 |
+
'loss_encoder': AverageValueMeter(),
|
183 |
+
'loss_lsc': AverageValueMeter(),
|
184 |
+
'loss_lsd': AverageValueMeter(),
|
185 |
+
'loss_lsg': AverageValueMeter(),
|
186 |
+
'loss_cls': AverageValueMeter(),
|
187 |
+
# 'loss_at_soft_cls': AverageValueMeter(),
|
188 |
+
}
|
189 |
+
for epoch in train_iterator:
|
190 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
191 |
+
# pbar = tqdm(total=(len(train_dataloader)+1) // args.gradient_accumulation_steps)
|
192 |
+
for step, batch in enumerate(train_dataloader):
|
193 |
+
|
194 |
+
# if step > 100:
|
195 |
+
# break
|
196 |
+
|
197 |
+
# Data
|
198 |
+
input_seq_ids, tgt_seq_ids, tokenized_text_lengths, cond_labels = batch
|
199 |
+
max_len_values, _ = tokenized_text_lengths.max(0)
|
200 |
+
input_seq_ids = input_seq_ids[:,:max_len_values[0]]
|
201 |
+
tgt_seq_ids = tgt_seq_ids[:,:max_len_values[1]]
|
202 |
+
input_seq_ids, tgt_seq_ids = mask_tokens(input_seq_ids, encoder_tokenizer, args) if args.mlm else (input_seq_ids, tgt_seq_ids)
|
203 |
+
input_seq_ids = input_seq_ids.to(args.device)
|
204 |
+
tgt_seq_ids = tgt_seq_ids.to(args.device)
|
205 |
+
cond_labels = cond_labels.to(args.device)
|
206 |
+
input_mask = torch.where(torch.arange(max_len_values[0].item()).unsqueeze(0).repeat(input_seq_ids.size(0), 1).type_as(tokenized_text_lengths).to(args.device)
|
207 |
+
< tokenized_text_lengths[:, 0].unsqueeze(1).to(args.device), torch.ones_like(input_seq_ids), torch.zeros_like(input_seq_ids))
|
208 |
+
|
209 |
+
# Configs
|
210 |
+
model_vae.train()
|
211 |
+
beta_t = beta_t_list[step + epoch*len(epoch_iterator)]
|
212 |
+
model_vae.module.args.beta_cls = beta_t
|
213 |
+
# if beta_t == 0.0:
|
214 |
+
# model_vae.args.fb_mode = 0
|
215 |
+
# else:
|
216 |
+
# model_vae.args.fb_mode = 1
|
217 |
+
# if args.use_deterministic_connect:
|
218 |
+
# model_vae.args.fb_mode = 2
|
219 |
+
|
220 |
+
# Model
|
221 |
+
loss_dict, acc_dict = model_vae(input_seq_ids=input_seq_ids, tgt_seq_ids=tgt_seq_ids, cond_labels=cond_labels, attention_mask=input_mask)
|
222 |
+
|
223 |
+
# Loss
|
224 |
+
for key, value in loss_dict.items():
|
225 |
+
loss_dict[key] = value.mean()
|
226 |
+
|
227 |
+
loss = loss_dict['loss']
|
228 |
+
if args.gradient_accumulation_steps > 1:
|
229 |
+
loss = loss / args.gradient_accumulation_steps
|
230 |
+
if args.fp16:
|
231 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
232 |
+
scaled_loss.backward()
|
233 |
+
else:
|
234 |
+
loss.backward()
|
235 |
+
tr_loss += loss.item()
|
236 |
+
|
237 |
+
# Log
|
238 |
+
for key, value in loss_dict.items():
|
239 |
+
lossmeter[key].add(value.item())
|
240 |
+
|
241 |
+
for key, value in acc_dict.items():
|
242 |
+
value = value.cpu().tolist()
|
243 |
+
for v in value:
|
244 |
+
accmeter[key].add(float(v))
|
245 |
+
|
246 |
+
# Optimize
|
247 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
248 |
+
# Optimize
|
249 |
+
if args.fp16:
|
250 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
251 |
+
else:
|
252 |
+
torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)
|
253 |
+
optimizer.step()
|
254 |
+
scheduler.step() # Update learning rate schedule
|
255 |
+
model_vae.zero_grad()
|
256 |
+
global_step += 1
|
257 |
+
# pbar.update(1)
|
258 |
+
|
259 |
+
# Log
|
260 |
+
if global_step % args.logging_steps == 0:
|
261 |
+
logger.info("\n")
|
262 |
+
logger.info("global_step: {}, avg loss: {:3f}".format(global_step, tr_loss/global_step))
|
263 |
+
logff.write("global_step: {}, avg loss: {:3f}\n".format(global_step, tr_loss/global_step))
|
264 |
+
logger.info("loss: {}".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in lossmeter.items())))
|
265 |
+
logff.write("loss: {}\n".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in lossmeter.items())))
|
266 |
+
logger.info("acc: {}".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in accmeter.items())))
|
267 |
+
logff.write("acc: {}\n".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in accmeter.items())))
|
268 |
+
logff.flush()
|
269 |
+
|
270 |
+
|
271 |
+
if args.use_philly:
|
272 |
+
#if args.local_rank in [-1, 0]:
|
273 |
+
if args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
274 |
+
logger.info("PROGRESS: {}%".format(round(100 * (step + epoch*len(train_dataloader) ) /(int(args.num_train_epochs) * len(train_dataloader)) , 4)))
|
275 |
+
logger.info("EVALERR: {}%".format(tr_loss / global_step))
|
276 |
+
|
277 |
+
|
278 |
+
if args.local_rank in [-1, 0] and args.eval_steps > 0 and global_step % args.eval_steps == 0:
|
279 |
+
# Log metrics
|
280 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
281 |
+
results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, epoch=epoch)
|
282 |
+
for key, value in results.items():
|
283 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
284 |
+
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
285 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.eval_steps, global_step)
|
286 |
+
logging_loss = tr_loss
|
287 |
+
|
288 |
+
# Save checkpoints
|
289 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
290 |
+
# Save encoder model checkpoint
|
291 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
292 |
+
if not os.path.exists(output_encoder_dir):
|
293 |
+
os.makedirs(output_encoder_dir)
|
294 |
+
model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
|
295 |
+
if args.use_philly:
|
296 |
+
save_solid = False
|
297 |
+
while not save_solid:
|
298 |
+
try:
|
299 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
300 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
|
301 |
+
logger.info("Saving model checkpoint to %s", output_encoder_dir)
|
302 |
+
save_solid = True
|
303 |
+
except:
|
304 |
+
pass
|
305 |
+
else:
|
306 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
307 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
|
308 |
+
logger.info("Saving model checkpoint to %s", output_encoder_dir)
|
309 |
+
|
310 |
+
# Save decoder model checkpoint
|
311 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
312 |
+
if not os.path.exists(output_decoder_dir):
|
313 |
+
os.makedirs(output_decoder_dir)
|
314 |
+
model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
|
315 |
+
if args.use_philly:
|
316 |
+
save_solid = False
|
317 |
+
while not save_solid:
|
318 |
+
try:
|
319 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
320 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
|
321 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
322 |
+
save_solid = True
|
323 |
+
except:
|
324 |
+
pass
|
325 |
+
else:
|
326 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
327 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
|
328 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
329 |
+
|
330 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
331 |
+
break
|
332 |
+
|
333 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
334 |
+
train_iterator.close()
|
335 |
+
break
|
336 |
+
|
337 |
+
if args.local_rank in [-1, 0]:
|
338 |
+
tb_writer.close()
|
339 |
+
|
340 |
+
return global_step, tr_loss / global_step
|
341 |
+
|
342 |
+
|
343 |
+
def evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, prefix="", subset="test", epoch=None):
|
344 |
+
|
345 |
+
eval_output_dir = args.output_dir
|
346 |
+
|
347 |
+
if subset == 'test':
|
348 |
+
eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
|
349 |
+
elif subset == 'train':
|
350 |
+
eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
|
351 |
+
else:
|
352 |
+
raise ValueError
|
353 |
+
|
354 |
+
args.label_size = len(eval_dataset.get_labels())
|
355 |
+
|
356 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
357 |
+
os.makedirs(eval_output_dir)
|
358 |
+
|
359 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
360 |
+
# Note that DistributedSampler samples randomly
|
361 |
+
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
362 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
363 |
+
|
364 |
+
# Eval!
|
365 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
366 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
367 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
368 |
+
logger.info(" Num steps = %d", len(eval_dataset) // args.eval_batch_size)
|
369 |
+
logger.info(" eval_output_dir = %s", eval_output_dir)
|
370 |
+
|
371 |
+
model_vae.eval()
|
372 |
+
model_vae_module = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
|
373 |
+
|
374 |
+
outputs = {
|
375 |
+
'sampled_cond_labels': None,
|
376 |
+
'cond_labels': None,
|
377 |
+
'tgt_seq_ids': None,
|
378 |
+
'generated': None,
|
379 |
+
'at_generated': None,
|
380 |
+
'cg_generated': None,
|
381 |
+
'pred_cls': None,
|
382 |
+
'pred_ge_cls': None,
|
383 |
+
'pred_at_cls': None,
|
384 |
+
'pred_cg_cls': None,
|
385 |
+
}
|
386 |
+
|
387 |
+
for bi, batch in enumerate(tqdm(eval_dataloader, desc="#Sentences", disable=args.local_rank not in [-1, 0]) ):
|
388 |
+
# if bi == 3:
|
389 |
+
# break
|
390 |
+
|
391 |
+
# Data
|
392 |
+
input_seq_ids, tgt_seq_ids, tokenized_text_lengths, cond_labels = batch
|
393 |
+
max_len_values, _ = tokenized_text_lengths.max(0)
|
394 |
+
input_seq_ids = input_seq_ids[:,:max_len_values[0]]
|
395 |
+
tgt_seq_ids = tgt_seq_ids[:,:max_len_values[1]]
|
396 |
+
input_seq_ids = input_seq_ids.to(args.device)
|
397 |
+
tgt_seq_ids = tgt_seq_ids.to(args.device)
|
398 |
+
cond_labels = cond_labels.to(args.device)
|
399 |
+
input_mask = torch.where(torch.arange(max_len_values[0].item()).unsqueeze(0).repeat(input_seq_ids.size(0), 1).type_as(tokenized_text_lengths).to(args.device)
|
400 |
+
< tokenized_text_lengths[:, 0].unsqueeze(1).to(args.device), torch.ones_like(input_seq_ids), torch.zeros_like(input_seq_ids))
|
401 |
+
|
402 |
+
# Model
|
403 |
+
with torch.no_grad():
|
404 |
+
result = model_vae(input_seq_ids=input_seq_ids, tgt_seq_ids=tgt_seq_ids, cond_labels=cond_labels, attention_mask=input_mask)
|
405 |
+
if bi == 0:
|
406 |
+
for key in outputs.keys():
|
407 |
+
outputs[key] = result[key].cpu().tolist()
|
408 |
+
else:
|
409 |
+
for key in outputs.keys():
|
410 |
+
outputs[key].extend(result[key].cpu().tolist())
|
411 |
+
|
412 |
+
# compute accuracies and store in results
|
413 |
+
acc = np.mean(np.array(np.array(outputs['pred_cls']) == np.array(outputs['cond_labels']), dtype=np.float))
|
414 |
+
acc_ge = np.mean(np.array(np.array(outputs['pred_ge_cls']) == np.array(outputs['cond_labels']), dtype=np.float))
|
415 |
+
acc_at = np.mean(np.array(np.array(outputs['pred_at_cls']) == np.array(outputs['sampled_cond_labels']), dtype=np.float))
|
416 |
+
acc_cg = np.mean(np.array(np.array(outputs['pred_cg_cls']) == np.array(outputs['sampled_cond_labels']), dtype=np.float))
|
417 |
+
metrics = {'acc': acc, 'acc_ge': acc_ge, 'acc_at': acc_at, 'acc_cg': acc_cg}
|
418 |
+
|
419 |
+
# dump generated outputs to file.
|
420 |
+
json.dump(outputs, open(os.path.join(eval_output_dir, "outputs_{}.json".format(epoch) if epoch is not None else "outputs.json"), 'w'))
|
421 |
+
|
422 |
+
# compute BLEU
|
423 |
+
bos_token_id = model_vae_module.tokenizer_decoder.encode('<BOS>')[0]
|
424 |
+
eos_token_id = model_vae_module.tokenizer_decoder.encode('<EOS>')[0]
|
425 |
+
pad_token_id = model_vae_module.tokenizer_decoder.encode('<PAD>')[0]
|
426 |
+
|
427 |
+
generated_ids = []
|
428 |
+
generated_text = []
|
429 |
+
for g in outputs['generated']:
|
430 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
431 |
+
g = g[1:]
|
432 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
433 |
+
g = g[1:]
|
434 |
+
g = g[:g.index(eos_token_id)] if eos_token_id in g else g
|
435 |
+
g = g[:g.index(pad_token_id)] if pad_token_id in g else g
|
436 |
+
g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
|
437 |
+
generated_ids.append(g)
|
438 |
+
generated_text.append(g_text)
|
439 |
+
|
440 |
+
tgt_seq_ids = []
|
441 |
+
tgt_seq_text = []
|
442 |
+
for g in outputs['tgt_seq_ids']:
|
443 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
444 |
+
g = g[1:]
|
445 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
446 |
+
g = g[1:]
|
447 |
+
g = g[:g.index(eos_token_id)] if eos_token_id in g else g
|
448 |
+
g = g[:g.index(pad_token_id)] if pad_token_id in g else g
|
449 |
+
g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
|
450 |
+
tgt_seq_ids.append(g)
|
451 |
+
tgt_seq_text.append(g_text)
|
452 |
+
|
453 |
+
at_generated_ids = []
|
454 |
+
at_generated_text = []
|
455 |
+
for g in outputs['at_generated']:
|
456 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
457 |
+
g = g[1:]
|
458 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
459 |
+
g = g[1:]
|
460 |
+
g = g[:g.index(eos_token_id)] if eos_token_id in g else g
|
461 |
+
g = g[:g.index(pad_token_id)] if pad_token_id in g else g
|
462 |
+
g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
|
463 |
+
at_generated_ids.append(g)
|
464 |
+
at_generated_text.append(g_text)
|
465 |
+
|
466 |
+
cg_generated_ids = []
|
467 |
+
cg_generated_text = []
|
468 |
+
for g in outputs['cg_generated']:
|
469 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
470 |
+
g = g[1:]
|
471 |
+
if g and g[0] in [eos_token_id, bos_token_id]:
|
472 |
+
g = g[1:]
|
473 |
+
g = g[:g.index(eos_token_id)] if eos_token_id in g else g
|
474 |
+
g = g[:g.index(pad_token_id)] if pad_token_id in g else g
|
475 |
+
g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
|
476 |
+
cg_generated_ids.append(g)
|
477 |
+
cg_generated_text.append(g_text)
|
478 |
+
|
479 |
+
f = open(os.path.join(eval_output_dir, "reconstruction{}.txt".format(('_'+str(epoch)) if epoch is not None else '')), 'w')
|
480 |
+
f.write('\n'.join([g + '\n' + t for g, t in zip(generated_text, tgt_seq_text)]))
|
481 |
+
fat = open(os.path.join(eval_output_dir, "attribute_transfer{}.txt".format(('_'+str(epoch)) if epoch is not None else '')), 'w')
|
482 |
+
fat.write('\n'.join([g + '\n' + t for g, t in zip(at_generated_text, tgt_seq_text)]))
|
483 |
+
fcg = open(os.path.join(eval_output_dir, "conditional_generation{}.txt".format(('_'+str(epoch)) if epoch is not None else '')), 'w')
|
484 |
+
fcg.write('\n'.join(cg_generated_text))
|
485 |
+
|
486 |
+
rec_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t)] for t in tgt_seq_text],
|
487 |
+
hypotheses=[nltk.word_tokenize(g) for g in generated_text])
|
488 |
+
|
489 |
+
at_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t)] for t in tgt_seq_text],
|
490 |
+
hypotheses=[nltk.word_tokenize(g) for g in at_generated_text])
|
491 |
+
|
492 |
+
cg_generated_text_subset = cg_generated_text[:500] # use a subset, otherwise it takes a long time to compute.
|
493 |
+
cg_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t) for t in tgt_seq_text] for _ in range(len(cg_generated_text_subset))],
|
494 |
+
hypotheses=[nltk.word_tokenize(g) for g in cg_generated_text_subset])
|
495 |
+
|
496 |
+
cg_self_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t) for t in cg_generated_text_subset[:i]+cg_generated_text_subset[i+1:]]
|
497 |
+
for i in range(len(cg_generated_text_subset))],
|
498 |
+
hypotheses=[nltk.word_tokenize(g) for g in cg_generated_text_subset])
|
499 |
+
|
500 |
+
metrics['rec_bleu'] = rec_bleu
|
501 |
+
metrics['at_bleu'] = at_bleu
|
502 |
+
metrics['cg_bleu'] = cg_bleu
|
503 |
+
metrics['cg_self_bleu'] = cg_self_bleu
|
504 |
+
|
505 |
+
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
506 |
+
writer = open(output_eval_file, "w")
|
507 |
+
logger.info("***** Eval results, global steps: {} *****".format(prefix))
|
508 |
+
for key, value in metrics.items():
|
509 |
+
logger.info(" %s = %s", key, str(value))
|
510 |
+
writer.write("%s = %s\n" % (key, str(value)))
|
511 |
+
|
512 |
+
return metrics
|
513 |
+
|
514 |
+
def main():
|
515 |
+
parser = argparse.ArgumentParser()
|
516 |
+
|
517 |
+
## Required parameters
|
518 |
+
parser.add_argument("--output_dir", default='results_cara', type=str, help="The output directory where the model predictions and checkpoints will be written.")
|
519 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
520 |
+
parser.add_argument("--soft_temperature", type=float, default=0.5)
|
521 |
+
parser.add_argument("--top_k", type=int, default=5)
|
522 |
+
parser.add_argument("--top_p", type=float, default=0.0)
|
523 |
+
parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.")
|
524 |
+
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
525 |
+
parser.add_argument("--lambda", default=0, type=float, help="")
|
526 |
+
|
527 |
+
## Data parameters
|
528 |
+
parser.add_argument("--dataset", default='yelp', type=str, help="The dataset.")
|
529 |
+
# parser.add_argument("--train_data_file", default='../../../data/yelp/sentiment.train.tiny.text', type=str, help="The input training data file (a text file).")
|
530 |
+
parser.add_argument("--train_data_file", default='../../../data/yelp/sentiment.train.text', type=str, help="The input training data file (a text file).")
|
531 |
+
# parser.add_argument("--eval_data_file", default='../../../data/yelp/sentiment.dev.tiny.text', type=str, help="")
|
532 |
+
parser.add_argument("--eval_data_file", default='../../../data/yelp/sentiment.dev.small.text', type=str, help="2000 samples.")
|
533 |
+
parser.add_argument("--ExpName", default="local_lctrlg_yelp", type=str, help="The experiment name used in Azure Table.")
|
534 |
+
parser.add_argument("--create_new", default=0, type=int, help="")
|
535 |
+
|
536 |
+
# Training parameters
|
537 |
+
parser.add_argument("--checkpoint_dir", default='results_arae/checkpoint-47501/pytorch_model.bin', type=str, help='results/checkpoint-1212/pytorch_model.bin')
|
538 |
+
# parser.add_argument("--checkpoint", default='', type=str, help='results/checkpoint-1212/pytorch_model.bin')
|
539 |
+
parser.add_argument("--start_global_step", default=1001, type=int, help='')
|
540 |
+
parser.add_argument("--do_train", action='store_true',
|
541 |
+
help="Whether to run training.")
|
542 |
+
parser.add_argument("--do_eval", action='store_true',
|
543 |
+
help="Whether to run eval on the dev set.")
|
544 |
+
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
545 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.")
|
546 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.")
|
547 |
+
parser.add_argument("--evaluate_during_training", action='store_true', help="Run evaluation during training at each logging step.")
|
548 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=0, help="Evaluate the results at the given global step")
|
549 |
+
# parser.add_argument('--logging_steps', type=int, default=2000, help="ARAE")
|
550 |
+
parser.add_argument('--logging_steps', type=int, default=10, help="CARA")
|
551 |
+
parser.add_argument('--eval_steps', type=int, default=500, help="CARA")
|
552 |
+
# parser.add_argument('--save_steps', type=int, default=5000, help="ARAE")
|
553 |
+
parser.add_argument('--save_steps', type=int, default=1000, help="CARA")
|
554 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true', help="")
|
555 |
+
|
556 |
+
## Encoder options
|
557 |
+
# parser.add_argument("--encoder_model_name_or_path", default="bert-base-uncased", type=str, )
|
558 |
+
parser.add_argument("--encoder_model_name_or_path", default="results_cara/checkpoint-encoder-1000", type=str)
|
559 |
+
# parser.add_argument("--encoder_model_name_or_path", default="results/checkpoint-encoder-55000", type=str")
|
560 |
+
parser.add_argument("--encoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path")
|
561 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str, help="Keep empty. Will default to decoder_model_name_or_path")
|
562 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str, help="The encoder model architecture to be fine-tuned.")
|
563 |
+
|
564 |
+
## Decoder options
|
565 |
+
# parser.add_argument("--decoder_model_name_or_path", default="gpt2", type=str)
|
566 |
+
parser.add_argument("--decoder_model_name_or_path", default="results_cara/checkpoint-decoder-1000", type=str)
|
567 |
+
# parser.add_argument("--decoder_model_name_or_path", default="results/checkpoint-decoder-55000", type=str)
|
568 |
+
parser.add_argument("--decoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path")
|
569 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str, help="Keep empty. Will default to decoder_model_name_or_path")
|
570 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.")
|
571 |
+
|
572 |
+
## Variational auto-encoder
|
573 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
574 |
+
parser.add_argument("--use_deterministic_connect", action='store_true', help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
|
575 |
+
|
576 |
+
## Objective functions
|
577 |
+
parser.add_argument("--mlm", action='store_true', help="Train with masked-language modeling loss instead of language modeling.")
|
578 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss")
|
579 |
+
parser.add_argument("--cache_dir", default="", type=str, help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
580 |
+
parser.add_argument("--block_size", default=21, type=int, help="21 for Yelp and Yahoo on label-conditional text generation")
|
581 |
+
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
582 |
+
|
583 |
+
# Training Schedules
|
584 |
+
parser.add_argument("--ratio_increase", default=0.25, type=float, help="Learning schedule, the percentage for the annealing stage.")
|
585 |
+
parser.add_argument("--ratio_zero", default=0.5, type=float, help="Learning schedule, the percentage for the pure auto-encoding stage.")
|
586 |
+
parser.add_argument("--fb_mode", default=1, type=int, help="free bit training mode.")
|
587 |
+
parser.add_argument("--dim_target_kl", default=3.0, type=float, help="dim_target_kl free bit training mode.")
|
588 |
+
parser.add_argument("--learning_rate", default=5e-6, type=float, help="The initial learning rate for Adam.")
|
589 |
+
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
590 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
591 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
592 |
+
parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
593 |
+
parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.")
|
594 |
+
parser.add_argument("--use_pretrained_model", action='store_true',
|
595 |
+
help="Use pre-trained auto-encoder models as the initialization")
|
596 |
+
parser.add_argument("--use_pretrained_vae", action='store_true',
|
597 |
+
help="Use use_pretrained_vae as initialization, where beta value is specified in the folder")
|
598 |
+
|
599 |
+
parser.add_argument("--beta", type=float, default=1.0, help="The weighting hyper-parameter of the KL term in VAE")
|
600 |
+
parser.add_argument("--beta_cls", type=float, default=1.0, help="The weighting hyper-parameter for the classifier on the generated sentences")
|
601 |
+
|
602 |
+
## IO: Logging and Saving
|
603 |
+
parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available")
|
604 |
+
parser.add_argument('--overwrite_output_dir', type=int, default=1, help="Overwrite the content of the output directory")
|
605 |
+
parser.add_argument('--overwrite_cache', action='store_true', help="Overwrite the cached training and evaluation sets")
|
606 |
+
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
607 |
+
|
608 |
+
# Precision & Distributed Training
|
609 |
+
parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
610 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1', help="")
|
611 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
612 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
613 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
614 |
+
|
615 |
+
# New parameters
|
616 |
+
parser.add_argument('--label_size', type=int, default=2, help="This depends on which dataset is used.")
|
617 |
+
args = parser.parse_args()
|
618 |
+
if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
|
619 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm flag (masked language modeling).")
|
620 |
+
if args.eval_data_file is None and args.do_eval:
|
621 |
+
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file or remove the --do_eval argument.")
|
622 |
+
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
623 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
624 |
+
# Setup distant debugging if needed
|
625 |
+
if args.server_ip and args.server_port:
|
626 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
627 |
+
import ptvsd
|
628 |
+
logger.info("Waiting for debugger attach")
|
629 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
630 |
+
ptvsd.wait_for_attach()
|
631 |
+
# Setup CUDA, GPU & distributed training
|
632 |
+
if args.local_rank == -1 or args.no_cuda:
|
633 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
634 |
+
args.n_gpu = torch.cuda.device_count()
|
635 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
636 |
+
torch.cuda.set_device(args.local_rank)
|
637 |
+
device = torch.device("cuda", args.local_rank)
|
638 |
+
torch.distributed.init_process_group(backend='nccl')
|
639 |
+
args.n_gpu = 1
|
640 |
+
args.device = device
|
641 |
+
# pdb.set_trace()
|
642 |
+
# Setup logging
|
643 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S',
|
644 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
645 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
646 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
647 |
+
|
648 |
+
args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + \
|
649 |
+
'_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
|
650 |
+
table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
|
651 |
+
set_seed(args)
|
652 |
+
|
653 |
+
# Load pretrained model and tokenizer
|
654 |
+
if args.local_rank not in [-1, 0]:
|
655 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
656 |
+
|
657 |
+
|
658 |
+
|
659 |
+
|
660 |
+
if args.use_pretrained_model:
|
661 |
+
args.encoder_model_type = args.encoder_model_type.lower()
|
662 |
+
args.decoder_model_type = args.decoder_model_type.lower()
|
663 |
+
|
664 |
+
global_step = args.gloabl_step_eval
|
665 |
+
|
666 |
+
if args.use_pretrained_vae:
|
667 |
+
output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}-1.0'.format(global_step))
|
668 |
+
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}-1.0'.format(global_step))
|
669 |
+
else:
|
670 |
+
output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
|
671 |
+
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
|
672 |
+
|
673 |
+
checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
674 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
675 |
+
|
676 |
+
# Load a trained Encoder model and vocabulary
|
677 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
678 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
|
679 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
680 |
+
|
681 |
+
model_encoder.to(args.device)
|
682 |
+
if args.block_size <= 0:
|
683 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
684 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
685 |
+
|
686 |
+
# Load a trained Decoder model and vocabulary
|
687 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
688 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
|
689 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
690 |
+
model_decoder.to(args.device)
|
691 |
+
if args.block_size <= 0:
|
692 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
693 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
694 |
+
|
695 |
+
else:
|
696 |
+
## Encoder
|
697 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
698 |
+
encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
|
699 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
700 |
+
if args.block_size <= 0:
|
701 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
702 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
703 |
+
model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
|
704 |
+
# model_encoder = encoder_model_class(config=encoder_config, latent_size=args.latent_size)
|
705 |
+
|
706 |
+
## Decoder
|
707 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
708 |
+
decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
|
709 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
710 |
+
if args.block_size <= 0:
|
711 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
712 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
713 |
+
setattr(decoder_config, "latent_size", args.latent_size)
|
714 |
+
model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
|
715 |
+
# model_decoder = decoder_model_class(config=decoder_config, latent_size=args.latent_size)
|
716 |
+
|
717 |
+
# Chunyuan: Add Padding token to GPT2
|
718 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
719 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
720 |
+
logger.info('We have added {} tokens to GPT2'.format(num_added_toks))
|
721 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
722 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
723 |
+
|
724 |
+
|
725 |
+
# on_gpu = next(model_vae.parameters()).is_cuda
|
726 |
+
if args.local_rank == 0:
|
727 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
728 |
+
logger.info("Training/evaluation parameters %s", args)
|
729 |
+
|
730 |
+
if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)
|
731 |
+
# Training
|
732 |
+
|
733 |
+
logff = open(os.path.join(args.output_dir, 'log_{}'.format(get_time_str())), 'a')
|
734 |
+
|
735 |
+
if args.do_train:
|
736 |
+
global_step = args.start_global_step
|
737 |
+
model_vae = CARA(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)
|
738 |
+
|
739 |
+
# if args.checkpoint:
|
740 |
+
# logger.info("Loading checkpoint from {}".format(args.checkpoint))
|
741 |
+
# model_vae.load_state_dict(torch.load(args.checkpoint))
|
742 |
+
|
743 |
+
if args.local_rank not in [-1, 0]:
|
744 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
745 |
+
if args.local_rank == 0:
|
746 |
+
torch.distributed.barrier()
|
747 |
+
|
748 |
+
train_dataset = load_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
|
749 |
+
|
750 |
+
# logger.info("Test evaluate before training.")
|
751 |
+
# evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=0, subset='test')
|
752 |
+
|
753 |
+
# Train
|
754 |
+
global_step, tr_loss = train(args, train_dataset, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, logff=logff)
|
755 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
756 |
+
|
757 |
+
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
758 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
759 |
+
# Create output directory if needed
|
760 |
+
# Save model checkpoint
|
761 |
+
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
762 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
763 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
764 |
+
if not os.path.exists(output_dir) and args.local_rank in [-1, 0]:
|
765 |
+
os.makedirs(output_dir)
|
766 |
+
if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
|
767 |
+
os.makedirs(output_encoder_dir)
|
768 |
+
if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
|
769 |
+
os.makedirs(output_decoder_dir)
|
770 |
+
|
771 |
+
logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
|
772 |
+
logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
|
773 |
+
|
774 |
+
model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
|
775 |
+
model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
|
776 |
+
model_to_save = model_vae.module if hasattr(model_vae, "module") else model_vae
|
777 |
+
|
778 |
+
# Good practice: save your training arguments together with the trained model
|
779 |
+
if args.use_philly:
|
780 |
+
save_solid = False
|
781 |
+
while not save_solid:
|
782 |
+
try:
|
783 |
+
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
784 |
+
torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
785 |
+
save_solid = True
|
786 |
+
except:
|
787 |
+
pass
|
788 |
+
else:
|
789 |
+
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
790 |
+
torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
791 |
+
args.checkpoint = os.path.join(output_dir, 'pytorch_model.bin')
|
792 |
+
|
793 |
+
if args.use_philly:
|
794 |
+
save_solid = False
|
795 |
+
while not save_solid:
|
796 |
+
try:
|
797 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
798 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
799 |
+
save_solid = True
|
800 |
+
except:
|
801 |
+
pass
|
802 |
+
else:
|
803 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
804 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
805 |
+
|
806 |
+
if args.use_philly:
|
807 |
+
save_solid = False
|
808 |
+
while not save_solid:
|
809 |
+
try:
|
810 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
811 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
812 |
+
save_solid = True
|
813 |
+
except:
|
814 |
+
pass
|
815 |
+
else:
|
816 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
817 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
818 |
+
|
819 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
820 |
+
# model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
|
821 |
+
# model_encoder.to(args.device)
|
822 |
+
#
|
823 |
+
# # Load a trained model and vocabulary that you have fine-tuned
|
824 |
+
# model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
|
825 |
+
# model_decoder.to(args.device)
|
826 |
+
|
827 |
+
# Evaluation
|
828 |
+
results = {}
|
829 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
830 |
+
# if global_step == 0:
|
831 |
+
# global_step = args.gloabl_step_eval
|
832 |
+
|
833 |
+
# output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
834 |
+
# output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
835 |
+
# checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
|
836 |
+
|
837 |
+
# logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
838 |
+
# for checkpoint in checkpoints:
|
839 |
+
|
840 |
+
# global_step = args.checkpoint_dir.split('/')[-2].split('-')[-1] if args.checkpoint_dir else ""
|
841 |
+
|
842 |
+
# model_encoder = encoder_model_class.from_pretrained(checkpoint[0], latent_size=args.latent_size)
|
843 |
+
# model_encoder.to(args.device)
|
844 |
+
# model_decoder = decoder_model_class.from_pretrained(checkpoint[1], latent_size=args.latent_size)
|
845 |
+
# model_decoder.to(args.device)
|
846 |
+
|
847 |
+
model_vae = CARA(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)
|
848 |
+
|
849 |
+
if args.gloabl_step_eval < 1:
|
850 |
+
args.gloabl_step_eval = global_step
|
851 |
+
args.checkpoint_dir = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(args.gloabl_step_eval))
|
852 |
+
else:
|
853 |
+
global_step = args.gloabl_step_eval
|
854 |
+
args.checkpoint_dir = os.path.join(args.checkpoint_dir, 'checkpoint-{}/pytorch_model.bin'.format(args.gloabl_step_eval))
|
855 |
+
|
856 |
+
|
857 |
+
# if args.checkpoint_dir and os.path.exists(args.checkpoint_dir):
|
858 |
+
# logger.info("Loading checkpoint from {}".format(args.checkpoint_dir))
|
859 |
+
# model_vae.load_state_dict(torch.load(args.checkpoint_dir))
|
860 |
+
# else:
|
861 |
+
# raise ValueError("Cannot find checkpoint at: {}".format(args.checkpoint))
|
862 |
+
|
863 |
+
metrics = evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='test')
|
864 |
+
metrics = dict((k + '_{}'.format(global_step), v) for k, v in metrics.items())
|
865 |
+
results.update(metrics)
|
866 |
+
|
867 |
+
# result = evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='train')
|
868 |
+
# result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
869 |
+
# results.update(result)
|
870 |
+
|
871 |
+
return results
|
872 |
+
|
873 |
+
|
874 |
+
if __name__ == "__main__":
|
875 |
+
main()
|
Optimus/code/examples/big_ae/run_lm_vae_pretraining.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
19 |
+
using a masked language modeling (MLM) loss.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import, division, print_function
|
23 |
+
|
24 |
+
|
25 |
+
import pdb
|
26 |
+
import argparse
|
27 |
+
import glob
|
28 |
+
import logging
|
29 |
+
|
30 |
+
import os
|
31 |
+
import pickle
|
32 |
+
import random
|
33 |
+
from pathlib import Path
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
import torch
|
37 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
|
38 |
+
from torch.utils.data.distributed import DistributedSampler
|
39 |
+
from tensorboardX import SummaryWriter
|
40 |
+
from tqdm import tqdm, trange
|
41 |
+
from collections import defaultdict
|
42 |
+
|
43 |
+
# from azure.cosmosdb.table.tableservice import TableService
|
44 |
+
# from azure.cosmosdb.table.models import Entity
|
45 |
+
from datetime import datetime
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
50 |
+
BertConfig, BertForLatentConnector, BertTokenizer,
|
51 |
+
GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
|
52 |
+
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
53 |
+
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
54 |
+
|
55 |
+
from utils import (calc_iwnll, calc_mi, calc_au, BucketingDataLoader, BucketingMultipleFiles_DataLoader, frange_cycle_linear, frange_cycle_zero_linear)
|
56 |
+
|
57 |
+
from modules import VAE
|
58 |
+
|
59 |
+
|
60 |
+
# logging.getLogger("azure").setLevel(logging.WARNING)
|
61 |
+
# logging.getLogger("TableService").setLevel(logging.WARNING)
|
62 |
+
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
|
66 |
+
MODEL_CLASSES = {
|
67 |
+
'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
|
68 |
+
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
69 |
+
'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
|
70 |
+
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
storage_name="textae"
|
75 |
+
key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
|
76 |
+
# ts = TableService(account_name=storage_name, account_key=key)
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
|
81 |
+
if isinstance(tokenizer, list):
|
82 |
+
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
83 |
+
file_path=args.train_data_file
|
84 |
+
dataloader = BucketingMultipleFiles_DataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True)
|
85 |
+
else:
|
86 |
+
pass
|
87 |
+
return dataloader
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def set_seed(args):
|
93 |
+
random.seed(args.seed)
|
94 |
+
np.random.seed(args.seed)
|
95 |
+
torch.manual_seed(args.seed)
|
96 |
+
if args.n_gpu > 0:
|
97 |
+
torch.cuda.manual_seed_all(args.seed)
|
98 |
+
|
99 |
+
|
100 |
+
def mask_tokens(inputs, tokenizer, args):
|
101 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
102 |
+
labels = inputs.clone()
|
103 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
104 |
+
|
105 |
+
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
|
106 |
+
labels[masked_indices==1] = -1 # We only compute loss on masked tokens
|
107 |
+
|
108 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
109 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
|
110 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
111 |
+
|
112 |
+
# 10% of the time, we replace masked input tokens with random word
|
113 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
|
114 |
+
indices_random = indices_random
|
115 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
116 |
+
inputs[indices_random] = random_words[indices_random]
|
117 |
+
|
118 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
119 |
+
return inputs, labels
|
120 |
+
|
121 |
+
|
122 |
+
def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
|
123 |
+
""" Train the model """
|
124 |
+
if args.local_rank in [-1, 0]:
|
125 |
+
tb_writer = SummaryWriter()
|
126 |
+
|
127 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
128 |
+
# train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
129 |
+
# train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
130 |
+
|
131 |
+
if args.max_steps > 0:
|
132 |
+
t_total = args.max_steps
|
133 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
134 |
+
else:
|
135 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
136 |
+
|
137 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
138 |
+
|
139 |
+
|
140 |
+
# model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
|
141 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
142 |
+
optimizer_grouped_parameters = [
|
143 |
+
{'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
144 |
+
{'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
145 |
+
]
|
146 |
+
|
147 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
148 |
+
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
149 |
+
|
150 |
+
|
151 |
+
if args.fp16:
|
152 |
+
try:
|
153 |
+
from apex import amp
|
154 |
+
except ImportError:
|
155 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
156 |
+
model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
|
157 |
+
|
158 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
159 |
+
if args.n_gpu > 1:
|
160 |
+
model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
|
161 |
+
|
162 |
+
# Distributed training (should be after apex fp16 initialization)
|
163 |
+
if args.local_rank != -1:
|
164 |
+
model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
|
165 |
+
output_device=args.local_rank,
|
166 |
+
find_unused_parameters=True)
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
files = Path(args.train_data_file)
|
172 |
+
num_files = len(list(files.glob('*seq64*.json')))
|
173 |
+
|
174 |
+
|
175 |
+
# Train!
|
176 |
+
logger.info("***** Running training *****")
|
177 |
+
logger.info(" Num files = %d", num_files)
|
178 |
+
logger.info(" Num examples of first file = %d", train_dataloader.num_examples)
|
179 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
180 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
181 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
182 |
+
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
183 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
184 |
+
logger.info(" Total optimization steps = %d", t_total)
|
185 |
+
|
186 |
+
|
187 |
+
global_step = 0
|
188 |
+
tr_loss, logging_loss = 0.0, 0.0
|
189 |
+
|
190 |
+
model_vae.zero_grad()
|
191 |
+
num_train_epochs_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
192 |
+
|
193 |
+
n_iter = int(args.num_train_epochs) * len(train_dataloader)
|
194 |
+
beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta, n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
|
195 |
+
|
196 |
+
tmp_list = []
|
197 |
+
dict_token_length = defaultdict(int)
|
198 |
+
|
199 |
+
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
200 |
+
for epoch in num_train_epochs_iterator:
|
201 |
+
train_dataloader.reset()
|
202 |
+
for idx_file in range(num_files-1):
|
203 |
+
logger.info(f"Epoch {epoch}, File idx {train_dataloader.file_idx}")
|
204 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
205 |
+
for step, batch in enumerate(epoch_iterator):
|
206 |
+
|
207 |
+
tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
|
208 |
+
|
209 |
+
dict_token_length[ tokenized_text_lengths[0,0].item() ] += 1
|
210 |
+
|
211 |
+
# continue
|
212 |
+
|
213 |
+
|
214 |
+
# tokenized_text0 = tokenized_text0.to(args.device)
|
215 |
+
# tokenized_text1 = tokenized_text1.to(args.device)
|
216 |
+
# prepare input-output data for reconstruction
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
|
221 |
+
labels = tokenized_text1
|
222 |
+
|
223 |
+
tokenized_text1 = tokenized_text1.to(args.device)
|
224 |
+
inputs = inputs.to(args.device)
|
225 |
+
labels = labels.to(args.device)
|
226 |
+
|
227 |
+
model_vae.train()
|
228 |
+
|
229 |
+
beta_t = 0.0 # beta_t_list[step + epoch*len(epoch_iterator)]
|
230 |
+
model_vae.module.args.beta = beta_t
|
231 |
+
|
232 |
+
if beta_t == 0.0:
|
233 |
+
model_vae.module.args.fb_mode = 0
|
234 |
+
else:
|
235 |
+
model_vae.module.args.fb_mode = 1
|
236 |
+
|
237 |
+
if args.use_deterministic_connect:
|
238 |
+
model_vae.module.args.fb_mode = 2
|
239 |
+
|
240 |
+
loss_rec, loss_kl, loss = model_vae(inputs, labels)
|
241 |
+
|
242 |
+
loss_rec = loss_rec.mean() # mean() to average on multi-gpu parallel training
|
243 |
+
loss_kl = loss_kl.mean()
|
244 |
+
loss = loss.mean()
|
245 |
+
|
246 |
+
if args.use_philly:
|
247 |
+
print("PROGRESS: {}%".format(round(100 * (step + epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) * len(epoch_iterator)) , 4)))
|
248 |
+
print("EVALERR: {}%".format(loss_rec))
|
249 |
+
|
250 |
+
epoch_iterator.set_description(
|
251 |
+
(
|
252 |
+
f'iter: {step + epoch*len(epoch_iterator) }; file:{idx_file}; loss: {loss.item():.3f}; '
|
253 |
+
f'loss_rec: {loss_rec.item():.3f}; loss_kl: {loss_kl.item():.3f}; '
|
254 |
+
f'beta: {model_vae.module.args.beta:.3f}'
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
# if global_step % 5 == 0:
|
259 |
+
# row = {
|
260 |
+
# 'PartitionKey': 'MILU_Rule_Rule_Template',
|
261 |
+
# 'RowKey': str(datetime.now()),
|
262 |
+
# 'ExpName' : args.ExpName,
|
263 |
+
# 'iter': str( step + epoch*len(epoch_iterator) ),
|
264 |
+
# 'loss': str( loss.item()),
|
265 |
+
# 'loss_rec': str(loss_rec.item()),
|
266 |
+
# 'loss_kl': str(loss_kl.item()),
|
267 |
+
# 'beta': str(model_vae.args.beta)
|
268 |
+
# }
|
269 |
+
# # pdb.set_trace()
|
270 |
+
# ts.insert_entity(table_name, row)
|
271 |
+
|
272 |
+
# pdb.set_trace()
|
273 |
+
|
274 |
+
if args.gradient_accumulation_steps > 1:
|
275 |
+
loss = loss / args.gradient_accumulation_steps
|
276 |
+
|
277 |
+
if args.fp16:
|
278 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
279 |
+
scaled_loss.backward()
|
280 |
+
else:
|
281 |
+
loss.backward()
|
282 |
+
|
283 |
+
tr_loss += loss.item()
|
284 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
285 |
+
if args.fp16:
|
286 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
287 |
+
else:
|
288 |
+
torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)
|
289 |
+
|
290 |
+
optimizer.step()
|
291 |
+
|
292 |
+
scheduler.step() # Update learning rate schedule
|
293 |
+
|
294 |
+
model_vae.zero_grad()
|
295 |
+
|
296 |
+
global_step += 1
|
297 |
+
|
298 |
+
|
299 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
300 |
+
# Log metrics
|
301 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
302 |
+
results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer)
|
303 |
+
for key, value in results.items():
|
304 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
305 |
+
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
306 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
307 |
+
logging_loss = tr_loss
|
308 |
+
|
309 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
310 |
+
|
311 |
+
# Save encoder model checkpoint
|
312 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
313 |
+
|
314 |
+
if not os.path.exists(output_encoder_dir):
|
315 |
+
os.makedirs(output_encoder_dir)
|
316 |
+
|
317 |
+
model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
|
318 |
+
if args.use_philly:
|
319 |
+
save_solid = False
|
320 |
+
while not save_solid:
|
321 |
+
try:
|
322 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
323 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
|
324 |
+
logger.info("Saving model checkpoint to %s", output_encoder_dir)
|
325 |
+
save_solid = True
|
326 |
+
except:
|
327 |
+
pass
|
328 |
+
else:
|
329 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
330 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
|
331 |
+
logger.info("Saving model checkpoint to %s", output_encoder_dir)
|
332 |
+
|
333 |
+
# Save decoder model checkpoint
|
334 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
335 |
+
|
336 |
+
if not os.path.exists(output_decoder_dir):
|
337 |
+
os.makedirs(output_decoder_dir)
|
338 |
+
|
339 |
+
model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
|
340 |
+
if args.use_philly:
|
341 |
+
save_solid = False
|
342 |
+
while not save_solid:
|
343 |
+
try:
|
344 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
345 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
|
346 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
347 |
+
save_solid = True
|
348 |
+
except:
|
349 |
+
pass
|
350 |
+
else:
|
351 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
352 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
|
353 |
+
logger.info("Saving model checkpoint to %s", output_decoder_dir)
|
354 |
+
|
355 |
+
|
356 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
357 |
+
epoch_iterator.close()
|
358 |
+
break
|
359 |
+
|
360 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
361 |
+
train_iterator.close()
|
362 |
+
break
|
363 |
+
|
364 |
+
|
365 |
+
# print(dict_token_length)
|
366 |
+
# with open('wikipedia_stats.json', 'w') as fp:
|
367 |
+
# json.dump(dict_token_length, fp)
|
368 |
+
|
369 |
+
if args.local_rank in [-1, 0]:
|
370 |
+
tb_writer.close()
|
371 |
+
|
372 |
+
return global_step, tr_loss / global_step
|
373 |
+
|
374 |
+
|
375 |
+
def main():
|
376 |
+
parser = argparse.ArgumentParser()
|
377 |
+
|
378 |
+
## Required parameters
|
379 |
+
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
380 |
+
help="The input training data file (a text file).")
|
381 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
382 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
383 |
+
parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
|
384 |
+
|
385 |
+
## Other parameters
|
386 |
+
parser.add_argument("--eval_data_file", default=None, type=str,
|
387 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
388 |
+
parser.add_argument("--ExpName", default="", type=str,
|
389 |
+
help="The experiment name used in Azure Table.")
|
390 |
+
|
391 |
+
## Encoder options
|
392 |
+
parser.add_argument("--encoder_model_type", default="bert", type=str,
|
393 |
+
help="The encoder model architecture to be fine-tuned.")
|
394 |
+
parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
|
395 |
+
help="The encoder model checkpoint for weights initialization.")
|
396 |
+
parser.add_argument("--encoder_config_name", default="", type=str,
|
397 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
398 |
+
parser.add_argument("--encoder_tokenizer_name", default="", type=str,
|
399 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
400 |
+
|
401 |
+
## Decoder options
|
402 |
+
parser.add_argument("--decoder_model_type", default="gpt2", type=str,
|
403 |
+
help="The decoder model architecture to be fine-tuned.")
|
404 |
+
parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
|
405 |
+
help="The decoder model checkpoint for weights initialization.")
|
406 |
+
parser.add_argument("--decoder_config_name", default="", type=str,
|
407 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
408 |
+
parser.add_argument("--decoder_tokenizer_name", default="", type=str,
|
409 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
410 |
+
|
411 |
+
## Variational auto-encoder
|
412 |
+
parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
|
413 |
+
parser.add_argument("--use_deterministic_connect", action='store_true',
|
414 |
+
help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
|
415 |
+
|
416 |
+
## Objective functions
|
417 |
+
parser.add_argument("--mlm", action='store_true',
|
418 |
+
help="Train with masked-language modeling loss instead of language modeling.")
|
419 |
+
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
420 |
+
help="Ratio of tokens to mask for masked language modeling loss")
|
421 |
+
parser.add_argument("--beta", type=float, default=1.0,
|
422 |
+
help="The weighting hyper-parameter of the KL term in VAE")
|
423 |
+
|
424 |
+
|
425 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
426 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
427 |
+
parser.add_argument("--max_seq_length", default=512, type=int,
|
428 |
+
help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
|
429 |
+
parser.add_argument("--block_size", default=-1, type=int,
|
430 |
+
help="Optional input sequence length after tokenization."
|
431 |
+
"The training dataset will be truncated in block of this size for training."
|
432 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
433 |
+
parser.add_argument("--do_train", action='store_true',
|
434 |
+
help="Whether to run training.")
|
435 |
+
parser.add_argument("--do_eval", action='store_true',
|
436 |
+
help="Whether to run eval on the dev set.")
|
437 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
438 |
+
help="Run evaluation during training at each logging step.")
|
439 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
440 |
+
help="Set this flag if you are using an uncased model.")
|
441 |
+
|
442 |
+
|
443 |
+
# Training Schedules
|
444 |
+
parser.add_argument("--ratio_increase", default=0.25, type=float,
|
445 |
+
help="Learning schedule, the percentage for the annealing stage.")
|
446 |
+
parser.add_argument("--ratio_zero", default=0.25, type=float,
|
447 |
+
help="Learning schedule, the percentage for the pure auto-encoding stage.")
|
448 |
+
parser.add_argument("--fb_mode", default=0, type=int,
|
449 |
+
help="free bit training mode.")
|
450 |
+
parser.add_argument("--dim_target_kl", default=3.0, type=float,
|
451 |
+
help="dim_target_kl free bit training mode.")
|
452 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
453 |
+
help="Batch size per GPU/CPU for training.")
|
454 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
|
455 |
+
help="Batch size per GPU/CPU for evaluation.")
|
456 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
457 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
458 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
459 |
+
help="The initial learning rate for Adam.")
|
460 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
461 |
+
help="Weight deay if we apply some.")
|
462 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
463 |
+
help="Epsilon for Adam optimizer.")
|
464 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
465 |
+
help="Max gradient norm.")
|
466 |
+
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
467 |
+
help="Total number of training epochs to perform.")
|
468 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
469 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
470 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
471 |
+
help="Linear warmup over warmup_steps.")
|
472 |
+
parser.add_argument("--use_philly", action='store_true',
|
473 |
+
help="Use Philly for computing.")
|
474 |
+
|
475 |
+
## IO: Logging and Saving
|
476 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
477 |
+
help="Log every X updates steps.")
|
478 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
479 |
+
help="Save checkpoint every X updates steps.")
|
480 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
481 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
482 |
+
parser.add_argument("--no_cuda", action='store_true',
|
483 |
+
help="Avoid using CUDA when available")
|
484 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
485 |
+
help="Overwrite the content of the output directory")
|
486 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
487 |
+
help="Overwrite the cached training and evaluation sets")
|
488 |
+
parser.add_argument('--seed', type=int, default=42,
|
489 |
+
help="random seed for initialization")
|
490 |
+
parser.add_argument('--gloabl_step_eval', type=int, default=661,
|
491 |
+
help="Evaluate the results at the given global step")
|
492 |
+
|
493 |
+
# Precision & Distributed Training
|
494 |
+
parser.add_argument('--fp16', action='store_true',
|
495 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
496 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
497 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
498 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
499 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
500 |
+
help="For distributed training: local_rank")
|
501 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
502 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
503 |
+
args = parser.parse_args()
|
504 |
+
|
505 |
+
if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
|
506 |
+
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
507 |
+
"flag (masked language modeling).")
|
508 |
+
if args.eval_data_file is None and args.do_eval:
|
509 |
+
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
510 |
+
"or remove the --do_eval argument.")
|
511 |
+
|
512 |
+
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
513 |
+
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
514 |
+
|
515 |
+
# Setup distant debugging if needed
|
516 |
+
if args.server_ip and args.server_port:
|
517 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
518 |
+
import ptvsd
|
519 |
+
print("Waiting for debugger attach")
|
520 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
521 |
+
ptvsd.wait_for_attach()
|
522 |
+
|
523 |
+
# Setup CUDA, GPU & distributed training
|
524 |
+
if args.local_rank == -1 or args.no_cuda:
|
525 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
526 |
+
args.n_gpu = torch.cuda.device_count()
|
527 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
528 |
+
torch.cuda.set_device(args.local_rank)
|
529 |
+
device = torch.device("cuda", args.local_rank)
|
530 |
+
torch.distributed.init_process_group(backend='nccl')
|
531 |
+
args.n_gpu = 1
|
532 |
+
args.device = device
|
533 |
+
|
534 |
+
# Setup logging
|
535 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
536 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
537 |
+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
538 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
539 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
540 |
+
|
541 |
+
args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
|
542 |
+
table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
|
543 |
+
try:
|
544 |
+
ts.create_table(table_name)
|
545 |
+
except:
|
546 |
+
pass
|
547 |
+
|
548 |
+
|
549 |
+
# Set seed
|
550 |
+
set_seed(args)
|
551 |
+
|
552 |
+
# Load pretrained model and tokenizer
|
553 |
+
if args.local_rank not in [-1, 0]:
|
554 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
555 |
+
|
556 |
+
## Encoder
|
557 |
+
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
|
558 |
+
encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
|
559 |
+
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
560 |
+
if args.block_size <= 0:
|
561 |
+
args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
562 |
+
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
|
563 |
+
model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
|
564 |
+
# model_encoder.to(args.device)
|
565 |
+
|
566 |
+
## Decoder
|
567 |
+
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
|
568 |
+
decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
|
569 |
+
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
|
570 |
+
if args.block_size <= 0:
|
571 |
+
args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
|
572 |
+
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
|
573 |
+
model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
|
574 |
+
|
575 |
+
# Chunyuan: Add Padding token to GPT2
|
576 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
577 |
+
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
|
578 |
+
print('We have added', num_added_toks, 'tokens to GPT2')
|
579 |
+
model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
580 |
+
assert tokenizer_decoder.pad_token == '<PAD>'
|
581 |
+
|
582 |
+
# model_decoder.to(args.device)
|
583 |
+
|
584 |
+
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) #
|
585 |
+
|
586 |
+
# on_gpu = next(model_vae.parameters()).is_cuda
|
587 |
+
|
588 |
+
|
589 |
+
|
590 |
+
if args.local_rank == 0:
|
591 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
592 |
+
|
593 |
+
logger.info("Training/evaluation parameters %s", args)
|
594 |
+
|
595 |
+
global_step= 0
|
596 |
+
# Training
|
597 |
+
if args.do_train:
|
598 |
+
if args.local_rank not in [-1, 0]:
|
599 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
600 |
+
|
601 |
+
train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
|
602 |
+
|
603 |
+
if args.local_rank == 0:
|
604 |
+
torch.distributed.barrier()
|
605 |
+
|
606 |
+
global_step, tr_loss = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
|
607 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
608 |
+
|
609 |
+
|
610 |
+
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
611 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
612 |
+
# Create output directory if needed
|
613 |
+
# Save model checkpoint
|
614 |
+
output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
|
615 |
+
output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
|
616 |
+
if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
|
617 |
+
os.makedirs(output_encoder_dir)
|
618 |
+
if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
|
619 |
+
os.makedirs(output_decoder_dir)
|
620 |
+
|
621 |
+
logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
|
622 |
+
logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
|
623 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
624 |
+
# They can then be reloaded using `from_pretrained()`
|
625 |
+
|
626 |
+
model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
|
627 |
+
model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
|
628 |
+
|
629 |
+
# Good practice: save your training arguments together with the trained model
|
630 |
+
if args.use_philly:
|
631 |
+
save_solid = False
|
632 |
+
while not save_solid:
|
633 |
+
try:
|
634 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
635 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
636 |
+
save_solid = True
|
637 |
+
except:
|
638 |
+
pass
|
639 |
+
else:
|
640 |
+
model_encoder_to_save.save_pretrained(output_encoder_dir)
|
641 |
+
torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
|
642 |
+
|
643 |
+
|
644 |
+
if args.use_philly:
|
645 |
+
save_solid = False
|
646 |
+
while not save_solid:
|
647 |
+
try:
|
648 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
649 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
|
650 |
+
save_solid = True
|
651 |
+
except:
|
652 |
+
pass
|
653 |
+
else:
|
654 |
+
model_decoder_to_save.save_pretrained(output_decoder_dir)
|
655 |
+
torch.save(args, os.path.join(output_decoder_dir, 'training_encoder_args.bin'))
|
656 |
+
|
657 |
+
|
658 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
659 |
+
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
|
660 |
+
model_encoder.to(args.device)
|
661 |
+
|
662 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
663 |
+
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
|
664 |
+
model_decoder.to(args.device)
|
665 |
+
|
666 |
+
|
667 |
+
|
668 |
+
if __name__ == "__main__":
|
669 |
+
main()
|