mirror of https://github.com/tiangolo/fastapi.git
✨ Add Django middleware
This commit is contained in:
parent
eca465f4c9
commit
828c71ab82
|
|
@ -0,0 +1,82 @@
|
|||
from contextvars import ContextVar
|
||||
from typing import Annotated
|
||||
|
||||
from django.core.asgi import ASGIHandler
|
||||
from django.http import HttpRequest, HttpResponse, StreamingHttpResponse
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
_django_request = ContextVar[HttpRequest | None]("fastapi_django_request", default=None)
|
||||
|
||||
|
||||
def get_django_request():
|
||||
django_request = _django_request.get()
|
||||
|
||||
if not django_request:
|
||||
raise ValueError(
|
||||
"Django Request not found, did you forget to add the Django Middleware?"
|
||||
)
|
||||
|
||||
return django_request
|
||||
|
||||
|
||||
DjangoRequestDep = Annotated[HttpRequest, Depends(get_django_request)]
|
||||
|
||||
|
||||
class DjangoMiddleware(BaseHTTPMiddleware, ASGIHandler):
|
||||
"""A FastAPI Middleware that runs the Django HTTP Request lifecycle.
|
||||
|
||||
This middleware is responsible for running the Django HTTP Request lifecycle
|
||||
in the FastAPI application. It is useful when you want to use Django's
|
||||
authentication system, or any other Django feature that requires the
|
||||
Django Request object to be available."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
ASGIHandler.__init__(self)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _get_response_async(self, request):
|
||||
fastapi_response = await self._call_next(request)
|
||||
|
||||
assert isinstance(fastapi_response, StreamingResponse)
|
||||
|
||||
return StreamingHttpResponse(
|
||||
streaming_content=fastapi_response.body_iterator,
|
||||
headers=fastapi_response.headers,
|
||||
status=fastapi_response.status_code,
|
||||
)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
self._django_request, _ = self.create_request(scope, "")
|
||||
|
||||
_django_request.set(self._django_request)
|
||||
|
||||
await BaseHTTPMiddleware.__call__(self, scope, receive, send)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
self._call_next = call_next
|
||||
|
||||
django_response = await self.get_response_async(self._django_request)
|
||||
|
||||
if isinstance(django_response, HttpResponse):
|
||||
return Response(
|
||||
status_code=django_response.status_code,
|
||||
content=django_response.content,
|
||||
headers=django_response.headers,
|
||||
)
|
||||
|
||||
if isinstance(django_response, StreamingHttpResponse):
|
||||
|
||||
async def streaming():
|
||||
async for chunk in django_response.streaming_content:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
status_code=django_response.status_code,
|
||||
content=streaming(),
|
||||
headers=django_response.headers,
|
||||
)
|
||||
|
||||
return Response(status_code=500)
|
||||
|
|
@ -18,3 +18,6 @@ passlib[bcrypt] >=1.7.2,<2.0.0
|
|||
# types
|
||||
types-ujson ==5.7.0.1
|
||||
types-orjson ==3.6.2
|
||||
|
||||
# django
|
||||
django ==5.0.6
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
import django
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.core.management.color import no_style
|
||||
from django.core.management.sql import sql_flush
|
||||
from django.db import connection
|
||||
|
||||
settings.configure(
|
||||
SECRET_KEY="not_very",
|
||||
ROOT_URLCONF="tests.django.proj.urls",
|
||||
INSTALLED_APPS=[
|
||||
"django.contrib.auth",
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
],
|
||||
MIDDLEWARE=[
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
],
|
||||
DATABASES={
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.sqlite3",
|
||||
"NAME": ":memory:",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
django.setup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def django_db_setup():
|
||||
connection.creation.create_test_db(verbosity=0, autoclobber=True)
|
||||
|
||||
yield
|
||||
|
||||
connection.creation.destroy_test_db("default", verbosity=0)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def flush_db():
|
||||
sql_list = sql_flush(no_style(), connection, allow_cascade=False)
|
||||
|
||||
connection.ops.execute_sql_flush(sql_list)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_session_id():
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from tests.django.utils import create_authenticated_session
|
||||
|
||||
user = User.objects.create_user(username="test", password="test")
|
||||
|
||||
return create_authenticated_session(user)
|
||||
|
|
@ -0,0 +1 @@
|
|||
urlpatterns = []
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from django.contrib.auth import (
|
||||
aget_user,
|
||||
)
|
||||
from fastapi import FastAPI
|
||||
from fastapi.django import DjangoMiddleware, DjangoRequestDep
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(DjangoMiddleware)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Hello World"}
|
||||
|
||||
|
||||
@app.get("/current-user")
|
||||
async def django_user(django_request: DjangoRequestDep):
|
||||
user = await aget_user(django_request)
|
||||
|
||||
if not user.is_authenticated:
|
||||
return {"error": "User not authenticated"}
|
||||
|
||||
return {"username": user.username}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_unauthenticated():
|
||||
response = client.get("/current-user")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
assert response.json() == {"error": "User not authenticated"}
|
||||
|
||||
|
||||
def test_authenticated(authenticated_session_id: str):
|
||||
client.cookies.set("sessionid", authenticated_session_id)
|
||||
|
||||
response = client.get("/current-user")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
assert response.json() == {"username": "test"}
|
||||
|
||||
|
||||
def test_route_with_no_django_request():
|
||||
response = client.get("/")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.django import DjangoRequestDep
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def django_user(django_request: DjangoRequestDep):
|
||||
user = django_request.user
|
||||
|
||||
if not user.is_authenticated:
|
||||
return {"error": "User not authenticated"}
|
||||
|
||||
return {"username": user.username}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_returns_an_error():
|
||||
with pytest.raises(ValueError, match="Django Request not found"):
|
||||
client.get("/")
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY
|
||||
|
||||
|
||||
def create_authenticated_session(user):
|
||||
"""Creates an authenticated session for the given user."""
|
||||
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
session = engine.SessionStore()
|
||||
session.create()
|
||||
|
||||
session[SESSION_KEY] = str(user.id)
|
||||
session[BACKEND_SESSION_KEY] = (
|
||||
user.backend
|
||||
if hasattr(user, "backend")
|
||||
else "django.contrib.auth.backends.ModelBackend"
|
||||
)
|
||||
session[HASH_SESSION_KEY] = user.get_session_auth_hash()
|
||||
|
||||
session.save()
|
||||
|
||||
return session.session_key
|
||||
Loading…
Reference in New Issue