ig1 / app.py
Afrinetwork7's picture
Update app.py
e822eea verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import random
import torch
import boto3
from io import BytesIO
import time
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline
# S3 Configuration from environment variables
S3_BUCKET = os.getenv("S3_BUCKET")
S3_REGION = os.getenv("S3_REGION")
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
# Validate S3 environment variables
if not all([S3_BUCKET, S3_REGION, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
raise ValueError("Missing required S3 environment variables")
# Set up S3 client
s3_client = boto3.client('s3',
region_name=S3_REGION,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY)
# Set up cache path
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
# Set up CUDA and model
torch.backends.cuda.matmul.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device=device, dtype=torch.bfloat16)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
app = FastAPI()
class InferenceRequest(BaseModel):
prompt: str
seed: int = 42
randomize_seed: bool = True
width: int = 1024
height: int = 1024
guidance_scale: float = 3.5
num_inference_steps: int = 8
class Timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
def save_image_to_s3(image):
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
filename = f"generated_image_{int(time.time())}.png"
s3_client.put_object(Bucket=S3_BUCKET,
Key=filename,
Body=img_byte_arr,
ContentType='image/png')
url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
return url
def process_image(height, width, steps, scales, prompt, seed):
global pipe
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), Timer("inference"):
return pipe(
prompt=[prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
@app.post("/infer")
async def infer(request: InferenceRequest):
if request.randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = request.seed
try:
image = process_image(
height=request.height,
width=request.width,
steps=request.num_inference_steps,
scales=request.guidance_scale,
prompt=request.prompt,
seed=seed
)
image_url = save_image_to_s3(image)
return {"image_url": image_url, "seed": seed}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Welcome to the IG API"}