Spaces:
Sleeping
Sleeping
AISimplyExplained
commited on
Commit
•
4fa87d4
1
Parent(s):
5c19b8d
added labels as in input
Browse files
main.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
from detoxify import Detoxify
|
6 |
import asyncio
|
7 |
from fastapi.concurrency import run_in_threadpool
|
|
|
8 |
|
9 |
class Guardrail:
|
10 |
def __init__(self):
|
@@ -60,17 +61,20 @@ class TopicBannerClassifier:
|
|
60 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
61 |
)
|
62 |
self.hypothesis_template = "This text is about {}"
|
63 |
-
self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
|
64 |
|
65 |
-
async def classify(self, text):
|
66 |
return await run_in_threadpool(
|
67 |
self.classifier,
|
68 |
text,
|
69 |
-
|
70 |
hypothesis_template=self.hypothesis_template,
|
71 |
multi_label=False
|
72 |
)
|
73 |
|
|
|
|
|
|
|
|
|
74 |
class TopicBannerResult(BaseModel):
|
75 |
sequence: str
|
76 |
labels: list
|
@@ -108,9 +112,9 @@ async def classify_text(text_prompt: TextPrompt):
|
|
108 |
raise HTTPException(status_code=500, detail=str(e))
|
109 |
|
110 |
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
|
111 |
-
async def classify_topic_banner(
|
112 |
try:
|
113 |
-
result = await topic_banner_classifier.classify(
|
114 |
return {
|
115 |
"sequence": result["sequence"],
|
116 |
"labels": result["labels"],
|
|
|
5 |
from detoxify import Detoxify
|
6 |
import asyncio
|
7 |
from fastapi.concurrency import run_in_threadpool
|
8 |
+
from typing import List
|
9 |
|
10 |
class Guardrail:
|
11 |
def __init__(self):
|
|
|
61 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
)
|
63 |
self.hypothesis_template = "This text is about {}"
|
|
|
64 |
|
65 |
+
async def classify(self, text, labels):
|
66 |
return await run_in_threadpool(
|
67 |
self.classifier,
|
68 |
text,
|
69 |
+
labels,
|
70 |
hypothesis_template=self.hypothesis_template,
|
71 |
multi_label=False
|
72 |
)
|
73 |
|
74 |
+
class TopicBannerRequest(BaseModel):
|
75 |
+
prompt: str
|
76 |
+
labels: List[str]
|
77 |
+
|
78 |
class TopicBannerResult(BaseModel):
|
79 |
sequence: str
|
80 |
labels: list
|
|
|
112 |
raise HTTPException(status_code=500, detail=str(e))
|
113 |
|
114 |
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
|
115 |
+
async def classify_topic_banner(request: TopicBannerRequest):
|
116 |
try:
|
117 |
+
result = await topic_banner_classifier.classify(request.prompt, request.labels)
|
118 |
return {
|
119 |
"sequence": result["sequence"],
|
120 |
"labels": result["labels"],
|