Spaces:
Sleeping
Sleeping
Cleanup token auth
Browse files- app/api/userlogout.py +9 -2
- app/dependencies.py +22 -17
- app/utils/db.py +7 -3
- app/utils/jwt_utils.py +15 -17
app/api/userlogout.py
CHANGED
@@ -1,14 +1,21 @@
|
|
|
|
1 |
from fastapi import APIRouter, Depends, HTTPException
|
2 |
from ..utils.db import tinydb_helper # Ensure this import is correct based on our project structure
|
3 |
-
from ..dependencies import oauth2_scheme
|
4 |
|
5 |
router = APIRouter()
|
6 |
|
7 |
@router.post("/user/logout")
|
8 |
-
async def user_logout(token: str = Depends(oauth2_scheme)):
|
9 |
try:
|
|
|
|
|
10 |
# Invalidate the token by removing it from the database
|
|
|
|
|
11 |
tinydb_helper.remove_token_by_value(token)
|
|
|
|
|
12 |
return {"message": "User logged out successfully"}
|
13 |
except Exception as e:
|
14 |
raise HTTPException(status_code=400, detail=f"Error during logout: {str(e)}")
|
|
|
1 |
+
from typing import Any
|
2 |
from fastapi import APIRouter, Depends, HTTPException
|
3 |
from ..utils.db import tinydb_helper # Ensure this import is correct based on our project structure
|
4 |
+
from ..dependencies import get_current_user, oauth2_scheme
|
5 |
|
6 |
router = APIRouter()
|
7 |
|
8 |
@router.post("/user/logout")
|
9 |
+
async def user_logout(token: str = Depends(oauth2_scheme), current_user: Any = Depends(get_current_user)):
|
10 |
try:
|
11 |
+
# Assuming `get_current_user` now also ensures and returns the full payload including `user_id`
|
12 |
+
user_id = current_user["user_id"]
|
13 |
# Invalidate the token by removing it from the database
|
14 |
+
if not tinydb_helper.query_token(user_id, token):
|
15 |
+
raise HTTPException(status_code=404, detail="Token not found.")
|
16 |
tinydb_helper.remove_token_by_value(token)
|
17 |
+
if tinydb_helper.query_token(user_id, token):
|
18 |
+
raise HTTPException(status_code=404, detail="Logout unsuccessful.")
|
19 |
return {"message": "User logged out successfully"}
|
20 |
except Exception as e:
|
21 |
raise HTTPException(status_code=400, detail=f"Error during logout: {str(e)}")
|
app/dependencies.py
CHANGED
@@ -1,28 +1,33 @@
|
|
1 |
from fastapi import Depends, HTTPException, status
|
2 |
from fastapi.security import OAuth2PasswordBearer
|
3 |
-
from jose import jwt, JWTError
|
4 |
-
from .utils.db import tinydb_helper # Ensure
|
5 |
-
from .utils.jwt_utils import SECRET_KEY, ALGORITHM
|
6 |
|
7 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
8 |
|
9 |
-
def
|
10 |
-
try:
|
11 |
-
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
12 |
-
user_id: str = payload.get("sub")
|
13 |
-
name: str = payload.get("name")
|
14 |
-
role: str = payload.get("role")
|
15 |
-
if user_id is None or name is None or role is None:
|
16 |
-
raise credentials_exception
|
17 |
-
return {"user_id": user_id, "name": name, "role": role}
|
18 |
-
except jwt.PyJWTError:
|
19 |
-
raise credentials_exception
|
20 |
-
|
21 |
-
async def get_current_user(token: str = Depends(oauth2_scheme)) -> dict:
|
22 |
credentials_exception = HTTPException(
|
23 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
24 |
detail="Could not validate credentials",
|
25 |
headers={"WWW-Authenticate": "Bearer"},
|
26 |
)
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
1 |
from fastapi import Depends, HTTPException, status
|
2 |
from fastapi.security import OAuth2PasswordBearer
|
3 |
+
from jose import jwt, JWTError # Ensure this is correctly imported
|
4 |
+
from .utils.db import tinydb_helper # Ensure this instance is correctly initialized elsewhere
|
5 |
+
from .utils.jwt_utils import SECRET_KEY, ALGORITHM, decode_jwt
|
6 |
|
7 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
8 |
|
9 |
+
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
credentials_exception = HTTPException(
|
11 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
12 |
detail="Could not validate credentials",
|
13 |
headers={"WWW-Authenticate": "Bearer"},
|
14 |
)
|
15 |
+
|
16 |
+
# Utilize the centralized JWT decoding and catch any JWT-related errors
|
17 |
+
try:
|
18 |
+
payload = decode_jwt(token)
|
19 |
+
except JWTError:
|
20 |
+
raise credentials_exception
|
21 |
+
|
22 |
+
user_id: str = payload.get("sub")
|
23 |
+
if user_id is None:
|
24 |
+
raise credentials_exception
|
25 |
+
|
26 |
+
# Verify if the token is stored and valid
|
27 |
+
if not tinydb_helper.query_token(user_id, token):
|
28 |
+
raise credentials_exception
|
29 |
+
|
30 |
+
# Payload is already obtained and validated, so just return it or its specific parts as needed
|
31 |
+
return {"user_id": user_id, "name": payload.get("name"), "role": payload.get("role")}
|
32 |
+
|
33 |
|
app/utils/db.py
CHANGED
@@ -20,14 +20,18 @@ class TinyDBHelper:
|
|
20 |
def query_token(self, user_id: str, token: str) -> bool:
|
21 |
"""Query to check if the token exists and is valid."""
|
22 |
User = Query()
|
23 |
-
# Assuming our tokens table contains 'user_id', 'token', and 'expires_at'
|
24 |
result = self.tokens_table.search((User.user_id == user_id) & (User.token == token))
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
expires_at = datetime.fromisoformat(result[0]['expires_at'])
|
27 |
if datetime.utcnow() > expires_at:
|
28 |
return False
|
29 |
|
30 |
-
return
|
31 |
|
32 |
def remove_token_by_value(self, token: str):
|
33 |
"""Remove a token based on its value."""
|
|
|
20 |
def query_token(self, user_id: str, token: str) -> bool:
|
21 |
"""Query to check if the token exists and is valid."""
|
22 |
User = Query()
|
|
|
23 |
result = self.tokens_table.search((User.user_id == user_id) & (User.token == token))
|
24 |
+
|
25 |
+
# Check if the result is empty (i.e., no matching token found)
|
26 |
+
if not result:
|
27 |
+
return False
|
28 |
+
|
29 |
+
# Check if the token is expired
|
30 |
expires_at = datetime.fromisoformat(result[0]['expires_at'])
|
31 |
if datetime.utcnow() > expires_at:
|
32 |
return False
|
33 |
|
34 |
+
return True
|
35 |
|
36 |
def remove_token_by_value(self, token: str):
|
37 |
"""Remove a token based on its value."""
|
app/utils/jwt_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from datetime import datetime, timedelta
|
2 |
from jose import JWTError, jwt
|
|
|
3 |
from typing import Any, Union
|
4 |
#from tinydb import TinyDB, Query
|
5 |
#from tinydb.storages import MemoryStorage
|
@@ -10,23 +11,20 @@ SECRET_KEY = "a_very_secret_key"
|
|
10 |
ALGORITHM = "HS256"
|
11 |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # The expiration time for the access token
|
12 |
|
13 |
-
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
#
|
26 |
-
|
27 |
-
# if datetime.utcnow() > expires_at:
|
28 |
-
# return False
|
29 |
-
# return True
|
30 |
|
31 |
def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
32 |
"""
|
|
|
1 |
from datetime import datetime, timedelta
|
2 |
from jose import JWTError, jwt
|
3 |
+
from fastapi import HTTPException
|
4 |
from typing import Any, Union
|
5 |
#from tinydb import TinyDB, Query
|
6 |
#from tinydb.storages import MemoryStorage
|
|
|
11 |
ALGORITHM = "HS256"
|
12 |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # The expiration time for the access token
|
13 |
|
14 |
+
def encode_jwt(data: dict):
|
15 |
+
# Encode a JWT token
|
16 |
+
return jwt.encode(data, SECRET_KEY, algorithm=ALGORITHM)
|
17 |
+
|
18 |
+
def decode_jwt(token: str):
|
19 |
+
try:
|
20 |
+
# Decode a JWT token
|
21 |
+
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
22 |
+
except JWTError as e:
|
23 |
+
# Handle specific JWT errors (e.g., token expired, invalid token)
|
24 |
+
raise HTTPException(status_code=401, detail="Token is invalid or expired")
|
25 |
+
except Exception as e:
|
26 |
+
# Handle unexpected errors
|
27 |
+
raise HTTPException(status_code=500, detail="An error occurred while decoding token")
|
|
|
|
|
|
|
28 |
|
29 |
def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
30 |
"""
|