Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import numpy as np | |
import random | |
import torch | |
import boto3 | |
from io import BytesIO | |
import time | |
import os | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download | |
from diffusers import FluxPipeline | |
# S3 Configuration from environment variables | |
S3_BUCKET = os.getenv("S3_BUCKET") | |
S3_REGION = os.getenv("S3_REGION") | |
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID") | |
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY") | |
# Validate S3 environment variables | |
if not all([S3_BUCKET, S3_REGION, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]): | |
raise ValueError("Missing required S3 environment variables") | |
# Set up S3 client | |
s3_client = boto3.client('s3', | |
region_name=S3_REGION, | |
aws_access_key_id=S3_ACCESS_KEY_ID, | |
aws_secret_access_key=S3_SECRET_ACCESS_KEY) | |
# Set up cache path | |
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") | |
os.environ["TRANSFORMERS_CACHE"] = cache_path | |
os.environ["HF_HUB_CACHE"] = cache_path | |
os.environ["HF_HOME"] = cache_path | |
if not os.path.exists(cache_path): | |
os.makedirs(cache_path, exist_ok=True) | |
# Set up CUDA and model | |
torch.backends.cuda.matmul.allow_tf32 = True | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize FluxPipeline | |
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) | |
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) | |
pipe.fuse_lora(lora_scale=0.125) | |
pipe.to(device=device, dtype=torch.bfloat16) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
app = FastAPI() | |
class InferenceRequest(BaseModel): | |
prompt: str | |
seed: int = 42 | |
randomize_seed: bool = True | |
width: int = 1024 | |
height: int = 1024 | |
guidance_scale: float = 3.5 | |
num_inference_steps: int = 8 | |
class Timer: | |
def __init__(self, method_name="timed process"): | |
self.method = method_name | |
def __enter__(self): | |
self.start = time.time() | |
print(f"{self.method} starts") | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
end = time.time() | |
print(f"{self.method} took {str(round(end - self.start, 2))}s") | |
def save_image_to_s3(image): | |
img_byte_arr = BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
filename = f"generated_image_{int(time.time())}.png" | |
s3_client.put_object(Bucket=S3_BUCKET, | |
Key=filename, | |
Body=img_byte_arr, | |
ContentType='image/png') | |
url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}" | |
return url | |
def process_image(height, width, steps, scales, prompt, seed): | |
global pipe | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), Timer("inference"): | |
return pipe( | |
prompt=[prompt], | |
generator=torch.Generator().manual_seed(int(seed)), | |
num_inference_steps=int(steps), | |
guidance_scale=float(scales), | |
height=int(height), | |
width=int(width), | |
max_sequence_length=256 | |
).images[0] | |
async def infer(request: InferenceRequest): | |
if request.randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
else: | |
seed = request.seed | |
try: | |
image = process_image( | |
height=request.height, | |
width=request.width, | |
steps=request.num_inference_steps, | |
scales=request.guidance_scale, | |
prompt=request.prompt, | |
seed=seed | |
) | |
image_url = save_image_to_s3(image) | |
return {"image_url": image_url, "seed": seed} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Welcome to the IG API"} |