Dimitri commited on
Commit
92cf4eb
1 Parent(s): 61d0d14
Files changed (1) hide show
  1. app.py +59 -28
app.py CHANGED
@@ -10,13 +10,33 @@ from fabric.generator import AttentionBasedGenerator
10
  model_name = ""
11
  model_ckpt = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_7_pruned.safetensors"
12
 
13
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- generator = AttentionBasedGenerator(
16
- model_name=model_name if model_name else None,
17
- model_ckpt=model_ckpt if model_ckpt else None,
18
- torch_dtype=dtype,
19
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  css = """
@@ -96,33 +116,44 @@ def generate_fn(
96
  liked = []
97
  disliked = disliked[-max_feedback_imgs:]
98
  # else: keep all feedback images
99
-
100
- images = generator.generate(
101
- prompt=prompt,
102
- negative_prompt=neg_prompt,
103
- liked=liked,
104
- disliked=disliked,
105
- denoising_steps=denoising_steps,
106
- guidance_scale=guidance_scale,
107
- feedback_start=feedback_start,
108
- feedback_end=feedback_end,
109
- min_weight=min_weight,
110
- max_weight=max_weight,
111
- neg_scale=neg_scale,
112
- seed=seed,
113
- n_images=batch_size,
114
- )
 
 
 
 
 
 
 
115
  return [(img, f"Image {i+1}") for i, img in enumerate(images)], images
116
  except Exception as err:
117
  raise gr.Error(str(err))
118
 
119
 
120
  def add_img_from_list(i, curr_imgs, all_imgs):
 
 
121
  if i >= 0 and i < len(curr_imgs):
122
  all_imgs.append(curr_imgs[i])
123
  return all_imgs, all_imgs # return (gallery, state)
124
 
125
  def add_img(img, all_imgs):
 
 
126
  all_imgs.append(img)
127
  return None, all_imgs, all_imgs
128
 
@@ -148,7 +179,7 @@ with gr.Blocks(css=css) as demo:
148
  with gr.Column():
149
  denoising_steps = gr.Slider(1, 100, value=20, step=1, label="Sampling steps")
150
  guidance_scale = gr.Slider(0.0, 30.0, value=6, step=0.25, label="CFG scale")
151
- batch_size = gr.Slider(1, 10, value=4, step=1, label="Batch size")
152
  seed = gr.Number(-1, minimum=-1, precision=0, label="Seed")
153
  max_feedback_imgs = gr.Slider(0, 20, value=6, step=1, label="Max. feedback images", info="Maximum number of liked/disliked images to be used. If exceeded, only the most recent images will be used as feedback. (NOTE: large number of feedback imgs => high VRAM requirements)")
154
  feedback_enabled = gr.Checkbox(True, label="Enable feedback", interactive=True)
@@ -222,8 +253,8 @@ with gr.Blocks(css=css) as demo:
222
  liked_img_input.upload(add_img, [liked_img_input, liked_imgs], [liked_img_input, like_gallery, liked_imgs], queue=False)
223
  disliked_img_input.upload(add_img, [disliked_img_input, disliked_imgs], [disliked_img_input, dislike_gallery, disliked_imgs], queue=False)
224
 
225
- clear_liked_btn.click(lambda: [None, None], None, [liked_imgs, like_gallery], queue=False)
226
- clear_disliked_btn.click(lambda: [None, None], None, [disliked_imgs, dislike_gallery], queue=False)
227
 
228
- demo.queue(8)
229
- demo.launch()
 
10
  model_name = ""
11
  model_ckpt = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_7_pruned.safetensors"
12
 
13
+ class GeneratorWrapper:
14
+ def __init__(self, model_name=None, model_ckpt=None):
15
+ self.model_name = model_name if model_name else None
16
+ self.model_ckpt = model_ckpt if model_ckpt else None
17
+ self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ self.reload()
21
+
22
+ def generate(self, *args, **kwargs):
23
+ return self.generator.generate(*args, **kwargs)
24
+
25
+ def to(self, device):
26
+ return self.generator.to(device)
27
+
28
+ def reload(self):
29
+ if hasattr(self, "generator"):
30
+ del self.generator
31
+ if self.device == "cuda":
32
+ torch.cuda.empty_cache()
33
+ self.generator = AttentionBasedGenerator(
34
+ model_name=self.model_name,
35
+ model_ckpt=self.model_ckpt,
36
+ torch_dtype=self.dtype,
37
+ ).to(self.device)
38
+
39
+ generator = GeneratorWrapper(model_name, model_ckpt)
40
 
41
 
42
  css = """
 
116
  liked = []
117
  disliked = disliked[-max_feedback_imgs:]
118
  # else: keep all feedback images
119
+
120
+ generate_kwargs = {
121
+ "prompt": prompt,
122
+ "negative_prompt": neg_prompt,
123
+ "liked": liked,
124
+ "disliked": disliked,
125
+ "denoising_steps": denoising_steps,
126
+ "guidance_scale": guidance_scale,
127
+ "feedback_start": feedback_start,
128
+ "feedback_end": feedback_end,
129
+ "min_weight": min_weight,
130
+ "max_weight": max_weight,
131
+ "neg_scale": neg_scale,
132
+ "seed": seed,
133
+ "n_images": batch_size,
134
+ }
135
+
136
+ try:
137
+ images = generator.generate(**generate_kwargs)
138
+ except RuntimeError as err:
139
+ if 'out of memory' in str(err):
140
+ generator.reload()
141
+ raise
142
  return [(img, f"Image {i+1}") for i, img in enumerate(images)], images
143
  except Exception as err:
144
  raise gr.Error(str(err))
145
 
146
 
147
  def add_img_from_list(i, curr_imgs, all_imgs):
148
+ if all_imgs is None:
149
+ all_imgs = []
150
  if i >= 0 and i < len(curr_imgs):
151
  all_imgs.append(curr_imgs[i])
152
  return all_imgs, all_imgs # return (gallery, state)
153
 
154
  def add_img(img, all_imgs):
155
+ if all_imgs is None:
156
+ all_imgs = []
157
  all_imgs.append(img)
158
  return None, all_imgs, all_imgs
159
 
 
179
  with gr.Column():
180
  denoising_steps = gr.Slider(1, 100, value=20, step=1, label="Sampling steps")
181
  guidance_scale = gr.Slider(0.0, 30.0, value=6, step=0.25, label="CFG scale")
182
+ batch_size = gr.Slider(1, 10, value=4, step=1, label="Batch size", interactive=False)
183
  seed = gr.Number(-1, minimum=-1, precision=0, label="Seed")
184
  max_feedback_imgs = gr.Slider(0, 20, value=6, step=1, label="Max. feedback images", info="Maximum number of liked/disliked images to be used. If exceeded, only the most recent images will be used as feedback. (NOTE: large number of feedback imgs => high VRAM requirements)")
185
  feedback_enabled = gr.Checkbox(True, label="Enable feedback", interactive=True)
 
253
  liked_img_input.upload(add_img, [liked_img_input, liked_imgs], [liked_img_input, like_gallery, liked_imgs], queue=False)
254
  disliked_img_input.upload(add_img, [disliked_img_input, disliked_imgs], [disliked_img_input, dislike_gallery, disliked_imgs], queue=False)
255
 
256
+ clear_liked_btn.click(lambda: [[], []], None, [liked_imgs, like_gallery], queue=False)
257
+ clear_disliked_btn.click(lambda: [[], []], None, [disliked_imgs, dislike_gallery], queue=False)
258
 
259
+ demo.queue(1)
260
+ demo.launch(debug=True)