File size: 3,332 Bytes
fe08f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
from typing import List

import torch
from fastapi import FastAPI, Request, status, HTTPException
from pydantic import BaseModel
from torch.cuda import get_device_properties
from transformers import AutoModel, AutoTokenizer
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

import os

os.environ['TRANSFORMERS_CACHE'] = ".cache"

bits = 4
kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so"
model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d"
cache_dir = './models'
model_name = 'chatglm-6b-int4'
min_memory = 5.5
tokenizer = None
model = None

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.on_event('startup')
def init():
    global tokenizer, model
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)

    if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory:
        model = model.half().quantize(bits=bits).cuda()
        print("Using GPU")
    else:
        model = model.float().quantize(bits=bits)
        if torch.cuda.is_available():
            print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3)
        else:
            print("No GPU available")
        print("Using CPU")
    model = model.eval()
    if os.environ.get("ngrok_token") is not None:
        ngrok_connect()


class Message(BaseModel):
    role: str
    content: str


class Body(BaseModel):
    messages: List[Message]
    model: str
    stream: bool
    max_tokens: int


@app.get("/")
def read_root():
    return {"Hello": "World!"}


@app.post("/chat/completions")
async def completions(body: Body, request: Request):
    if not body.stream or body.model != model_name:
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented")

    question = body.messages[-1]
    if question.role == 'user':
        question = question.content
    else:
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")

    user_question = ''
    history = []
    for message in body.messages:
        if message.role == 'user':
            user_question = message.content
        elif message.role == 'system' or message.role == 'assistant':
            assistant_answer = message.content
            history.append((user_question, assistant_answer))

    async def event_generator():
        for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)):
            if await request.is_disconnected():
                return
            yield json.dumps({"response": response[0]})
        yield "[DONE]"

    return EventSourceResponse(event_generator())


def ngrok_connect():
    from pyngrok import ngrok, conf
    conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
    ngrok.set_auth_token(os.environ["ngrok_token"])
    http_tunnel = ngrok.connect(8000)
    print(http_tunnel.public_url)


if __name__ == "__main__":
    uvicorn.run("main:app", reload=True, app_dir=".")