180 lines
5.5 KiB
Python
180 lines
5.5 KiB
Python
|
|
import time
|
||
|
|
|
||
|
|
import jwt
|
||
|
|
import pytest
|
||
|
|
from httpx import ASGITransport, AsyncClient
|
||
|
|
|
||
|
|
from app.core.auth import AuthUser, get_current_user, require_auth
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.main import app
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def jwt_secret():
|
||
|
|
"""Provide a test JWT secret."""
|
||
|
|
return "test-jwt-secret-for-testing-only"
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def valid_token(jwt_secret):
|
||
|
|
"""Generate a valid JWT token."""
|
||
|
|
payload = {
|
||
|
|
"sub": "user-123",
|
||
|
|
"email": "test@example.com",
|
||
|
|
"role": "authenticated",
|
||
|
|
"aud": "authenticated",
|
||
|
|
"exp": int(time.time()) + 3600,
|
||
|
|
}
|
||
|
|
return jwt.encode(payload, jwt_secret, algorithm="HS256")
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def expired_token(jwt_secret):
|
||
|
|
"""Generate an expired JWT token."""
|
||
|
|
payload = {
|
||
|
|
"sub": "user-123",
|
||
|
|
"email": "test@example.com",
|
||
|
|
"role": "authenticated",
|
||
|
|
"aud": "authenticated",
|
||
|
|
"exp": int(time.time()) - 3600, # Expired 1 hour ago
|
||
|
|
}
|
||
|
|
return jwt.encode(payload, jwt_secret, algorithm="HS256")
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def invalid_token():
|
||
|
|
"""Generate a token signed with wrong secret."""
|
||
|
|
payload = {
|
||
|
|
"sub": "user-123",
|
||
|
|
"email": "test@example.com",
|
||
|
|
"role": "authenticated",
|
||
|
|
"aud": "authenticated",
|
||
|
|
"exp": int(time.time()) + 3600,
|
||
|
|
}
|
||
|
|
return jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def auth_client(db_session, jwt_secret, valid_token, monkeypatch):
|
||
|
|
"""Client with valid auth token and configured JWT secret."""
|
||
|
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||
|
|
|
||
|
|
async def _get_client():
|
||
|
|
async with AsyncClient(
|
||
|
|
transport=ASGITransport(app=app),
|
||
|
|
base_url="http://test",
|
||
|
|
headers={"Authorization": f"Bearer {valid_token}"},
|
||
|
|
) as ac:
|
||
|
|
yield ac
|
||
|
|
|
||
|
|
return _get_client
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_current_user_valid_token(jwt_secret, valid_token, monkeypatch):
|
||
|
|
"""Test get_current_user returns user for valid token."""
|
||
|
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||
|
|
|
||
|
|
class MockRequest:
|
||
|
|
headers = {"Authorization": f"Bearer {valid_token}"}
|
||
|
|
|
||
|
|
user = get_current_user(MockRequest())
|
||
|
|
assert user is not None
|
||
|
|
assert user.id == "user-123"
|
||
|
|
assert user.email == "test@example.com"
|
||
|
|
assert user.role == "authenticated"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_current_user_no_token(jwt_secret, monkeypatch):
|
||
|
|
"""Test get_current_user returns None when no token."""
|
||
|
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||
|
|
|
||
|
|
class MockRequest:
|
||
|
|
headers = {}
|
||
|
|
|
||
|
|
user = get_current_user(MockRequest())
|
||
|
|
assert user is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_current_user_expired_token(jwt_secret, expired_token, monkeypatch):
|
||
|
|
"""Test get_current_user returns None for expired token."""
|
||
|
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||
|
|
|
||
|
|
class MockRequest:
|
||
|
|
headers = {"Authorization": f"Bearer {expired_token}"}
|
||
|
|
|
||
|
|
user = get_current_user(MockRequest())
|
||
|
|
assert user is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_current_user_invalid_token(jwt_secret, invalid_token, monkeypatch):
|
||
|
|
"""Test get_current_user returns None for invalid token."""
|
||
|
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||
|
|
|
||
|
|
class MockRequest:
|
||
|
|
headers = {"Authorization": f"Bearer {invalid_token}"}
|
||
|
|
|
||
|
|
user = get_current_user(MockRequest())
|
||
|
|
assert user is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_current_user_malformed_header(jwt_secret, monkeypatch):
|
||
|
|
"""Test get_current_user returns None for malformed auth header."""
|
||
|
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||
|
|
|
||
|
|
class MockRequest:
|
||
|
|
headers = {"Authorization": "NotBearer token"}
|
||
|
|
|
||
|
|
user = get_current_user(MockRequest())
|
||
|
|
assert user is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_require_auth_valid_user():
|
||
|
|
"""Test require_auth passes through valid user."""
|
||
|
|
user = AuthUser(id="user-123", email="test@example.com")
|
||
|
|
result = require_auth(user)
|
||
|
|
assert result is user
|
||
|
|
|
||
|
|
|
||
|
|
async def test_require_auth_no_user():
|
||
|
|
"""Test require_auth raises 401 for no user."""
|
||
|
|
from fastapi import HTTPException
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
require_auth(None)
|
||
|
|
assert exc_info.value.status_code == 401
|
||
|
|
assert exc_info.value.detail == "Authentication required"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_protected_endpoint_without_token(db_session):
|
||
|
|
"""Test that write endpoint returns 401 without token."""
|
||
|
|
async with AsyncClient(
|
||
|
|
transport=ASGITransport(app=app), base_url="http://test"
|
||
|
|
) as ac:
|
||
|
|
response = await ac.post("/runs", json={"game_id": 1, "name": "Test Run"})
|
||
|
|
assert response.status_code == 401
|
||
|
|
assert response.json()["detail"] == "Authentication required"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_protected_endpoint_with_expired_token(
|
||
|
|
db_session, jwt_secret, expired_token, monkeypatch
|
||
|
|
):
|
||
|
|
"""Test that write endpoint returns 401 with expired token."""
|
||
|
|
monkeypatch.setattr(settings, "supabase_jwt_secret", jwt_secret)
|
||
|
|
|
||
|
|
async with AsyncClient(
|
||
|
|
transport=ASGITransport(app=app),
|
||
|
|
base_url="http://test",
|
||
|
|
headers={"Authorization": f"Bearer {expired_token}"},
|
||
|
|
) as ac:
|
||
|
|
response = await ac.post("/runs", json={"game_id": 1, "name": "Test Run"})
|
||
|
|
assert response.status_code == 401
|
||
|
|
|
||
|
|
|
||
|
|
async def test_read_endpoint_without_token(db_session):
|
||
|
|
"""Test that read endpoints work without authentication."""
|
||
|
|
async with AsyncClient(
|
||
|
|
transport=ASGITransport(app=app), base_url="http://test"
|
||
|
|
) as ac:
|
||
|
|
response = await ac.get("/runs")
|
||
|
|
assert response.status_code == 200
|