patrickvonplaten commited on
Commit
d42c56f
1 Parent(s): 7302472

Former-commit-id: eef5da90dbee4b22bdd864e53726993f98ae3366

Files changed (1) hide show
  1. scripts/txt2img.py +13 -10
scripts/txt2img.py CHANGED
@@ -19,8 +19,10 @@ from ldm.models.diffusion.plms import PLMSSampler
19
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
20
  from transformers import AutoFeatureExtractor
21
 
22
- feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
23
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
 
 
24
 
25
  def chunk(it, size):
26
  it = iter(it)
@@ -266,16 +268,23 @@ def main():
266
 
267
  x_samples_ddim = model.decode_first_stage(samples_ddim)
268
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
 
 
 
 
 
 
269
 
270
  if not opt.skip_save:
271
- for x_sample in x_samples_ddim:
272
  x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
273
  Image.fromarray(x_sample.astype(np.uint8)).save(
274
  os.path.join(sample_path, f"{base_count:05}.png"))
275
  base_count += 1
276
 
277
  if not opt.skip_grid:
278
- all_samples.append(x_samples_ddim)
279
 
280
  if not opt.skip_grid:
281
  # additionally, save as grid
@@ -288,12 +297,6 @@ def main():
288
  Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
289
  grid_count += 1
290
 
291
- image = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
292
-
293
- # run safety checker
294
- safety_checker_input = pipe.feature_extractor(numpy_to_pil(image), return_tensors="pt")
295
- image, has_nsfw_concept = pipe.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
296
-
297
  print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
298
  f" \nEnjoy.")
299
 
 
19
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
20
  from transformers import AutoFeatureExtractor
21
 
22
+ # load safety model
23
+ safety_model_id = "CompVis/stable-diffusion-v-1-3"
24
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, use_auth_token=True)
25
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, use_auth_token=True)
26
 
27
  def chunk(it, size):
28
  it = iter(it)
 
268
 
269
  x_samples_ddim = model.decode_first_stage(samples_ddim)
270
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
271
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
272
+
273
+ x_image = x_samples_ddim
274
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
275
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
276
+
277
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 2, 1)
278
 
279
  if not opt.skip_save:
280
+ for x_sample in x_checked_image_torch:
281
  x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
282
  Image.fromarray(x_sample.astype(np.uint8)).save(
283
  os.path.join(sample_path, f"{base_count:05}.png"))
284
  base_count += 1
285
 
286
  if not opt.skip_grid:
287
+ all_samples.append(x_checked_image_torch)
288
 
289
  if not opt.skip_grid:
290
  # additionally, save as grid
 
297
  Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
298
  grid_count += 1
299
 
 
 
 
 
 
 
300
  print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
301
  f" \nEnjoy.")
302