diff --git a/fastapi/params.py b/fastapi/params.py index 860146531..7c35a571b 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -91,7 +91,7 @@ class Param(FieldInfo): max_length=max_length, discriminator=discriminator, multiple_of=multiple_of, - allow_nan=allow_inf_nan, + allow_inf_nan=allow_inf_nan, max_digits=max_digits, decimal_places=decimal_places, **extra, diff --git a/test_param_allow_inf_nan.py b/test_param_allow_inf_nan.py new file mode 100644 index 000000000..4488eeaea --- /dev/null +++ b/test_param_allow_inf_nan.py @@ -0,0 +1,58 @@ +from typing import Optional + +from fastapi import FastAPI +from fastapi.params import Query +from fastapi.testclient import TestClient + +app = FastAPI() + +@app.get("/") +def get( + x: Optional[float] = Query(default=0, allow_inf_nan=False), + y: Optional[float] = Query(default=0, allow_inf_nan=True), + z: Optional[float] = Query(default=0)) -> str: # type: ignore + return 'OK' + + +client = TestClient(app) + + +def test_allow_inf_nan_false(): + response = client.get('/?x=inf') + assert response.status_code == 422, response.text + + response = client.get('/?x=-inf') + assert response.status_code == 422, response.text + + response = client.get('/?x=nan') + assert response.status_code == 422, response.text + + response = client.get('/?x=0') + assert response.status_code == 200, response.text + +def test_allow_inf_nan_true(): + response = client.get('/?y=inf') + assert response.status_code == 200, response.text + + response = client.get('/?y=-inf') + assert response.status_code == 200, response.text + + response = client.get('/?y=nan') + assert response.status_code == 200, response.text + + response = client.get('/?y=0') + assert response.status_code == 200, response.text + +def test_allow_inf_nan_not_specified(): + response = client.get('/?z=inf') + assert response.status_code == 200, response.text + + response = client.get('/?z=-inf') + assert response.status_code == 200, response.text + + response = client.get('/?z=nan') + assert response.status_code == 200, response.text + + response = client.get('/?z=0') + assert response.status_code == 200, response.text +