Source code for brunns.row.rowwrapper

# encoding=utf-8
import logging
import re
from collections import OrderedDict
from typing import Any, Iterable, Mapping, Sequence, Tuple, Union

try:
    from dataclasses import make_dataclass
except ImportError:  # pragma: no cover
    from collections import namedtuple as make_dataclass

logger = logging.getLogger(__name__)


[docs]class RowWrapper: """ Build lightweight row tuples for DB API and csv.DictReader rows. Inspired by Greg Stein's lovely `dtuple module <https://code.activestate.com/recipes/81252-using-dtuple-for-flexible-query-result-access>`_, which I can't find online any longer, isn't on pypi, and doesn't support Python 3 without some fixes. Initializer takes a sequence of column descriptions, either names, or tuples of names and other metadata (which will be ignored). For instance, it's happy to take a DB API cursor description, or a csv.DictReader's fieldnames property. Provides a wrap(row) method for wrapping rows. Some characters which are illegal in identifiers will be replaced when building the row tuples - currently "-" and " " characters will be replaced with "_"s. >>> cursor = conn.cursor() >>> cursor.execute("SELECT kind, rating FROM sausages ORDER BY rating DESC;") >>> wrapper = RowWrapper(cursor.description) >>> rows = [wrapper.wrap(row) for row in cursor.fetchall()] >>> reader = csv.DictReader(csv_file) >>> wrapper = RowWrapper(reader.fieldnames) >>> rows = [wrapper.wrap(row) for row in reader] """ def __init__( self, description: Iterable[Union[str, Tuple[str]]], force_lower_case_ids: bool = False ) -> None: column_names = ( [col for col in description] if isinstance(description[0], str) else [col[0] for col in description] ) self.ids_and_column_names = self._ids_and_column_names( column_names, force_lower_case=force_lower_case_ids ) self.dataclass = make_dataclass("RowTuple", self.ids_and_column_names.keys()) @staticmethod def _ids_and_column_names(names, force_lower_case=False): """Ensure all column names are unique identifiers.""" fixed = OrderedDict() for name in names: identifier = RowWrapper._make_identifier(name) if force_lower_case: identifier = identifier.lower() while identifier in fixed: identifier = RowWrapper._increment_numeric_suffix(identifier) fixed[identifier] = name return fixed @staticmethod def _make_identifier(string): """Attempt to convert string into a valid identifier by replacing invalid characters with "_"s, and prefixing with "a_" if necessary.""" string = re.sub(r"[ \-+/\\*%&$£#@.,;:'" "?<>]", "_", string) if re.match(r"^\d", string): string = "a_{0}".format(string) return string @staticmethod def _increment_numeric_suffix(s): """Increment (or add) numeric suffix to identifier.""" if re.match(r".*\d+$", s): return re.sub(r"\d+$", lambda n: str(int(n.group(0)) + 1), s) return s + "_2"
[docs] def wrap(self, row: Union[Mapping[str, Any], Sequence[Any]]): """Return row tuple for row.""" return ( self.dataclass( **{ ident: row[column_name] for ident, column_name in self.ids_and_column_names.items() } ) if isinstance(row, Mapping) else self.dataclass( **{ident: val for ident, val in zip(self.ids_and_column_names.keys(), row)} ) )
[docs] def wrap_all(self, rows: Iterable[Union[Mapping[str, Any], Sequence[Any]]]): """Return row tuple for each row in rows.""" return (self.wrap(r) for r in rows)
def __call__(self, row): return self.wrap(row)