64 lines
1.5 KiB
Python
64 lines
1.5 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from collections.abc import Iterator
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
from sqlalchemy import create_engine, event
|
||
|
|
from sqlalchemy.orm import Session, sessionmaker
|
||
|
|
from sqlalchemy.pool import StaticPool
|
||
|
|
|
||
|
|
from quartermaster.db import get_session
|
||
|
|
from quartermaster.main import create_app
|
||
|
|
from quartermaster.models import Base
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def engine():
|
||
|
|
eng = create_engine(
|
||
|
|
"sqlite:///:memory:",
|
||
|
|
future=True,
|
||
|
|
connect_args={"check_same_thread": False},
|
||
|
|
poolclass=StaticPool,
|
||
|
|
)
|
||
|
|
|
||
|
|
@event.listens_for(eng, "connect")
|
||
|
|
def _fk_on(conn, _): # type: ignore[no-untyped-def]
|
||
|
|
cur = conn.cursor()
|
||
|
|
cur.execute("PRAGMA foreign_keys=ON")
|
||
|
|
cur.close()
|
||
|
|
|
||
|
|
Base.metadata.create_all(eng)
|
||
|
|
yield eng
|
||
|
|
eng.dispose()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def session_factory(engine):
|
||
|
|
return sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def db(session_factory) -> Iterator[Session]:
|
||
|
|
session = session_factory()
|
||
|
|
try:
|
||
|
|
yield session
|
||
|
|
finally:
|
||
|
|
session.close()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client(session_factory) -> Iterator[TestClient]:
|
||
|
|
app = create_app()
|
||
|
|
|
||
|
|
def override_get_session() -> Iterator[Session]:
|
||
|
|
session = session_factory()
|
||
|
|
try:
|
||
|
|
yield session
|
||
|
|
finally:
|
||
|
|
session.close()
|
||
|
|
|
||
|
|
app.dependency_overrides[get_session] = override_get_session
|
||
|
|
with TestClient(app) as c:
|
||
|
|
yield c
|