File size: 4,002 Bytes
4559e89
 
8ccf632
 
 
0a30c9c
 
 
3b76ae1
9237c74
 
 
0a30c9c
e822eea
 
 
 
 
 
 
 
 
0a30c9c
 
 
 
 
 
06f0278
9237c74
 
 
 
 
 
 
 
 
 
8ccf632
9237c74
 
 
 
 
 
 
8ccf632
06f0278
8ccf632
4559e89
 
 
 
 
9237c74
4559e89
 
9237c74
 
 
 
 
 
e822eea
9237c74
 
 
e822eea
9237c74
 
 
4559e89
0a30c9c
 
 
 
 
8291039
 
 
3b76ae1
8291039
0a30c9c
27495d6
e6f25a5
 
 
 
 
 
 
 
 
 
 
 
 
4559e89
 
 
54192f0
4559e89
 
0a30c9c
4559e89
e6f25a5
 
 
 
 
 
 
 
8ccf632
4559e89
8ccf632
4559e89
 
 
8ccf632
4559e89
 
9237c74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import random
import torch
import boto3
from io import BytesIO
import time
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline

# S3 Configuration from environment variables
S3_BUCKET = os.getenv("S3_BUCKET")
S3_REGION = os.getenv("S3_REGION")
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")

# Validate S3 environment variables
if not all([S3_BUCKET, S3_REGION, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
    raise ValueError("Missing required S3 environment variables")

# Set up S3 client
s3_client = boto3.client('s3', 
                         region_name=S3_REGION,
                         aws_access_key_id=S3_ACCESS_KEY_ID,
                         aws_secret_access_key=S3_SECRET_ACCESS_KEY)

# Set up cache path
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
if not os.path.exists(cache_path):
    os.makedirs(cache_path, exist_ok=True)

# Set up CUDA and model
torch.backends.cuda.matmul.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device=device, dtype=torch.bfloat16)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

app = FastAPI()

class InferenceRequest(BaseModel):
    prompt: str
    seed: int = 42
    randomize_seed: bool = True
    width: int = 1024
    height: int = 1024
    guidance_scale: float = 3.5
    num_inference_steps: int = 8

class Timer:
    def __init__(self, method_name="timed process"):
        self.method = method_name
    
    def __enter__(self):
        self.start = time.time()
        print(f"{self.method} starts")
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        print(f"{self.method} took {str(round(end - self.start, 2))}s")

def save_image_to_s3(image):
    img_byte_arr = BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()
    filename = f"generated_image_{int(time.time())}.png"
    s3_client.put_object(Bucket=S3_BUCKET, 
                         Key=filename, 
                         Body=img_byte_arr, 
                         ContentType='image/png')
    url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
    return url

def process_image(height, width, steps, scales, prompt, seed):
    global pipe
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), Timer("inference"):
        return pipe(
            prompt=[prompt],
            generator=torch.Generator().manual_seed(int(seed)),
            num_inference_steps=int(steps),
            guidance_scale=float(scales),
            height=int(height),
            width=int(width),
            max_sequence_length=256
        ).images[0]

@app.post("/infer")
async def infer(request: InferenceRequest):
    if request.randomize_seed:
        seed = random.randint(0, MAX_SEED)
    else:
        seed = request.seed
    
    try:
        image = process_image(
            height=request.height,
            width=request.width,
            steps=request.num_inference_steps,
            scales=request.guidance_scale,
            prompt=request.prompt,
            seed=seed
        )
        
        image_url = save_image_to_s3(image)
        
        return {"image_url": image_url, "seed": seed}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Welcome to the IG API"}