from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import torch from detoxify import Detoxify import asyncio from fastapi.concurrency import run_in_threadpool from typing import List, Optional class Guardrail: def __init__(self): tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") self.classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer, truncation=True, max_length=512, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) async def guard(self, prompt): return await run_in_threadpool(self.classifier, prompt) def determine_level(self, label, score): if label == "SAFE": return 0, "safe" else: if score > 0.9: return 4, "high" elif score > 0.75: return 3, "medium" elif score > 0.5: return 2, "low" else: return 1, "very low" class TextPrompt(BaseModel): prompt: str class ClassificationResult(BaseModel): label: str score: float level: int severity_label: str class ToxicityResult(BaseModel): toxicity: float severe_toxicity: float obscene: float threat: float insult: float identity_attack: float @classmethod def from_dict(cls, data: dict): return cls(**{k: float(v) for k, v in data.items()}) class TopicBannerClassifier: def __init__(self): self.classifier = pipeline( "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) self.hypothesis_template = "This text is about {}" async def classify(self, text, labels): return await run_in_threadpool( self.classifier, text, labels, hypothesis_template=self.hypothesis_template, multi_label=False ) class TopicBannerRequest(BaseModel): prompt: str labels: List[str] class TopicBannerResult(BaseModel): sequence: str labels: list scores: list class GuardrailsRequest(BaseModel): prompt: str guardrails: List[str] labels: Optional[List[str]] = None class GuardrailsResponse(BaseModel): prompt_injection: Optional[ClassificationResult] = None toxicity: Optional[ToxicityResult] = None topic_banner: Optional[TopicBannerResult] = None app = FastAPI() guardrail = Guardrail() toxicity_classifier = Detoxify('original') topic_banner_classifier = TopicBannerClassifier() @app.post("/api/models/toxicity/classify", response_model=ToxicityResult) async def classify_toxicity(text_prompt: TextPrompt): try: result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt) return ToxicityResult.from_dict(result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult) async def classify_text(text_prompt: TextPrompt): try: result = await guardrail.guard(text_prompt.prompt) label = result[0]['label'] score = result[0]['score'] level, severity_label = guardrail.determine_level(label, score) return {"label": label, "score": score, "level": level, "severity_label": severity_label} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult) async def classify_topic_banner(request: TopicBannerRequest): try: result = await topic_banner_classifier.classify(request.prompt, request.labels) return { "sequence": result["sequence"], "labels": result["labels"], "scores": result["scores"] } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/guardrails", response_model=GuardrailsResponse) async def evaluate_guardrails(request: GuardrailsRequest): tasks = [] response = GuardrailsResponse() if "pi" in request.guardrails: tasks.append(classify_text(TextPrompt(prompt=request.prompt))) if "tox" in request.guardrails: tasks.append(classify_toxicity(TextPrompt(prompt=request.prompt))) if "top" in request.guardrails: if not request.labels: raise HTTPException(status_code=400, detail="Labels are required for topic banner classification") tasks.append(classify_topic_banner(TopicBannerRequest(prompt=request.prompt, labels=request.labels))) results = await asyncio.gather(*tasks, return_exceptions=True) for result, guardrail in zip(results, request.guardrails): if isinstance(result, Exception): # Handle the exception as needed continue if guardrail == "pi": response.prompt_injection = result elif guardrail == "tox": response.toxicity = result elif guardrail == "top": response.topic_banner = result return response if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)