|
import os |
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends |
|
from fastapi.security.api_key import APIKeyHeader |
|
from starlette.status import HTTP_403_FORBIDDEN |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig |
|
from PIL import Image |
|
import torch |
|
import base64 |
|
import io |
|
|
|
app = FastAPI() |
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
'allenai/Molmo-7B-D-0924', |
|
trust_remote_code=True, |
|
torch_dtype='auto', |
|
device_map='auto' |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
'allenai/Molmo-7B-D-0924', |
|
trust_remote_code=True, |
|
torch_dtype='auto', |
|
device_map='auto' |
|
) |
|
|
|
|
|
API_KEY = os.environ.get("API_KEY") |
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) |
|
|
|
async def get_api_key(api_key_header: str = Depends(api_key_header)): |
|
if api_key_header == API_KEY: |
|
return api_key_header |
|
else: |
|
raise HTTPException( |
|
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials" |
|
) |
|
|
|
def process_image_and_text(image, text): |
|
inputs = processor.process( |
|
images=[image], |
|
text=text |
|
) |
|
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} |
|
output = model.generate_from_batch( |
|
inputs, |
|
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"), |
|
tokenizer=processor.tokenizer |
|
) |
|
generated_tokens = output[0, inputs['input_ids'].size(1):] |
|
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
return generated_text |
|
|
|
class Base64Request(BaseModel): |
|
image: str |
|
text: str |
|
|
|
@app.post("/upload") |
|
async def upload_image(file: UploadFile = File(...), text: str = "", api_key: str = Depends(get_api_key)): |
|
contents = await file.read() |
|
image = Image.open(io.BytesIO(contents)) |
|
response = process_image_and_text(image, text) |
|
return {"response": response} |
|
|
|
@app.post("/base64") |
|
async def process_base64(request: Base64Request, api_key: str = Depends(get_api_key)): |
|
try: |
|
image_data = base64.b64decode(request.image) |
|
image = Image.open(io.BytesIO(image_data)) |
|
except: |
|
raise HTTPException(status_code=400, detail="Invalid base64 image") |
|
|
|
response = process_image_and_text(image, request.text) |
|
return {"response": response} |
|
|