Source code for sql_executor.base

"""Backend-agnostic base for stored-procedure / SQL executors.

Template Method pattern: this class owns everything that's the same
across DB backends (cursor lifecycle, commit/rollback, fetch-to-dict,
chunked streaming, error wrapping). A subclass only has to implement:

  - _connect()              -> open a DB-API 2.0 connection
  - _call_procedure_sql()   -> dialect-specific "call this SP" string

Everything else (call_procedure, execute, stream, stream_procedure) is
inherited for free. See sqlserver.py for a ~15-line concrete example.
"""

from __future__ import annotations

import abc
import logging
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Sequence

logger = logging.getLogger(__name__)


[docs] class DatabaseError(Exception): """Raised when a stored procedure / SQL execution fails. Callers should catch this rather than checking for None/False — it's always raised on failure, never swallowed. """
def _rows_to_dicts(cursor) -> List[Dict[str, Any]]: columns = [col[0].lower() for col in cursor.description] return [dict(zip(columns, row)) for row in cursor.fetchall()]
[docs] class BaseSqlExecutor(abc.ABC): # Subclass should override with their driver's exception type, # e.g. `pyodbc.Error`, `psycopg2.Error`, `cx_Oracle.Error`. # Left as plain Exception here so a forgetful subclass still works, # just with a less precise catch. driver_error: type = Exception # ---- subclasses must implement ------------------------------------- @abc.abstractmethod def _connect(self): """Return a new DB-API 2.0 connection.""" @abc.abstractmethod def _call_procedure_sql(self, sp_name: str, params: Sequence[Any]) -> str: """Return the dialect-specific SQL string to call a stored procedure.""" # ---- shared cursor lifecycle ---------------------------------------- @contextmanager def _get_cursor(self): conn = self._connect() try: cursor = conn.cursor() yield cursor conn.commit() except Exception: conn.rollback() raise finally: conn.close() def _run(self, sql: str, params: Sequence[Any], fetch: bool): try: with self._get_cursor() as cursor: if params: cursor.execute(sql, params) else: cursor.execute(sql) return _rows_to_dicts(cursor) if fetch else True except self.driver_error as db_ex: logger.exception("DB error executing: %s", sql) raise DatabaseError(str(db_ex)) from db_ex # ---- public API ----------------------------------------------------
[docs] def call_procedure( self, sp_name: str, params: Sequence[Any] = (), fetch: bool = False ): sql = self._call_procedure_sql(sp_name, params) return self._run(sql, tuple(params), fetch)
[docs] def execute(self, sql: str, params: Sequence[Any] = (), fetch: bool = False): """Run raw SQL (SELECT/INSERT/UPDATE/DELETE).""" return self._run(sql, tuple(params), fetch)
[docs] def stream( self, sql: str, params: Sequence[Any] = (), chunk_size: int = 500, ) -> Iterator[Dict[str, Any]]: """Yield rows in chunks — for big SELECTs or SPs returning lots of rows.""" try: with self._get_cursor() as cursor: cursor.arraysize = chunk_size if params: cursor.execute(sql, params) else: cursor.execute(sql) columns = [col[0].lower() for col in cursor.description] while True: rows = cursor.fetchmany(chunk_size) if not rows: break for row in rows: yield dict(zip(columns, row)) except self.driver_error as db_ex: logger.exception("DB error streaming: %s", sql) raise DatabaseError(str(db_ex)) from db_ex
[docs] def stream_procedure( self, sp_name: str, params: Sequence[Any] = (), chunk_size: int = 500 ) -> Iterator[Dict[str, Any]]: sql = self._call_procedure_sql(sp_name, params) yield from self.stream(sql, tuple(params), chunk_size)