dtyago commited on
Commit
c672e11
1 Parent(s): 2dae6f6

Cleanup token auth

Browse files
app/api/userlogout.py CHANGED
@@ -1,14 +1,21 @@
 
1
  from fastapi import APIRouter, Depends, HTTPException
2
  from ..utils.db import tinydb_helper # Ensure this import is correct based on our project structure
3
- from ..dependencies import oauth2_scheme
4
 
5
  router = APIRouter()
6
 
7
  @router.post("/user/logout")
8
- async def user_logout(token: str = Depends(oauth2_scheme)):
9
  try:
 
 
10
  # Invalidate the token by removing it from the database
 
 
11
  tinydb_helper.remove_token_by_value(token)
 
 
12
  return {"message": "User logged out successfully"}
13
  except Exception as e:
14
  raise HTTPException(status_code=400, detail=f"Error during logout: {str(e)}")
 
1
+ from typing import Any
2
  from fastapi import APIRouter, Depends, HTTPException
3
  from ..utils.db import tinydb_helper # Ensure this import is correct based on our project structure
4
+ from ..dependencies import get_current_user, oauth2_scheme
5
 
6
  router = APIRouter()
7
 
8
  @router.post("/user/logout")
9
+ async def user_logout(token: str = Depends(oauth2_scheme), current_user: Any = Depends(get_current_user)):
10
  try:
11
+ # Assuming `get_current_user` now also ensures and returns the full payload including `user_id`
12
+ user_id = current_user["user_id"]
13
  # Invalidate the token by removing it from the database
14
+ if not tinydb_helper.query_token(user_id, token):
15
+ raise HTTPException(status_code=404, detail="Token not found.")
16
  tinydb_helper.remove_token_by_value(token)
17
+ if tinydb_helper.query_token(user_id, token):
18
+ raise HTTPException(status_code=404, detail="Logout unsuccessful.")
19
  return {"message": "User logged out successfully"}
20
  except Exception as e:
21
  raise HTTPException(status_code=400, detail=f"Error during logout: {str(e)}")
app/dependencies.py CHANGED
@@ -1,28 +1,33 @@
1
  from fastapi import Depends, HTTPException, status
2
  from fastapi.security import OAuth2PasswordBearer
3
- from jose import jwt, JWTError
4
- from .utils.db import tinydb_helper # Ensure correct import path
5
- from .utils.jwt_utils import SECRET_KEY, ALGORITHM # Ensure these are defined in our jwt_utils.py
6
 
7
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
8
 
9
- def decode_access_token(token: str, credentials_exception) -> dict:
10
- try:
11
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
12
- user_id: str = payload.get("sub")
13
- name: str = payload.get("name")
14
- role: str = payload.get("role")
15
- if user_id is None or name is None or role is None:
16
- raise credentials_exception
17
- return {"user_id": user_id, "name": name, "role": role}
18
- except jwt.PyJWTError:
19
- raise credentials_exception
20
-
21
- async def get_current_user(token: str = Depends(oauth2_scheme)) -> dict:
22
  credentials_exception = HTTPException(
23
  status_code=status.HTTP_401_UNAUTHORIZED,
24
  detail="Could not validate credentials",
25
  headers={"WWW-Authenticate": "Bearer"},
26
  )
27
- return decode_access_token(token, credentials_exception)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
1
  from fastapi import Depends, HTTPException, status
2
  from fastapi.security import OAuth2PasswordBearer
3
+ from jose import jwt, JWTError # Ensure this is correctly imported
4
+ from .utils.db import tinydb_helper # Ensure this instance is correctly initialized elsewhere
5
+ from .utils.jwt_utils import SECRET_KEY, ALGORITHM, decode_jwt
6
 
7
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
8
 
9
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
 
 
 
 
 
 
 
 
 
 
 
 
10
  credentials_exception = HTTPException(
11
  status_code=status.HTTP_401_UNAUTHORIZED,
12
  detail="Could not validate credentials",
13
  headers={"WWW-Authenticate": "Bearer"},
14
  )
