Afrinetwork7 commited on
Commit
e6f25a5
1 Parent(s): 9237c74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -80,24 +80,35 @@ def save_image_to_s3(image):
80
  url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
81
  return url
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  @app.post("/infer")
84
  async def infer(request: InferenceRequest):
85
  if request.randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
87
  else:
88
  seed = request.seed
89
- generator = torch.Generator().manual_seed(seed)
90
 
91
  try:
92
- with Timer("Image generation"):
93
- image = pipe(
94
- prompt=request.prompt,
95
- width=request.width,
96
- height=request.height,
97
- num_inference_steps=request.num_inference_steps,
98
- generator=generator,
99
- guidance_scale=request.guidance_scale
100
- ).images[0]
101
 
102
  image_url = save_image_to_s3(image)
103
 
 
80
  url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
81
  return url
82
 
83
+ def process_image(height, width, steps, scales, prompt, seed):
84
+ global pipe
85
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), Timer("inference"):
86
+ return pipe(
87
+ prompt=[prompt],
88
+ generator=torch.Generator().manual_seed(int(seed)),
89
+ num_inference_steps=int(steps),
90
+ guidance_scale=float(scales),
91
+ height=int(height),
92
+ width=int(width),
93
+ max_sequence_length=256
94
+ ).images[0]
95
+
96
  @app.post("/infer")
97
  async def infer(request: InferenceRequest):
98
  if request.randomize_seed:
99
  seed = random.randint(0, MAX_SEED)
100
  else:
101
  seed = request.seed
 
102
 
103
  try:
104
+ image = process_image(
105
+ height=request.height,
106
+ width=request.width,
107
+ steps=request.num_inference_steps,
108
+ scales=request.guidance_scale,
109
+ prompt=request.prompt,
110
+ seed=seed
111
+ )
 
112
 
113
  image_url = save_image_to_s3(image)
114