ManojINaik's picture
Update main.py
4776181 verified
raw
history blame
1.98 kB
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
app = FastAPI()
# Initialize the InferenceClient with the specified model
client = InferenceClient("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
# Define the structure of the request body
class CourseRequest(BaseModel):
course_name: str
history: list = [] # Keeping history optional
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
# Format the prompt for the model
def format_prompt(course_name, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST] {bot_response} </s> "
prompt += f"[INST] Generate a roadmap for the course: {course_name} [/INST]"
return prompt
# Generate text using the specified parameters
def generate(course_request: CourseRequest):
temperature = max(float(course_request.temperature), 1e-2)
top_p = float(course_request.top_p)
generate_kwargs = {
'temperature': temperature,
'max_new_tokens': course_request.max_new_tokens,
'top_p': top_p,
'repetition_penalty': course_request.repetition_penalty,
'do_sample': True,
'seed': 42,
}
formatted_prompt = format_prompt(course_request.course_name, course_request.history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
return output
# Define the API endpoint for generating course roadmaps
@app.post("/generate-roadmap/")
async def generate_roadmap(course_request: CourseRequest):
return {"roadmap": generate(course_request)}
# Run the application (uncomment the next two lines if running this as a standalone script)
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)