Spaces:
Runtime error
Runtime error
File size: 4,002 Bytes
4559e89 8ccf632 0a30c9c 3b76ae1 9237c74 0a30c9c e822eea 0a30c9c 06f0278 9237c74 8ccf632 9237c74 8ccf632 06f0278 8ccf632 4559e89 9237c74 4559e89 9237c74 e822eea 9237c74 e822eea 9237c74 4559e89 0a30c9c 8291039 3b76ae1 8291039 0a30c9c 27495d6 e6f25a5 4559e89 54192f0 4559e89 0a30c9c 4559e89 e6f25a5 8ccf632 4559e89 8ccf632 4559e89 8ccf632 4559e89 9237c74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import random
import torch
import boto3
from io import BytesIO
import time
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline
# S3 Configuration from environment variables
S3_BUCKET = os.getenv("S3_BUCKET")
S3_REGION = os.getenv("S3_REGION")
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
# Validate S3 environment variables
if not all([S3_BUCKET, S3_REGION, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
raise ValueError("Missing required S3 environment variables")
# Set up S3 client
s3_client = boto3.client('s3',
region_name=S3_REGION,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY)
# Set up cache path
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
# Set up CUDA and model
torch.backends.cuda.matmul.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device=device, dtype=torch.bfloat16)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
app = FastAPI()
class InferenceRequest(BaseModel):
prompt: str
seed: int = 42
randomize_seed: bool = True
width: int = 1024
height: int = 1024
guidance_scale: float = 3.5
num_inference_steps: int = 8
class Timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
def save_image_to_s3(image):
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
filename = f"generated_image_{int(time.time())}.png"
s3_client.put_object(Bucket=S3_BUCKET,
Key=filename,
Body=img_byte_arr,
ContentType='image/png')
url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
return url
def process_image(height, width, steps, scales, prompt, seed):
global pipe
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), Timer("inference"):
return pipe(
prompt=[prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
@app.post("/infer")
async def infer(request: InferenceRequest):
if request.randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = request.seed
try:
image = process_image(
height=request.height,
width=request.width,
steps=request.num_inference_steps,
scales=request.guidance_scale,
prompt=request.prompt,
seed=seed
)
image_url = save_image_to_s3(image)
return {"image_url": image_url, "seed": seed}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Welcome to the IG API"} |