Source code for pyathena.formatter
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
import textwrap
import uuid
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from datetime import date, datetime, timezone
from decimal import Decimal
from typing import Any, Callable, Dict, Optional, Type
from pyathena.error import ProgrammingError
from pyathena.model import AthenaCompression, AthenaFileFormat
_logger = logging.getLogger(__name__) # type: ignore
[docs]
class Formatter(metaclass=ABCMeta):
"""Abstract base class for formatting Python values for SQL queries.
Formatters handle the conversion of Python objects to SQL-compatible
string representations for use in parameterized queries. They ensure
proper escaping and formatting of values based on their types.
This class provides a framework for mapping Python types to formatting
functions and handles the formatting process during query preparation.
Attributes:
mappings: Dictionary mapping Python types to formatting functions.
default: Default formatting function for unmapped types.
"""
[docs]
def __init__(
self,
mappings: Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]],
default: Optional[Callable[[Formatter, Callable[[str], str], Any], Any]] = None,
) -> None:
self._mappings = mappings
self._default = default
@property
def mappings(
self,
) -> Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]]:
"""Get the current parameter formatting mappings.
Returns:
Dictionary mapping Python types to formatting functions.
"""
return self._mappings
[docs]
def get(self, type_) -> Optional[Callable[[Formatter, Callable[[str], str], Any], Any]]:
"""Get the formatting function for a specific Python type.
Args:
type_: The Python value to get formatter for.
Returns:
The formatting function for the type, or the default formatter if not found.
"""
return self.mappings.get(type(type_), self._default)
[docs]
def set(
self,
type_: Type[Any],
formatter: Callable[[Formatter, Callable[[str], str], Any], Any],
) -> None:
self.mappings[type_] = formatter
[docs]
def update(
self, mappings: Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]]
) -> None:
self.mappings.update(mappings)
[docs]
@abstractmethod
def format(self, operation: str, parameters: Optional[Dict[str, Any]] = None) -> str:
raise NotImplementedError # pragma: no cover
[docs]
@staticmethod
def wrap_unload(
operation: str,
s3_staging_dir: str,
format_: str = AthenaFileFormat.FILE_FORMAT_PARQUET,
compression: str = AthenaCompression.COMPRESSION_SNAPPY,
):
"""Wrap a SELECT query with UNLOAD statement for high-performance result retrieval.
Transforms SELECT or WITH queries into UNLOAD statements that export results
directly to S3 in optimized formats (Parquet, ORC) with compression. This
approach is significantly faster than standard CSV-based result retrieval
for large datasets and preserves data types more accurately.
Args:
operation: SQL query to wrap. Must be a SELECT or WITH statement.
s3_staging_dir: Base S3 directory for storing UNLOAD results.
format_: Output file format. Defaults to Parquet for optimal performance.
compression: Compression algorithm. Defaults to Snappy for balanced
compression ratio and speed.
Returns:
Tuple containing:
- Modified UNLOAD query string
- S3 location where results will be stored (None if not SELECT/WITH)
Example:
>>> query = "SELECT * FROM sales WHERE year = 2023"
>>> unload_query, location = Formatter.wrap_unload(
... query, "s3://my-bucket/results/"
... )
>>> print(unload_query)
UNLOAD (
SELECT * FROM sales WHERE year = 2023
)
TO 's3://my-bucket/results/unload/20231215/uuid//'
WITH (
format = 'PARQUET',
compression = 'SNAPPY'
)
Note:
Only SELECT and WITH statements are wrapped. Other statement types
are returned unchanged with location=None.
"""
if not operation or not operation.strip():
raise ProgrammingError("Query is none or empty.")
operation_upper = operation.strip().upper()
if operation_upper.startswith("SELECT") or operation_upper.startswith("WITH"):
now = datetime.now(timezone.utc).strftime("%Y%m%d")
location = f"{s3_staging_dir}unload/{now}/{str(uuid.uuid4())}/"
operation = textwrap.dedent(
f"""
UNLOAD (
\t{operation.strip()}
)
TO '{location}'
WITH (
\tformat = '{format_}',
\tcompression = '{compression}'
)
"""
)
else:
location = None
return operation, location
def _escape_presto(val: str) -> str:
escaped = val.replace("'", "''")
return f"'{escaped}'"
def _escape_hive(val: str) -> str:
escaped = (
val.replace("\\", "\\\\")
.replace("'", "\\'")
.replace("\r", "\\r")
.replace("\n", "\\n")
.replace("\t", "\\t")
)
return f"'{escaped}'"
def _format_none(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
return "null"
def _format_default(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
return val
def _format_date(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
return f"DATE '{val:%Y-%m-%d}'"
def _format_datetime(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
return f"""TIMESTAMP '{val.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}'"""
def _format_bool(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
return str(val)
def _format_str(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
return escaper(val)
def _format_seq(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
results = []
for v in val:
func = formatter.get(v)
if not func:
raise TypeError(f"{type(v)} is not defined formatter.")
formatted = func(formatter, escaper, v)
if not isinstance(
formatted,
(str,),
):
# force string format
if isinstance(
formatted,
(
float,
Decimal,
),
):
formatted = f"{formatted:f}"
else:
formatted = f"{formatted}"
results.append(formatted)
return f"""({", ".join(results)})"""
def _format_decimal(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
escaped = escaper(f"{val:f}")
return f"DECIMAL {escaped}"
_DEFAULT_FORMATTERS: Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]] = {
type(None): _format_none,
date: _format_date,
datetime: _format_datetime,
int: _format_default,
float: _format_default,
Decimal: _format_decimal,
bool: _format_bool,
str: _format_str,
list: _format_seq,
set: _format_seq,
tuple: _format_seq,
}
[docs]
class DefaultParameterFormatter(Formatter):
"""Default implementation of the Formatter for SQL parameter formatting.
This formatter provides standard formatting for common Python types used
in SQL parameters. It handles proper escaping and quoting to prevent
SQL injection and ensure valid SQL syntax.
Supported types:
- None: Converts to SQL NULL
- Strings: Properly escaped and quoted
- Numbers: int, float, Decimal
- Dates and times: date, datetime, time
- Booleans: Converted to SQL boolean literals
- Sequences: list, tuple, set (for IN clauses)
Example:
>>> formatter = DefaultParameterFormatter()
>>> sql = formatter.format(
... "SELECT * FROM users WHERE name = %(name)s AND age > %(age)s",
... {"name": "John's Data", "age": 25}
... )
>>> print(sql)
SELECT * FROM users WHERE name = 'John''s Data' AND age > 25
"""
[docs]
def __init__(self) -> None:
super().__init__(mappings=deepcopy(_DEFAULT_FORMATTERS), default=None)
[docs]
def format(self, operation: str, parameters: Optional[Dict[str, Any]] = None) -> str:
if not operation or not operation.strip():
raise ProgrammingError("Query is none or empty.")
operation = operation.strip()
operation_upper = operation.upper()
if (
operation_upper.startswith("SELECT")
or operation_upper.startswith("WITH")
or operation_upper.startswith("INSERT")
or operation_upper.startswith("UPDATE")
or operation_upper.startswith("MERGE")
):
escaper = _escape_presto
else:
escaper = _escape_hive
kwargs: Optional[Dict[str, Any]] = None
if parameters is not None:
kwargs = {}
if not parameters:
pass
elif isinstance(parameters, dict):
for k, v in parameters.items():
func = self.get(v)
if not func:
raise TypeError(f"{type(v)} is not defined formatter.")
kwargs.update({k: func(self, escaper, v)})
else:
raise ProgrammingError(
f"Unsupported parameter (Support for dict only): {parameters}"
)
return (operation % kwargs).strip() if kwargs is not None else operation.strip()