AISimplyExplained commited on
Commit
c311b0d
1 Parent(s): 4fa87d4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -1
main.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from detoxify import Detoxify
6
  import asyncio
7
  from fastapi.concurrency import run_in_threadpool
8
- from typing import List
9
 
10
  class Guardrail:
11
  def __init__(self):
@@ -80,6 +80,16 @@ class TopicBannerResult(BaseModel):
80
  labels: list
81
  scores: list
82
 
 
 
 
 
 
 
 
 
 
 
83
  app = FastAPI()
84
  guardrail = Guardrail()
85
  toxicity_classifier = Detoxify('original')
@@ -123,6 +133,35 @@ async def classify_topic_banner(request: TopicBannerRequest):
123
  except Exception as e:
124
  raise HTTPException(status_code=500, detail=str(e))
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  if __name__ == "__main__":
127
  import uvicorn
128
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  from detoxify import Detoxify
6
  import asyncio
7
  from fastapi.concurrency import run_in_threadpool
8
+ from typing import List, Optional
9
 
10
  class Guardrail:
11
  def __init__(self):
 
80
  labels: list
81
  scores: list
82
 
83
+ class GuardrailsRequest(BaseModel):
84
+ prompt: str
85
+ guardrails: List[str]
86
+ labels: Optional[List[str]] = None
87
+
88
+ class GuardrailsResponse(BaseModel):
89
+ prompt_injection: Optional[ClassificationResult] = None
90
+ toxicity: Optional[ToxicityResult] = None
91
+ topic_banner: Optional[TopicBannerResult] = None
92
+
93
  app = FastAPI()
94
  guardrail = Guardrail()
95
  toxicity_classifier = Detoxify('original')
 
133
  except Exception as e:
134
  raise HTTPException(status_code=500, detail=str(e))
135
 
136
+ @app.post("/api/guardrails", response_model=GuardrailsResponse)
137
+ async def evaluate_guardrails(request: GuardrailsRequest):
138
+ tasks = []
139
+ response = GuardrailsResponse()
140
+
141
+ if "pi" in request.guardrails:
142
+ tasks.append(classify_text(TextPrompt(prompt=request.prompt)))
143
+ if "tox" in request.guardrails:
144
+ tasks.append(classify_toxicity(TextPrompt(prompt=request.prompt)))
145
+ if "top" in request.guardrails:
146
+ if not request.labels:
147
+ raise HTTPException(status_code=400, detail="Labels are required for topic banner classification")
148
+ tasks.append(classify_topic_banner(TopicBannerRequest(prompt=request.prompt, labels=request.labels)))
149
+
150
+ results = await asyncio.gather(*tasks, return_exceptions=True)
151
+
152
+ for result, guardrail in zip(results, request.guardrails):
153
+ if isinstance(result, Exception):
154
+ # Handle the exception as needed
155
+ continue
156
+ if guardrail == "pi":
157
+ response.prompt_injection = result
158
+ elif guardrail == "tox":
159
+ response.toxicity = result
160
+ elif guardrail == "top":
161
+ response.topic_banner = result
162
+
163
+ return response
164
+
165
  if __name__ == "__main__":
166
  import uvicorn
167
  uvicorn.run(app, host="0.0.0.0", port=8000)