image-gen / tld /app.py
BeveledCube's picture
Trying sum
3b3a783
raw
history blame
2 kB
import io
import os
from typing import Optional
import torch
import torchvision.transforms as transforms
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from tld.diffusion import DiffusionTransformer, LTDConfig
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
to_pil = transforms.ToPILImage()
ltdconfig = LTDConfig()
diffusion_transformer = DiffusionTransformer(ltdconfig)
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def validate_token(token: str = Depends(oauth2_scheme)):
if token != os.getenv("API_TOKEN"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
class ImageRequest(BaseModel):
prompt: str
class_guidance: Optional[int] = 6
seed: Optional[int] = 11
num_imgs: Optional[int] = 1
img_size: Optional[int] = 32
@app.get("/")
def read_root():
return {"message": "Welcome to Image Generator"}
@app.post("/generate-image/")
async def generate_image(request: ImageRequest, token: str = Depends(validate_token)):
try:
img = diffusion_transformer.generate_image_from_text(
prompt=request.prompt,
class_guidance=request.class_guidance,
seed=request.seed,
num_imgs=request.num_imgs,
img_size=request.img_size,
)
# Convert PIL image to byte stream suitable for HTTP response
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="JPEG")
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/jpeg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# build job to test and deploy the API on a docker image (maybe in Azure?)