Audio-to-Audio
Transformers
PyTorch
perceiver-ar-symbolic-audio-model
krasserm cstub commited on
Commit
ee4d750
1 Parent(s): b09dff5

Create README.md (#1)

Browse files

- Create README.md (f57d0b8ce9a1f44a6e7f890c9043746fc93f09b2)


Co-authored-by: Christoph Stumpf <[email protected]>

Files changed (1) hide show
  1. README.md +235 -0
README.md ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ inference: false
4
+ pipeline_tag: audio-to-audio
5
+ ---
6
+
7
+ # Perceiver AR symbolic audio model
8
+
9
+ This model is a [Perceiver AR](https://arxiv.org/abs/2202.07765) symbolic audio model (134M parameters) pretrained on
10
+ the [GiantMIDI-Piano](https://github.com/bytedance/GiantMIDI-Piano) dataset for 27 epochs (157M tokens). It uses [rotary embedding](https://arxiv.org/abs/2104.09864)
11
+ for relative position encoding. It is a [training example](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#giantmidi-piano)
12
+ of the [perceiver-io](https://github.com/krasserm/perceiver-io) library.
13
+
14
+ ## Model description
15
+
16
+ Perceiver AR is a simple extension of a plain decoder-only transformer such as GPT-2, for example. A core building block
17
+ of both is the *decoder layer* consisting of a self-attention layer followed by a position-wise MLP. Self-attention uses
18
+ a causal attention mask.
19
+
20
+ Perceiver AR additionally cross-attends to a longer prefix of the input sequence in its first attention layer. This layer
21
+ is a hybrid self- and cross-attention layer. Self-attention is over the last n positions of the input sequence, with a
22
+ causal attention mask, cross-attention is from the last n positions to the first m positions. The length of the input
23
+ sequence is m + n. This allows a Perceiver AR to process a much larger context than decoder-only transformers which are
24
+ based on self-attention only.
25
+
26
+ <p align="center">
27
+ <img src="https://krasserm.github.io/img/2023-01-23/perceiver-ar.png" alt="Perceiver AR" width="600"/><br/>
28
+ <i>Fig. 1</i>. Attention in Perceiver AR with m=8 prefix tokens and n=3 latent tokens.
29
+ <p/>
30
+
31
+ The output of the hybrid attention layer are n latent arrays corresponding to the last n tokens of the input sequence.
32
+ These are further processed by a stack of L-1 decoder layers where the total number of attention layers is L. A final
33
+ layer (not shown in Fig. 1) predicts the target token for each latent position. The weights of the final layer are
34
+ shared with the input embedding layer. Except for the initial cross-attention to the prefix sequence, a Perceiver AR
35
+ is architecturally identical to a decoder-only transformer.
36
+
37
+ ## Model training
38
+
39
+ The model was [trained](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#giantmidi-piano) with
40
+ the task of symbolic audio modeling on the [GiantMIDI-Piano](https://github.com/bytedance/GiantMIDI-Piano) dataset
41
+ for 27 epochs (157M tokens). This dataset consists of [MIDI](https://en.wikipedia.org/wiki/MIDI) files, tokenized using the
42
+ approach from the [Perceiver AR paper](https://arxiv.org/pdf/2202.07765.pdf), which is described
43
+ in detail in Section A.2 of [Huang et al (2019)](https://arxiv.org/abs/1809.04281).
44
+ All hyperparameters are summarized in the [training script](https://github.com/krasserm/perceiver-io/blob/main/examples/training/sam/giantmidi/train.sh).
45
+ The context length was set to 6144 tokens with 2048 latent positions, resulting in a maximal prefix length of 4096. The
46
+ actual prefix length per example was randomly chosen between 0 and 4096. Training was done with [PyTorch Lightning](https://www.pytorchlightning.ai/index.html)
47
+ and the resulting checkpoint was converted to this 🤗 model with a library-specific [conversion utility](#checkpoint-conversion).
48
+
49
+ ## Intended use and limitations
50
+
51
+ This model can be used for audio generation with a user-defined initial number of latent tokens. It mainly exists for
52
+ demonstration purposes on how to train Perceiver AR models with the [perceiver-io library](https://github.com/krasserm/perceiver-io).
53
+ To improve on the quality of the generated audio samples a much larger dataset than
54
+ [GiantMIDI-Piano](https://github.com/bytedance/GiantMIDI-Piano) is required for training.
55
+
56
+ ## Usage examples
57
+
58
+ To use this model you first need to [install](https://github.com/krasserm/perceiver-io/blob/main/README.md#installation)
59
+ the `perceiver-io` library with extension `audio`.
60
+
61
+ ```shell
62
+ pip install perceiver-io[audio]
63
+ ```
64
+
65
+ Then the model can be used with PyTorch. Either use the model directly to generate MIDI files:
66
+
67
+ ```python
68
+ import torch
69
+
70
+ from perceiver.model.audio.symbolic import PerceiverSymbolicAudioModel
71
+ from perceiver.data.audio.midi_processor import decode_midi, encode_midi
72
+ from pretty_midi import PrettyMIDI
73
+
74
+ repo_id = "krasserm/perceiver-ar-sam-giant-midi"
75
+
76
+ model = PerceiverSymbolicAudioModel.from_pretrained(repo_id)
77
+
78
+ prompt_midi = PrettyMIDI("prompt.mid")
79
+ prompt = torch.tensor(encode_midi(prompt_midi)).unsqueeze(0)
80
+
81
+ output = model.generate(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0)
82
+
83
+ output_midi = decode_midi(output[0].cpu().numpy())
84
+ type(output_midi)
85
+ ```
86
+ ```
87
+ pretty_midi.pretty_midi.PrettyMIDI
88
+ ```
89
+
90
+ use a `symbolic-audio-generation` pipeline to generate a MIDI output:
91
+
92
+ ```python
93
+ from transformers import pipeline
94
+ from pretty_midi import PrettyMIDI
95
+ from perceiver.model.audio import symbolic # auto-class registration
96
+
97
+ repo_id = "krasserm/perceiver-ar-sam-giant-midi"
98
+
99
+ prompt = PrettyMIDI("prompt.mid")
100
+ audio_generator = pipeline("symbolic-audio-generation", model=repo_id)
101
+
102
+ output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0)
103
+ type(output["generated_audio_midi"])
104
+ ```
105
+ ```
106
+ pretty_midi.pretty_midi.PrettyMIDI
107
+ ```
108
+
109
+ or generate WAV output by rendering the MIDI symbols using [fluidsynth](https://www.fluidsynth.org/) (Note: fluidsynth must be installed
110
+ in order for the following example to work):
111
+
112
+ ```python
113
+ from transformers import pipeline
114
+ from pretty_midi import PrettyMIDI
115
+ from perceiver.model.audio import symbolic # auto-class registration
116
+
117
+ repo_id = "krasserm/perceiver-ar-sam-giant-midi"
118
+
119
+ prompt = PrettyMIDI("prompt.mid")
120
+ audio_generator = pipeline("symbolic-audio-generation", model=repo_id)
121
+
122
+ output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0, render=True)
123
+
124
+ with open("generated_audio.wav", "wb") as f:
125
+ f.write(output["generated_audio_wav"])
126
+ ```
127
+
128
+ ## Audio samples
129
+
130
+ The following (hand-picked) audio samples were generated using various prompts from the validation subset of
131
+ the [GiantMIDI-Piano](https://github.com/bytedance/GiantMIDI-Piano) dataset. The input prompts are
132
+ not included in the audio output.
133
+
134
+ <table>
135
+ <tr>
136
+ <th>Audio sample</th>
137
+ <th>Top-K</th>
138
+ <th>Top-p</th>
139
+ <th>Temperature</th>
140
+ <th>Prefix length</th>
141
+ <th>Latents</th>
142
+ </tr>
143
+ <tr>
144
+ <td>
145
+ <audio controls>
146
+ <source src="https://martin-krasser.com/perceiver/data/midi/01_nehrlich_continuation.wav" type="audio/wav">
147
+ Your browser does not support the audio element.
148
+ </audio>
149
+ </td>
150
+ <td style="vertical-align: top;">-</td>
151
+ <td style="vertical-align: top;">0.95</td>
152
+ <td style="vertical-align: top;">0.95</td>
153
+ <td style="vertical-align: top;">4096</td>
154
+ <td style="vertical-align: top;">1</td>
155
+ </tr>
156
+ <tr>
157
+ <td>
158
+ <audio controls>
159
+ <source src="https://martin-krasser.com/perceiver/data/midi/02_eduardo_continuation.wav" type="audio/wav">
160
+ Your browser does not support the audio element.
161
+ </audio>
162
+ </td>
163
+ <td style="vertical-align: top;">-</td>
164
+ <td style="vertical-align: top;">0.95</td>
165
+ <td style="vertical-align: top;">1.0</td>
166
+ <td style="vertical-align: top;">4096</td>
167
+ <td style="vertical-align: top;">64</td>
168
+ </tr>
169
+ <tr>
170
+ <td>
171
+ <audio controls>
172
+ <source src="https://martin-krasser.com/perceiver/data/midi/03_membree_continuation.wav" type="audio/wav">
173
+ Your browser does not support the audio element.
174
+ </audio>
175
+ </td>
176
+ <td style="vertical-align: top;">-</td>
177
+ <td style="vertical-align: top;">0.95</td>
178
+ <td style="vertical-align: top;">1.0</td>
179
+ <td style="vertical-align: top;">1024</td>
180
+ <td style="vertical-align: top;">1</td>
181
+ </tr>
182
+ <tr>
183
+ <td>
184
+ <audio controls>
185
+ <source src="https://martin-krasser.com/perceiver/data/midi/04_membree_continuation.wav" type="audio/wav">
186
+ Your browser does not support the audio element.
187
+ </audio>
188
+ </td>
189
+ <td style="vertical-align: top;">15</td>
190
+ <td style="vertical-align: top;">-</td>
191
+ <td style="vertical-align: top;">1.0</td>
192
+ <td style="vertical-align: top;">4096</td>
193
+ <td style="vertical-align: top;">16</td>
194
+ </tr>
195
+ <tr>
196
+ <td>
197
+ <audio controls>
198
+ <source src="https://martin-krasser.com/perceiver/data/midi/05_kinscella_continuation.wav" type="audio/wav">
199
+ Your browser does not support the audio element.
200
+ </audio>
201
+ </td>
202
+ <td style="vertical-align: top;">-</td>
203
+ <td style="vertical-align: top;">0.95</td>
204
+ <td style="vertical-align: top;">1.0</td>
205
+ <td style="vertical-align: top;">4096</td>
206
+ <td style="vertical-align: top;">1</td>
207
+ </tr>
208
+ </table>
209
+
210
+ ## Checkpoint conversion
211
+
212
+ The `krasserm/perceiver-ar-sam-giant-midi` model has been created from a training checkpoint with:
213
+
214
+ ```python
215
+ from perceiver.model.audio.symbolic import convert_checkpoint
216
+
217
+ convert_checkpoint(
218
+ save_dir="krasserm/perceiver-ar-sam-giant-midi",
219
+ ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/sam/version_1/checkpoints/epoch=027-val_loss=1.944.ckpt",
220
+ push_to_hub=True,
221
+ )
222
+ ```
223
+
224
+ ## Citation
225
+
226
+ ```bibtex
227
+ @inproceedings{hawthorne2022general,
228
+ title={General-purpose, long-context autoregressive modeling with perceiver ar},
229
+ author={Hawthorne, Curtis and Jaegle, Andrew and Cangea, C{\u{a}}t{\u{a}}lina and Borgeaud, Sebastian and Nash, Charlie and Malinowski, Mateusz and Dieleman, Sander and Vinyals, Oriol and Botvinick, Matthew and Simon, Ian and others},
230
+ booktitle={International Conference on Machine Learning},
231
+ pages={8535--8558},
232
+ year={2022},
233
+ organization={PMLR}
234
+ }
235
+ ```