15
+
16
+ # Utilize the centralized JWT decoding and catch any JWT-related errors
17
+ try:
18
+ payload = decode_jwt(token)
19
+ except JWTError:
20
+ raise credentials_exception
21
+
22
+ user_id: str = payload.get("sub")
23
+ if user_id is None:
24
+ raise credentials_exception
25
+
26
+ # Verify if the token is stored and valid
27
+ if not tinydb_helper.query_token(user_id, token):
28
+ raise credentials_exception
29
+
30
+ # Payload is already obtained and validated, so just return it or its specific parts as needed
31
+ return {"user_id": user_id, "name": payload.get("name"), "role": payload.get("role")}
32
+
33
 
app/utils/db.py CHANGED
@@ -20,14 +20,18 @@ class TinyDBHelper:
20
  def query_token(self, user_id: str, token: str) -> bool:
21
  """Query to check if the token exists and is valid."""
22
  User = Query()
23
- # Assuming our tokens table contains 'user_id', 'token', and 'expires_at'
24
  result = self.tokens_table.search((User.user_id == user_id) & (User.token == token))
25
- # Optionally, check if the token is expired
 
 
 
 
 
26
  expires_at = datetime.fromisoformat(result[0]['expires_at'])
27
  if datetime.utcnow() > expires_at:
28
  return False
29
 
30
- return bool(result)
31
 
32
  def remove_token_by_value(self, token: str):
33
  """Remove a token based on its value."""
 
20
  def query_token(self, user_id: str, token: str) -> bool:
21
  """Query to check if the token exists and is valid."""
22
  User = Query()
 
23
  result = self.tokens_table.search((User.user_id == user_id) & (User.token == token))
24
+
25
+ # Check if the result is empty (i.e., no matching token found)
26
+ if not result:
27
+ return False
28
+
29
+ # Check if the token is expired
30
  expires_at = datetime.fromisoformat(result[0]['expires_at'])
31
  if datetime.utcnow() > expires_at:
32
  return False
33
 
34
+ return True
35
 
36
  def remove_token_by_value(self, token: str):
37
  """Remove a token based on its value."""
app/utils/jwt_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  from datetime import datetime, timedelta
2
  from jose import JWTError, jwt
 
3
  from typing import Any, Union
4
  #from tinydb import TinyDB, Query
5
  #from tinydb.storages import MemoryStorage
@@ -10,23 +11,20 @@ SECRET_KEY = "a_very_secret_key"
10
  ALGORITHM = "HS256"
11
  ACCESS_TOKEN_EXPIRE_MINUTES = 30 # The expiration time for the access token
12
 
13
- # db = TinyDB(storage=MemoryStorage)
14
- # tokens_table = db.table('tokens')
15
-
16
- # def insert_token(user_id: str, token: str, expires_in: timedelta):
17
- # expiration = datetime.utcnow() + expires_in
18
- # tokens_table.insert({'user_id': user_id, 'token': token, 'expires_at': expiration.isoformat()})
19
-
20
- # def validate_token(user_id: str, token: str) -> bool:
21
- # User = Query()
22
- # result = tokens_table.search((User.user_id == user_id) & (User.token == token))
23
- # if not result:
24
- # return False
25
- # # Check token expiration
26
- # expires_at = datetime.fromisoformat(result[0]['expires_at'])
27
- # if datetime.utcnow() > expires_at:
28
- # return False
29
- # return True
30
 
31
  def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
32
  """
 
1
  from datetime import datetime, timedelta
2
  from jose import JWTError, jwt
3
+ from fastapi import HTTPException
4
  from typing import Any, Union
5
  #from tinydb import TinyDB, Query
6
  #from tinydb.storages import MemoryStorage
 
11
  ALGORITHM = "HS256"
12
  ACCESS_TOKEN_EXPIRE_MINUTES = 30 # The expiration time for the access token
13
 
14
+ def encode_jwt(data: dict):
15
+ # Encode a JWT token
16
+ return jwt.encode(data, SECRET_KEY, algorithm=ALGORITHM)
17
+
18
+ def decode_jwt(token: str):
19
+ try:
20
+ # Decode a JWT token
21
+ return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
22
+ except JWTError as e:
23
+ # Handle specific JWT errors (e.g., token expired, invalid token)
24
+ raise HTTPException(status_code=401, detail="Token is invalid or expired")
25
+ except Exception as e:
26
+ # Handle unexpected errors
27
+ raise HTTPException(status_code=500, detail="An error occurred while decoding token")
 
 
 
28
 
29
  def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
30
  """