nateraw commited on
Commit
b7321c6
1 Parent(s): 09805a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -149
app.py CHANGED
@@ -1,8 +1,4 @@
1
- import time
2
- from pathlib import Path
3
-
4
- import gradio as gr
5
- from stable_diffusion_videos import StableDiffusionWalkPipeline, generate_images
6
  from diffusers.schedulers import LMSDiscreteScheduler
7
  import torch
8
 
@@ -11,149 +7,15 @@ from huggingface_hub import HfFolder
11
 
12
  HfFolder().save_token(os.environ['HF_TOKEN'])
13
 
14
-
15
- class ImageGenerationInterface:
16
- def __init__(self, pipeline):
17
- self.pipeline = pipeline
18
- self.interface_images = gr.Interface(
19
- self.fn,
20
- inputs=[
21
- gr.Textbox("blueberry spaghetti", label='Prompt'),
22
- gr.Slider(1, 24, 16, step=1, label='Batch size'),
23
- gr.Slider(1, 16, 1, step=1, label='# Batches'),
24
- gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
25
- gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
26
- gr.Slider(512, 1024, 512, step=64, label='Height'),
27
- gr.Slider(512, 1024, 512, step=64, label='Width'),
28
- gr.Checkbox(False, label='Upsample'),
29
- gr.Textbox("nateraw/stable-diffusion-gallery", label='(Optional) Repo ID'),
30
- gr.Checkbox(False, label='Push to Hub'),
31
- gr.Checkbox(False, label='Private'),
32
- gr.Textbox("./images", label='Output directory'),
33
- ],
34
- outputs=gr.Gallery(),
35
- )
36
-
37
- self.interface_videos = gr.Interface(
38
- self.fn_videos,
39
- inputs=[
40
- gr.Textbox("blueberry spaghetti\nstrawberry spaghetti", lines=2, label='Prompts, separated by new line'),
41
- gr.Textbox("42\n1337", lines=2, label='Seeds, separated by new line'),
42
- gr.Textbox("25\n27", lines=2, label='Audio Offsets (seconds in song), separated by new line'),
43
- gr.Audio(type="filepath"),
44
- gr.Slider(3, 60, 5, step=1, label='FPS'),
45
- gr.Slider(1, 24, 16, step=1, label='Batch size'),
46
- gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
47
- gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
48
- gr.Slider(512, 1024, 512, step=64, label='Height'),
49
- gr.Slider(512, 1024, 512, step=64, label='Width'),
50
- gr.Checkbox(False, label='Upsample'),
51
- ],
52
- outputs=gr.Video(),
53
- )
54
- self.interface = gr.TabbedInterface(
55
- [self.interface_images, self.interface_videos],
56
- ['Images!', 'Videos!'],
57
- )
58
-
59
- def fn_videos(
60
- self,
61
- prompts,
62
- seeds,
63
- audio_offsets,
64
- audio_filepath,
65
- fps,
66
- batch_size,
67
- num_inference_steps,
68
- guidance_scale,
69
- height,
70
- width,
71
- upsample,
72
- ):
73
- prompts = [x.strip() for x in prompts.split('\n')]
74
- seeds = [int(x.strip()) for x in seeds.split('\n')]
75
- audio_offsets = [float(x.strip()) for x in audio_offsets.split('\n')]
76
- num_interpolation_steps = [(b-a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]
77
-
78
- return self.pipeline.walk(
79
- prompts=prompts,
80
- seeds=seeds,
81
- num_interpolation_steps=num_interpolation_steps,
82
- audio_filepath=audio_filepath,
83
- audio_start_sec=audio_offsets[0],
84
- fps=fps,
85
- height=height,
86
- width=width,
87
- output_dir='dreams',
88
- guidance_scale=guidance_scale,
89
- num_inference_steps=num_inference_steps,
90
- upsample=upsample,
91
- batch_size=batch_size
92
- )
93
-
94
- def fn(
95
- self,
96
- prompt,
97
- batch_size,
98
- num_batches,
99
- num_inference_steps,
100
- guidance_scale,
101
- height,
102
- width,
103
- upsample,
104
- repo_id,
105
- push_to_hub,
106
- private,
107
- output_dir,
108
- ):
109
- output_path = Path(output_dir)
110
- name = time.strftime("%Y%m%d-%H%M%S")
111
- save_path = output_path / name
112
- image_filepaths = generate_images(
113
- self.pipeline,
114
- prompt,
115
- batch_size=batch_size,
116
- num_batches=num_batches,
117
- num_inference_steps=num_inference_steps,
118
- guidance_scale=guidance_scale,
119
- output_dir=output_dir,
120
- name=name,
121
- image_file_ext='.jpg',
122
- upsample=upsample,
123
- height=height,
124
- width=width,
125
- push_to_hub=push_to_hub,
126
- repo_id=repo_id,
127
- private=private,
128
- create_pr=False,
129
- )
130
- return [(x, Path(x).stem) for x in sorted(image_filepaths)]
131
-
132
- def launch(self, *args, **kwargs):
133
- self.interface.launch(*args, **kwargs)
134
-
135
-
136
- def main(
137
- model_id: str = "CompVis/stable-diffusion-v1-4",
138
- tiled=False,
139
- disable_safety_checker=False,
140
- ):
141
- safety_checker_kwargs = {'safety_checker': None} if disable_safety_checker else {}
142
- pipeline = StableDiffusionWalkPipeline.from_pretrained(
143
- model_id,
144
- revision="fp16",
145
- torch_dtype=torch.float16,
146
- scheduler=LMSDiscreteScheduler(
147
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
148
- ),
149
- tiled=tiled,
150
- **safety_checker_kwargs
151
- ).to("cuda")
152
- ImageGenerationInterface(pipeline).launch(debug=True)
153
-
154
 
155
  if __name__ == '__main__':
156
- #import fire
157
-
158
- #fire.Fire(main)
159
- main(disable_safety_checker=True)
 
1
+ from stable_diffusion_videos import StableDiffusionWalkPipeline, Interface
 
 
 
 
2
  from diffusers.schedulers import LMSDiscreteScheduler
3
  import torch
4
 
 
7
 
8
  HfFolder().save_token(os.environ['HF_TOKEN'])
9
 
10
+ pipeline = StableDiffusionWalkPipeline.from_pretrained(
11
+ "CompVis/stable-diffusion-v1-4",
12
+ torch_dtype=torch.float16,
13
+ revision="fp16",
14
+ scheduler=LMSDiscreteScheduler(
15
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
16
+ )
17
+ ).to("cuda")
18
+ interface = Interface(pipeline)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  if __name__ == '__main__':
21
+ interface.launch()