mirror of https://github.com/tiangolo/fastapi.git
Merge eb2cef2aa7 into 272204c0c7
This commit is contained in:
commit
dd91796b42
|
|
@ -0,0 +1,82 @@
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from django.core.asgi import ASGIHandler
|
||||||
|
from django.http import HttpRequest, HttpResponse, StreamingHttpResponse
|
||||||
|
from fastapi import Depends, Request, Response
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
_django_request = ContextVar[HttpRequest | None]("fastapi_django_request", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_django_request():
|
||||||
|
django_request = _django_request.get()
|
||||||
|
|
||||||
|
if not django_request:
|
||||||
|
raise ValueError(
|
||||||
|
"Django Request not found, did you forget to add the Django Middleware?"
|
||||||
|
)
|
||||||
|
|
||||||
|
return django_request
|
||||||
|
|
||||||
|
|
||||||
|
DjangoRequestDep = Annotated[HttpRequest, Depends(get_django_request)]
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoMiddleware(BaseHTTPMiddleware, ASGIHandler):
|
||||||
|
"""A FastAPI Middleware that runs the Django HTTP Request lifecycle.
|
||||||
|
|
||||||
|
This middleware is responsible for running the Django HTTP Request lifecycle
|
||||||
|
in the FastAPI application. It is useful when you want to use Django's
|
||||||
|
authentication system, or any other Django feature that requires the
|
||||||
|
Django Request object to be available."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
ASGIHandler.__init__(self)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _get_response_async(self, request):
|
||||||
|
fastapi_response = await self._call_next(request)
|
||||||
|
|
||||||
|
assert isinstance(fastapi_response, StreamingResponse)
|
||||||
|
|
||||||
|
return StreamingHttpResponse(
|
||||||
|
streaming_content=fastapi_response.body_iterator,
|
||||||
|
headers=fastapi_response.headers,
|
||||||
|
status=fastapi_response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
self._django_request, _ = self.create_request(scope, "")
|
||||||
|
|
||||||
|
_django_request.set(self._django_request)
|
||||||
|
|
||||||
|
await BaseHTTPMiddleware.__call__(self, scope, receive, send)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
self._call_next = call_next
|
||||||
|
|
||||||
|
django_response = await self.get_response_async(self._django_request)
|
||||||
|
|
||||||
|
if isinstance(django_response, HttpResponse):
|
||||||
|
return Response(
|
||||||
|
status_code=django_response.status_code,
|
||||||
|
content=django_response.content,
|
||||||
|
headers=django_response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(django_response, StreamingHttpResponse):
|
||||||
|
|
||||||
|
async def streaming():
|
||||||
|
async for chunk in django_response.streaming_content:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
status_code=django_response.status_code,
|
||||||
|
content=streaming(),
|
||||||
|
headers=django_response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(status_code=500)
|
||||||
|
|
@ -14,3 +14,6 @@ inline-snapshot>=0.21.1
|
||||||
# types
|
# types
|
||||||
types-ujson ==5.10.0.20240515
|
types-ujson ==5.10.0.20240515
|
||||||
types-orjson ==3.6.2
|
types-orjson ==3.6.2
|
||||||
|
|
||||||
|
# django
|
||||||
|
django ==5.0.6
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
import django
|
||||||
|
import pytest
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core.management.color import no_style
|
||||||
|
from django.core.management.sql import sql_flush
|
||||||
|
from django.db import connection
|
||||||
|
|
||||||
|
settings.configure(
|
||||||
|
SECRET_KEY="not_very",
|
||||||
|
ROOT_URLCONF="tests.django.proj.urls",
|
||||||
|
INSTALLED_APPS=[
|
||||||
|
"django.contrib.auth",
|
||||||
|
"django.contrib.contenttypes",
|
||||||
|
"django.contrib.sessions",
|
||||||
|
],
|
||||||
|
MIDDLEWARE=[
|
||||||
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
|
],
|
||||||
|
DATABASES={
|
||||||
|
"default": {
|
||||||
|
"ENGINE": "django.db.backends.sqlite3",
|
||||||
|
"NAME": ":memory:",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
django.setup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def django_db_setup():
|
||||||
|
connection.creation.create_test_db(verbosity=0, autoclobber=True)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
connection.creation.destroy_test_db("default", verbosity=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def flush_db():
|
||||||
|
sql_list = sql_flush(no_style(), connection, allow_cascade=False)
|
||||||
|
|
||||||
|
connection.ops.execute_sql_flush(sql_list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def authenticated_session_id():
|
||||||
|
from django.contrib.auth.models import User
|
||||||
|
|
||||||
|
from tests.django.utils import create_authenticated_session
|
||||||
|
|
||||||
|
user = User.objects.create_user(username="test", password="test")
|
||||||
|
|
||||||
|
return create_authenticated_session(user)
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
urlpatterns = []
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
from django.contrib.auth import aget_user
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.django import DjangoMiddleware, DjangoRequestDep
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(DjangoMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "Hello World"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/current-user")
|
||||||
|
async def django_user(django_request: DjangoRequestDep):
|
||||||
|
user = await aget_user(django_request)
|
||||||
|
|
||||||
|
if not user.is_authenticated:
|
||||||
|
return {"error": "User not authenticated"}
|
||||||
|
|
||||||
|
return {"username": user.username}
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unauthenticated():
|
||||||
|
response = client.get("/current-user")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
assert response.json() == {"error": "User not authenticated"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_authenticated(authenticated_session_id: str):
|
||||||
|
client.cookies.set("sessionid", authenticated_session_id)
|
||||||
|
|
||||||
|
response = client.get("/current-user")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
assert response.json() == {"username": "test"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_route_with_no_django_request():
|
||||||
|
response = client.get("/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
assert response.json() == {"message": "Hello World"}
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.django import DjangoRequestDep
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def django_user(django_request: DjangoRequestDep):
|
||||||
|
user = django_request.user
|
||||||
|
|
||||||
|
if not user.is_authenticated:
|
||||||
|
return {"error": "User not authenticated"}
|
||||||
|
|
||||||
|
return {"username": user.username}
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_returns_an_error():
|
||||||
|
with pytest.raises(ValueError, match="Django Request not found"):
|
||||||
|
client.get("/")
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY
|
||||||
|
|
||||||
|
|
||||||
|
def create_authenticated_session(user):
|
||||||
|
"""Creates an authenticated session for the given user."""
|
||||||
|
|
||||||
|
engine = import_module(settings.SESSION_ENGINE)
|
||||||
|
session = engine.SessionStore()
|
||||||
|
session.create()
|
||||||
|
|
||||||
|
session[SESSION_KEY] = str(user.id)
|
||||||
|
session[BACKEND_SESSION_KEY] = (
|
||||||
|
user.backend
|
||||||
|
if hasattr(user, "backend")
|
||||||
|
else "django.contrib.auth.backends.ModelBackend"
|
||||||
|
)
|
||||||
|
session[HASH_SESSION_KEY] = user.get_session_auth_hash()
|
||||||
|
|
||||||
|
session.save()
|
||||||
|
|
||||||
|
return session.session_key
|
||||||
Loading…
Reference in New Issue