Prgckwb commited on
Commit
cc8ba3d
1 Parent(s): 750ec07

:tada: change process

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -17,23 +17,14 @@ DIFFUSERS_MODEL_IDS = [
17
  # Other Models
18
  "Prgckwb/trpfrog-diffusion",
19
  ]
20
-
21
  EXTERNAL_MODEL_MAPPING = {
22
  "Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7",
23
  }
24
-
25
  MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys())
26
 
27
  current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
28
-
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
- if device == 'cuda':
31
- dtype = torch.float16
32
- pipe = DiffusionPipeline.from_pretrained(
33
- current_model_id,
34
- torch_dtype=dtype,
35
- )
36
- pipe.enable_model_cpu_offload()
37
 
38
 
39
  @dataclasses.dataclass
@@ -86,6 +77,7 @@ def inference(
86
  num_inference_steps: int = 50,
87
  num_images: int = 4,
88
  safety_checker: bool = True,
 
89
  progress=gr.Progress(track_tqdm=True),
90
  ) -> Image.Image:
91
  progress(0, "Starting inference...")
@@ -100,9 +92,8 @@ def inference(
100
 
101
  pipe = DiffusionPipeline.from_pretrained(
102
  model_id,
103
- torch_dtype=dtype,
104
  )
105
- pipe.enable_model_cpu_offload()
106
 
107
  current_model_id = model_id
108
 
@@ -118,6 +109,11 @@ def inference(
118
 
119
  # Generation
120
  progress(0.4, 'Generating images...')
 
 
 
 
 
121
  images = pipe(
122
  prompt,
123
  negative_prompt=negative_prompt,
@@ -168,6 +164,7 @@ if __name__ == "__main__":
168
 
169
  with gr.Row():
170
  safety_checker = gr.Checkbox(value=True, label='Use Safety Checker')
 
171
 
172
  with gr.Column():
173
  output_image = gr.Image(label="Image", type="pil")
@@ -181,7 +178,8 @@ if __name__ == "__main__":
181
  guidance_scale,
182
  num_inference_step,
183
  num_images,
184
- safety_checker
 
185
  ]
186
 
187
  btn = gr.Button("Generate")
 
17
  # Other Models
18
  "Prgckwb/trpfrog-diffusion",
19
  ]
 
20
  EXTERNAL_MODEL_MAPPING = {
21
  "Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7",
22
  }
 
23
  MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys())
24
 
25
  current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ pipe = None
 
 
 
 
 
 
28
 
29
 
30
  @dataclasses.dataclass
 
77
  num_inference_steps: int = 50,
78
  num_images: int = 4,
79
  safety_checker: bool = True,
80
+ use_model_offload:bool = False,
81
  progress=gr.Progress(track_tqdm=True),
82
  ) -> Image.Image:
83
  progress(0, "Starting inference...")
 
92
 
93
  pipe = DiffusionPipeline.from_pretrained(
94
  model_id,
95
+ torch_dtype=torch.float16,
96
  )
 
97
 
98
  current_model_id = model_id
99
 
 
109
 
110
  # Generation
111
  progress(0.4, 'Generating images...')
112
+ if use_model_offload:
113
+ pipe.enable_model_cpu_offload()
114
+ else:
115
+ pipe = pipe.to('cuda')
116
+
117
  images = pipe(
118
  prompt,
119
  negative_prompt=negative_prompt,
 
164
 
165
  with gr.Row():
166
  safety_checker = gr.Checkbox(value=True, label='Use Safety Checker')
167
+ model_offload = gr.Checkbox(value=False, label='Use Model Offload')
168
 
169
  with gr.Column():
170
  output_image = gr.Image(label="Image", type="pil")
 
178
  guidance_scale,
179
  num_inference_step,
180
  num_images,
181
+ safety_checker,
182
+ model_offload,
183
  ]
184
 
185
  btn = gr.Button("Generate")