from fastapi import FastAPI, HTTPException from pydantic import BaseModel import numpy as np import random import torch import boto3 from io import BytesIO import time import os from safetensors.torch import load_file from huggingface_hub import hf_hub_download from diffusers import FluxPipeline # S3 Configuration from environment variables S3_BUCKET = os.getenv("S3_BUCKET") S3_REGION = os.getenv("S3_REGION") S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID") S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY") # Validate S3 environment variables if not all([S3_BUCKET, S3_REGION, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]): raise ValueError("Missing required S3 environment variables") # Set up S3 client s3_client = boto3.client('s3', region_name=S3_REGION, aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY) # Set up cache path cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") os.environ["TRANSFORMERS_CACHE"] = cache_path os.environ["HF_HUB_CACHE"] = cache_path os.environ["HF_HOME"] = cache_path if not os.path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) # Set up CUDA and model torch.backends.cuda.matmul.allow_tf32 = True device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize FluxPipeline pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) pipe.fuse_lora(lora_scale=0.125) pipe.to(device=device, dtype=torch.bfloat16) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 app = FastAPI() class InferenceRequest(BaseModel): prompt: str seed: int = 42 randomize_seed: bool = True width: int = 1024 height: int = 1024 guidance_scale: float = 3.5 num_inference_steps: int = 8 class Timer: def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() print(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() print(f"{self.method} took {str(round(end - self.start, 2))}s") def save_image_to_s3(image): img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() filename = f"generated_image_{int(time.time())}.png" s3_client.put_object(Bucket=S3_BUCKET, Key=filename, Body=img_byte_arr, ContentType='image/png') url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}" return url def process_image(height, width, steps, scales, prompt, seed): global pipe with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), Timer("inference"): return pipe( prompt=[prompt], generator=torch.Generator().manual_seed(int(seed)), num_inference_steps=int(steps), guidance_scale=float(scales), height=int(height), width=int(width), max_sequence_length=256 ).images[0] @app.post("/infer") async def infer(request: InferenceRequest): if request.randomize_seed: seed = random.randint(0, MAX_SEED) else: seed = request.seed try: image = process_image( height=request.height, width=request.width, steps=request.num_inference_steps, scales=request.guidance_scale, prompt=request.prompt, seed=seed ) image_url = save_image_to_s3(image) return {"image_url": image_url, "seed": seed} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return {"message": "Welcome to the IG API"}