Litestar
In order to successfully inject dependencies into Litestar request handlers, make sure that the following points are completed:
use
@injectdecorator before http-methodLitestardecorator;added for each injected parameter to the request handler a typing of the form
Annotated[<your type>, Dependency(skip_validation=True)].Dependencyis object fromlitestar.params;use the
Providemarker from theinjection(not from Litestar) package indicating the provider
Example
from functools import partial
from random import Random
from typing import Any, Dict
from unittest.mock import Mock
import pytest
from litestar import Litestar, get
from litestar.params import Dependency
from litestar.testing import TestClient
from typing_extensions import Annotated
from injection import DeclarativeContainer, Provide, inject, providers
_NoValidationDependency = partial(Dependency, skip_validation=True)
class Redis:
def __init__(self, *, url: str, port: int):
self.uri = url + ":" + str(port)
self.url = url
self.port = port
def get(self, key: Any) -> Any:
return key
class Container(DeclarativeContainer):
redis = providers.Singleton(
Redis,
port=9873,
url="redis://...",
)
num = providers.Object(9402)
@get(
"/some_resource/{redis_key:int}",
status_code=200,
)
@inject
async def litestar_endpoint(
redis_key: int,
num: Annotated[int, _NoValidationDependency()] = Provide[Container.num],
redis: Annotated[Redis, _NoValidationDependency()] = Provide[Container.redis],
) -> Dict[str, Any]:
value = redis.get(redis_key)
return {"key": value, "num2": num}
@get(
"/num_endpoint",
status_code=200,
)
@inject
async def litestar_endpoint_object_provider(
num: Annotated[int, _NoValidationDependency()] = Provide[Container.num],
) -> Dict[str, Any]:
return {"detail": num}
_handlers = [
litestar_endpoint,
litestar_endpoint_object_provider,
]
app = Litestar(route_handlers=_handlers)
##################################################### TESTS ############################################################
@pytest.fixture(scope="session")
def test_client() -> TestClient[Any]:
return TestClient(app=app)
def test_num_endpoint(test_client: TestClient[Any]) -> None:
response = test_client.get("/num_endpoint")
assert response.status_code == 200
assert response.json() == {"detail": 9402}
def test_some_resource_endpoint(test_client: TestClient[Any]) -> None:
key = Random().randint(0, 100) # noqa: S311
response = test_client.get(f"/some_resource/{key}")
assert response.status_code == 200
assert response.json() == {"key": key, "num2": 9402}
def test_overriding(test_client: TestClient[Any]) -> None:
key = 192342526
mock_instance = Mock(get=lambda _: key)
override_providers = {"redis": mock_instance, "num": -2999999999}
with Container.override_providers(override_providers):
response = test_client.get(f"/some_resource/{key}")
assert response.status_code == 200
assert response.json() == {"key": key, "num2": -2999999999}