teticio commited on
Commit
da20094
1 Parent(s): c78ba1a

reproducible samples with seed

Browse files
Files changed (1) hide show
  1. notebooks/test_model.ipynb +20 -17
notebooks/test_model.ipynb CHANGED
@@ -84,17 +84,9 @@
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
87
- "audio_diffusion = AudioDiffusion(model_id=model_id)"
88
- ]
89
- },
90
- {
91
- "cell_type": "code",
92
- "execution_count": null,
93
- "id": "4dc17ac0",
94
- "metadata": {},
95
- "outputs": [],
96
- "source": [
97
- "mel = Mel(x_res=256, y_res=256)"
98
  ]
99
  },
100
  {
@@ -112,10 +104,13 @@
112
  "metadata": {},
113
  "outputs": [],
114
  "source": [
115
- "generator = torch.Generator()\n",
116
  "for _ in range(10):\n",
117
- " print(f'Seed = {generator.seed()}')\n",
118
- " image, (sample_rate, audio) = audio_diffusion.generate_spectrogram_and_audio(generator)\n",
 
 
 
 
119
  " display(image)\n",
120
  " display(Audio(audio, rate=sample_rate))\n",
121
  " loop = AudioDiffusion.loop_it(audio, sample_rate)\n",
@@ -149,9 +144,10 @@
149
  "outputs": [],
150
  "source": [
151
  "seed = 16183389798189209330 #@param {type:\"integer\"}\n",
 
152
  "image, (sample_rate,\n",
153
- " audio) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
154
- " generator=torch.Generator().manual_seed(seed))\n",
155
  "display(image)\n",
156
  "display(Audio(audio, rate=sample_rate))"
157
  ]
@@ -258,7 +254,6 @@
258
  "overlap_samples = overlap_secs * mel.get_sample_rate()\n",
259
  "slice_size = mel.x_res * mel.hop_length\n",
260
  "stride = slice_size - overlap_samples\n",
261
- "generator = torch.Generator()\n",
262
  "seed = generator.seed()\n",
263
  "print(f'Seed = {seed}')\n",
264
  "track = np.array([])\n",
@@ -346,6 +341,14 @@
346
  "audio = mel.image_to_audio(image)\n",
347
  "Audio(data=audio, rate=mel.get_sample_rate())"
348
  ]
 
 
 
 
 
 
 
 
349
  }
350
  ],
351
  "metadata": {
 
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
87
+ "audio_diffusion = AudioDiffusion(model_id=model_id)\n",
88
+ "mel = Mel(x_res=256, y_res=256)\n",
89
+ "generator = torch.Generator()"
 
 
 
 
 
 
 
 
90
  ]
91
  },
92
  {
 
104
  "metadata": {},
105
  "outputs": [],
106
  "source": [
 
107
  "for _ in range(10):\n",
108
+ " seed = generator.seed()\n",
109
+ " print(f'Seed = {seed}')\n",
110
+ " generator.manual_seed(seed)\n",
111
+ " image, (sample_rate,\n",
112
+ " audio) = audio_diffusion.generate_spectrogram_and_audio(\n",
113
+ " generator=generator)\n",
114
  " display(image)\n",
115
  " display(Audio(audio, rate=sample_rate))\n",
116
  " loop = AudioDiffusion.loop_it(audio, sample_rate)\n",
 
144
  "outputs": [],
145
  "source": [
146
  "seed = 16183389798189209330 #@param {type:\"integer\"}\n",
147
+ "generator.manual_seed(seed)\n",
148
  "image, (sample_rate,\n",
149
+ " audio) = audio_diffusion.generate_spectrogram_and_audio(\n",
150
+ " generator=generator)\n",
151
  "display(image)\n",
152
  "display(Audio(audio, rate=sample_rate))"
153
  ]
 
254
  "overlap_samples = overlap_secs * mel.get_sample_rate()\n",
255
  "slice_size = mel.x_res * mel.hop_length\n",
256
  "stride = slice_size - overlap_samples\n",
 
257
  "seed = generator.seed()\n",
258
  "print(f'Seed = {seed}')\n",
259
  "track = np.array([])\n",
 
341
  "audio = mel.image_to_audio(image)\n",
342
  "Audio(data=audio, rate=mel.get_sample_rate())"
343
  ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": null,
348
+ "id": "4deb47f4",
349
+ "metadata": {},
350
+ "outputs": [],
351
+ "source": []
352
  }
353
  ],
354
  "metadata": {