Edmond7 commited on
Commit
aa732b3
1 Parent(s): e82f7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -43
app.py CHANGED
@@ -1,8 +1,15 @@
1
- import gradio as gr
 
 
 
 
2
  from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
  from PIL import Image
4
  import torch
5
- import spaces
 
 
 
6
 
7
  # Load the processor and model
8
  processor = AutoProcessor.from_pretrained(
@@ -11,7 +18,6 @@ processor = AutoProcessor.from_pretrained(
11
  torch_dtype='auto',
12
  device_map='auto'
13
  )
14
-
15
  model = AutoModelForCausalLM.from_pretrained(
16
  'allenai/Molmo-7B-D-0924',
17
  trust_remote_code=True,
@@ -19,62 +25,51 @@ model = AutoModelForCausalLM.from_pretrained(
19
  device_map='auto'
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- @spaces.GPU(duration=120)
24
  def process_image_and_text(image, text):
25
- # Process the image and text
26
  inputs = processor.process(
27
- images=[Image.fromarray(image)],
28
  text=text
29
  )
30
-
31
- # Move inputs to the correct device and make a batch of size 1
32
  inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
33
-
34
- # Generate output
35
  output = model.generate_from_batch(
36
  inputs,
37
  GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
38
  tokenizer=processor.tokenizer
39
  )
40
-
41
- # Only get generated tokens; decode them to text
42
  generated_tokens = output[0, inputs['input_ids'].size(1):]
43
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
44
-
45
  return generated_text
46
 
47
- def chatbot(image, text, history):
48
- if image is None:
49
- return history + [("Please upload an image first.", None)]
50
 
 
 
 
 
51
  response = process_image_and_text(image, text)
52
- history.append((text, response))
53
- return history
54
 
55
- # Define the Gradio interface
56
- with gr.Blocks() as demo:
57
- gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
 
 
 
 
58
 
59
- with gr.Row():
60
- image_input = gr.Image(type="numpy")
61
- chatbot_output = gr.Chatbot()
62
-
63
- text_input = gr.Textbox(placeholder="Ask a question about the image...")
64
- submit_button = gr.Button("Submit")
65
-
66
- state = gr.State([])
67
-
68
- submit_button.click(
69
- chatbot,
70
- inputs=[image_input, text_input, state],
71
- outputs=[chatbot_output]
72
- )
73
-
74
- text_input.submit(
75
- chatbot,
76
- inputs=[image_input, text_input, state],
77
- outputs=[chatbot_output]
78
- )
79
-
80
- demo.launch()
 
1
+ import os
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
3
+ from fastapi.security.api_key import APIKeyHeader
4
+ from starlette.status import HTTP_403_FORBIDDEN
5
+ from pydantic import BaseModel
6
  from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
7
  from PIL import Image
8
  import torch
9
+ import base64
10
+ import io
11
+
12
+ app = FastAPI()
13
 
14
  # Load the processor and model
15
  processor = AutoProcessor.from_pretrained(
 
18
  torch_dtype='auto',
19
  device_map='auto'
20
  )
 
21
  model = AutoModelForCausalLM.from_pretrained(
22
  'allenai/Molmo-7B-D-0924',
23
  trust_remote_code=True,
 
25
  device_map='auto'
26
  )
27
 
28
+ # API Key setup
29
+ API_KEY = os.environ.get("API_KEY")
30
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
31
+
32
+ async def get_api_key(api_key_header: str = Depends(api_key_header)):
33
+ if api_key_header == API_KEY:
34
+ return api_key_header
35
+ else:
36
+ raise HTTPException(
37
+ status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
38
+ )
39
 
 
40
  def process_image_and_text(image, text):
 
41
  inputs = processor.process(
42
+ images=[image],
43
  text=text
44
  )
 
 
45
  inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
 
 
46
  output = model.generate_from_batch(
47
  inputs,
48
  GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
49
  tokenizer=processor.tokenizer
50
  )
 
 
51
  generated_tokens = output[0, inputs['input_ids'].size(1):]
52
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
53
  return generated_text
54
 
55
+ class Base64Request(BaseModel):
56
+ image: str
57
+ text: str
58
 
59
+ @app.post("/upload")
60
+ async def upload_image(file: UploadFile = File(...), text: str = "", api_key: str = Depends(get_api_key)):
61
+ contents = await file.read()
62
+ image = Image.open(io.BytesIO(contents))
63
  response = process_image_and_text(image, text)
64
+ return {"response": response}
 
65
 
66
+ @app.post("/base64")
67
+ async def process_base64(request: Base64Request, api_key: str = Depends(get_api_key)):
68
+ try:
69
+ image_data = base64.b64decode(request.image)
70
+ image = Image.open(io.BytesIO(image_data))
71
+ except:
72
+ raise HTTPException(status_code=400, detail="Invalid base64 image")
73
 
74
+ response = process_image_and_text(image, request.text)
75
+ return {"response": response}