This example shows how to use Dependency Injector
with FastAPI and SQLAlchemy.
The source code is available on the Github.
Thanks to @ShvetsovYura for providing initial example: FastAPI_DI_SqlAlchemy.
Application structure¶Application has next structure:
./ âââ webapp/ â âââ __init__.py â âââ application.py â âââ containers.py â âââ database.py â âââ endpoints.py â âââ models.py â âââ repositories.py â âââ services.py â âââ tests.py âââ config.yml âââ docker-compose.yml âââ Dockerfile âââ requirements.txtApplication factory¶
Application factory creates container, wires it with the endpoints
module, creates FastAPI
app, and setup routes.
Application factory also creates database if it does not exist.
Listing of webapp/application.py
:
"""Application module.""" from fastapi import FastAPI from .containers import Container from . import endpoints def create_app() -> FastAPI: container = Container() db = container.db() db.create_database() app = FastAPI() app.container = container app.include_router(endpoints.router) return app app = create_app()Endpoints¶
Module endpoints
contains example endpoints. Endpoints have a dependency on user service. User service is injected using Wiring feature. See webapp/endpoints.py
:
"""Endpoints module.""" from typing import Annotated from fastapi import APIRouter, Depends, Response, status from dependency_injector.wiring import Provide, inject from .containers import Container from .repositories import NotFoundError from .services import UserService router = APIRouter() @router.get("/users") @inject def get_list( user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.get_users() @router.get("/users/{user_id}") @inject def get_by_id( user_id: int, user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): try: return user_service.get_user_by_id(user_id) except NotFoundError: return Response(status_code=status.HTTP_404_NOT_FOUND) @router.post("/users", status_code=status.HTTP_201_CREATED) @inject def add( user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.create_user() @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @inject def remove( user_id: int, user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ) -> Response: try: user_service.delete_user_by_id(user_id) except NotFoundError: return Response(status_code=status.HTTP_404_NOT_FOUND) else: return Response(status_code=status.HTTP_204_NO_CONTENT) @router.get("/status") def get_status(): return {"status": "OK"}Container¶
Declarative container wires example user service, user repository, and utility database class. See webapp/containers.py
:
"""Containers module.""" from dependency_injector import containers, providers from .database import Database from .repositories import UserRepository from .services import UserService class Container(containers.DeclarativeContainer): wiring_config = containers.WiringConfiguration(modules=[".endpoints"]) config = providers.Configuration(yaml_files=["config.yml"]) db = providers.Singleton(Database, db_url=config.db.url) user_repository = providers.Factory( UserRepository, session_factory=db.provided.session, ) user_service = providers.Factory( UserService, user_repository=user_repository, )Services¶
Module services
contains example user service. See webapp/services.py
:
"""Services module.""" from uuid import uuid4 from typing import Iterator from .repositories import UserRepository from .models import User class UserService: def __init__(self, user_repository: UserRepository) -> None: self._repository: UserRepository = user_repository def get_users(self) -> Iterator[User]: return self._repository.get_all() def get_user_by_id(self, user_id: int) -> User: return self._repository.get_by_id(user_id) def create_user(self) -> User: uid = uuid4() return self._repository.add(email=f"{uid}@email.com", password="pwd") def delete_user_by_id(self, user_id: int) -> None: return self._repository.delete_by_id(user_id)Repositories¶
Module repositories
contains example user repository. See webapp/repositories.py
:
"""Repositories module.""" from contextlib import AbstractContextManager from typing import Callable, Iterator from sqlalchemy.orm import Session from .models import User class UserRepository: def __init__(self, session_factory: Callable[..., AbstractContextManager[Session]]) -> None: self.session_factory = session_factory def get_all(self) -> Iterator[User]: with self.session_factory() as session: return session.query(User).all() def get_by_id(self, user_id: int) -> User: with self.session_factory() as session: user = session.query(User).filter(User.id == user_id).first() if not user: raise UserNotFoundError(user_id) return user def add(self, email: str, password: str, is_active: bool = True) -> User: with self.session_factory() as session: user = User(email=email, hashed_password=password, is_active=is_active) session.add(user) session.commit() session.refresh(user) return user def delete_by_id(self, user_id: int) -> None: with self.session_factory() as session: entity: User = session.query(User).filter(User.id == user_id).first() if not entity: raise UserNotFoundError(user_id) session.delete(entity) session.commit() class NotFoundError(Exception): entity_name: str def __init__(self, entity_id): super().__init__(f"{self.entity_name} not found, id: {entity_id}") class UserNotFoundError(NotFoundError): entity_name: str = "User"Models¶
Module models
contains example SQLAlchemy user model. See webapp/models.py
:
"""Models module.""" from sqlalchemy import Column, String, Boolean, Integer from .database import Base class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) email = Column(String, unique=True) hashed_password = Column(String) is_active = Column(Boolean, default=True) def __repr__(self): return f"<User(id={self.id}, " \ f"email=\"{self.email}\", " \ f"hashed_password=\"{self.hashed_password}\", " \ f"is_active={self.is_active})>"Database¶
Module database
defines declarative base and utility class with engine and session factory. See webapp/database.py
:
"""Database module.""" from contextlib import contextmanager, AbstractContextManager from typing import Callable import logging from sqlalchemy import create_engine, orm from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session logger = logging.getLogger(__name__) Base = declarative_base() class Database: def __init__(self, db_url: str) -> None: self._engine = create_engine(db_url, echo=True) self._session_factory = orm.scoped_session( orm.sessionmaker( autocommit=False, autoflush=False, bind=self._engine, ), ) def create_database(self) -> None: Base.metadata.create_all(self._engine) @contextmanager def session(self) -> Callable[..., AbstractContextManager[Session]]: session: Session = self._session_factory() try: yield session except Exception: logger.exception("Session rollback because of exception") session.rollback() raise finally: session.close()Tests¶
Tests use Provider overriding feature to replace repository with a mock. See webapp/tests.py
:
"""Tests module.""" from unittest import mock import pytest from fastapi.testclient import TestClient from .repositories import UserRepository, UserNotFoundError from .models import User from .application import app @pytest.fixture def client(): yield TestClient(app) def test_get_list(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_all.return_value = [ User(id=1, email="test1@email.com", hashed_password="pwd", is_active=True), User(id=2, email="test2@email.com", hashed_password="pwd", is_active=False), ] with app.container.user_repository.override(repository_mock): response = client.get("/users") assert response.status_code == 200 data = response.json() assert data == [ {"id": 1, "email": "test1@email.com", "hashed_password": "pwd", "is_active": True}, {"id": 2, "email": "test2@email.com", "hashed_password": "pwd", "is_active": False}, ] def test_get_by_id(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_by_id.return_value = User( id=1, email="xyz@email.com", hashed_password="pwd", is_active=True, ) with app.container.user_repository.override(repository_mock): response = client.get("/users/1") assert response.status_code == 200 data = response.json() assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True} repository_mock.get_by_id.assert_called_once_with(1) def test_get_by_id_404(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_by_id.side_effect = UserNotFoundError(1) with app.container.user_repository.override(repository_mock): response = client.get("/users/1") assert response.status_code == 404 @mock.patch("webapp.services.uuid4", return_value="xyz") def test_add(_, client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.add.return_value = User( id=1, email="xyz@email.com", hashed_password="pwd", is_active=True, ) with app.container.user_repository.override(repository_mock): response = client.post("/users") assert response.status_code == 201 data = response.json() assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True} repository_mock.add.assert_called_once_with(email="xyz@email.com", password="pwd") def test_remove(client): repository_mock = mock.Mock(spec=UserRepository) with app.container.user_repository.override(repository_mock): response = client.delete("/users/1") assert response.status_code == 204 repository_mock.delete_by_id.assert_called_once_with(1) def test_remove_404(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.delete_by_id.side_effect = UserNotFoundError(1) with app.container.user_repository.override(repository_mock): response = client.delete("/users/1") assert response.status_code == 404 def test_status(client): response = client.get("/status") assert response.status_code == 200 data = response.json() assert data == {"status": "OK"}Sources¶
The source code is available on the Github.
Sponsor the project on GitHub:
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4