Spaces:
Running
Running
import os | |
import chromadb | |
from fastapi import FastAPI, Request, Form, File, UploadFile, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse | |
from fastapi.templating import Jinja2Templates | |
from .admin import admin_functions as admin | |
from .utils.db import UserFaceEmbeddingFunction,ChromaDBFaceHelper | |
from .api import userlogin, userlogout, userchat, userupload | |
from .utils.db import tinydb_helper, ChromaDBFaceHelper | |
CHROMADB_LOC = "/home/user/data/chromadb" | |
app = FastAPI() | |
# Add middleware | |
# Set all origins to wildcard for simplicity, but we should limit this in production | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Persitent storage for chromadb setup in /data volume | |
ec_client = chromadb.PersistentClient(CHROMADB_LOC) | |
# The following collection reference is needed for admin function to register face | |
user_faces_db = ec_client.get_or_create_collection(name="user_faces_db", embedding_function=UserFaceEmbeddingFunction()) | |
async def startup_event(): | |
global chromadb_face_helper | |
# Assuming chromadb persistent store client for APIs is in helper | |
db_path = CHROMADB_LOC | |
chromadb_face_helper = ChromaDBFaceHelper(db_path) # Used by APIs | |
# Perform any other startup tasks here | |
print(f"MODEL_PATH in main.py = {os.getenv('MODEL_PATH')} ") | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Setup Jinja2Templates to point to the templates directory | |
templates = Jinja2Templates(directory="app/admin/templates") | |
async def get_admin_login(request: Request): | |
return templates.TemplateResponse("admin_login.html", {"request": request}) | |
# Admin Login Handler | |
async def handle_admin_login(request: Request, username: str = Form(...), password: str = Form(...)): | |
if admin.verify_admin_password(username, password): | |
# Redirect to user registration page upon successful login | |
return RedirectResponse(url="/admin/register_user", status_code=303) | |
else: | |
# Reload login page with error message | |
return templates.TemplateResponse("admin_login.html", {"request": request, "error": "Invalid password"}) | |
# To display the register user page | |
async def get_user_registration(request: Request): | |
# Render the registration form | |
return templates.TemplateResponse("user_registration.html", {"request": request}) | |
# User Registration Handler | |
async def handle_user_registration(request: Request, email: str = Form(...), name: str = Form(...), role: str = Form(...), file: UploadFile = File(...)): | |
user_id = await admin.register_user(user_faces_db, email, name, role, file) | |
if user_id: | |
# Calculate disk usage | |
disk_usage = admin.get_disk_usage("/home/user/data") | |
# Redirect or display a success message | |
return templates.TemplateResponse("registration_success.html", { | |
"request": request, | |
"disk_usage": disk_usage | |
}) | |
else: | |
# Reload registration page with error message | |
return templates.TemplateResponse("user_registration.html", {"request": request, "error": "Registration failed"}) | |
app.include_router(userlogin.router) | |
app.include_router(userlogout.router) | |
app.include_router(userchat.router) | |
app.include_router(userupload.router) | |