Source code for pyathena.converter

# -*- coding: utf-8 -*-
from __future__ import annotations

import binascii
import json
import logging
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from datetime import date, datetime, time
from decimal import Decimal
from typing import Any, Callable, Dict, List, Optional, Type

from dateutil.tz import gettz

from pyathena.util import strtobool

_logger = logging.getLogger(__name__)  # type: ignore


def _to_date(varchar_value: Optional[str]) -> Optional[date]:
    if varchar_value is None:
        return None
    return datetime.strptime(varchar_value, "%Y-%m-%d").date()


def _to_datetime(varchar_value: Optional[str]) -> Optional[datetime]:
    if varchar_value is None:
        return None
    return datetime.strptime(varchar_value, "%Y-%m-%d %H:%M:%S.%f")


def _to_datetime_with_tz(varchar_value: Optional[str]) -> Optional[datetime]:
    if varchar_value is None:
        return None
    datetime_, _, tz = varchar_value.rpartition(" ")
    return datetime.strptime(datetime_, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=gettz(tz))


def _to_time(varchar_value: Optional[str]) -> Optional[time]:
    if varchar_value is None:
        return None
    return datetime.strptime(varchar_value, "%H:%M:%S.%f").time()


def _to_float(varchar_value: Optional[str]) -> Optional[float]:
    if varchar_value is None:
        return None
    return float(varchar_value)


def _to_int(varchar_value: Optional[str]) -> Optional[int]:
    if varchar_value is None:
        return None
    return int(varchar_value)


def _to_decimal(varchar_value: Optional[str]) -> Optional[Decimal]:
    if not varchar_value:
        return None
    return Decimal(varchar_value)


def _to_boolean(varchar_value: Optional[str]) -> Optional[bool]:
    if not varchar_value:
        return None
    return bool(strtobool(varchar_value))


def _to_binary(varchar_value: Optional[str]) -> Optional[bytes]:
    if varchar_value is None:
        return None
    return binascii.a2b_hex("".join(varchar_value.split(" ")))


def _to_json(varchar_value: Optional[str]) -> Optional[Any]:
    if varchar_value is None:
        return None
    return json.loads(varchar_value)


def _to_array(varchar_value: Optional[str]) -> Optional[List[Any]]:
    """Convert array data to Python list.

    Supports two formats:
    1. JSON format: '[1, 2, 3]' or '["a", "b", "c"]' (recommended)
    2. Athena native format: '[1, 2, 3]' (basic cases only)

    For complex arrays, use CAST(array_column AS JSON) in your SQL query.

    Args:
        varchar_value: String representation of array data

    Returns:
        List representation of array, or None if parsing fails
    """
    if varchar_value is None:
        return None

    # Quick check: if it doesn't look like an array, return None
    if not (varchar_value.startswith("[") and varchar_value.endswith("]")):
        return None

    # Optimize: Try JSON parsing first (most reliable)
    try:
        result = json.loads(varchar_value)
        if isinstance(result, list):
            return result
    except json.JSONDecodeError:
        # If JSON parsing fails, fall back to basic parsing for simple cases
        pass

    inner = varchar_value[1:-1].strip()
    if not inner:
        return []

    try:
        # For nested arrays, too complex for basic parsing
        if "[" in inner:
            # Contains nested arrays - too complex for basic parsing
            return None
        # Try native parsing (including struct arrays)
        return _parse_array_native(inner)
    except Exception:
        return None


def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
    """Convert map data to Python dictionary.

    Supports two formats:
    1. JSON format: '{"key1": "value1", "key2": "value2"}' (recommended)
    2. Athena native format: '{key1=value1, key2=value2}' (basic cases only)

    For complex maps, use CAST(map_column AS JSON) in your SQL query.

    Args:
        varchar_value: String representation of map data

    Returns:
        Dictionary representation of map, or None if parsing fails
    """
    if varchar_value is None:
        return None

    # Quick check: if it doesn't look like a map, return None
    if not (varchar_value.startswith("{") and varchar_value.endswith("}")):
        return None

    # Optimize: Check if it looks like JSON vs Athena native format
    # JSON objects typically have quoted keys: {"key": value}
    # Athena native format has unquoted keys: {key=value}
    inner_preview = varchar_value[1:10] if len(varchar_value) > 10 else varchar_value[1:-1]

    if '"' in inner_preview or varchar_value.startswith('{"'):
        # Likely JSON format - try JSON parsing
        try:
            result = json.loads(varchar_value)
            return result if isinstance(result, dict) else None
        except json.JSONDecodeError:
            # If JSON parsing fails, fall back to native format parsing
            pass

    inner = varchar_value[1:-1].strip()
    if not inner:
        return {}

    try:
        # MAP format is always key=value pairs
        # But for complex structures, return None to keep as string
        if any(char in inner for char in "()[]"):
            # Contains complex structures (arrays, structs), skip parsing
            return None
        return _parse_map_native(inner)
    except Exception:
        return None


def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
    """Convert struct data to Python dictionary.

    Supports two formats:
    1. JSON format: '{"key": "value", "num": 123}' (recommended)
    2. Athena native format: '{key=value, num=123}' (basic cases only)

    For complex structs, use CAST(struct_column AS JSON) in your SQL query.

    Args:
        varchar_value: String representation of struct data

    Returns:
        Dictionary representation of struct, or None if parsing fails
    """
    if varchar_value is None:
        return None

    # Quick check: if it doesn't look like a struct, return None
    if not (varchar_value.startswith("{") and varchar_value.endswith("}")):
        return None

    # Optimize: Check if it looks like JSON vs Athena native format
    # JSON objects typically have quoted keys: {"key": value}
    # Athena native format has unquoted keys: {key=value}
    inner_preview = varchar_value[1:10] if len(varchar_value) > 10 else varchar_value[1:-1]

    if '"' in inner_preview or varchar_value.startswith('{"'):
        # Likely JSON format - try JSON parsing
        try:
            result = json.loads(varchar_value)
            return result if isinstance(result, dict) else None
        except json.JSONDecodeError:
            # If JSON parsing fails, fall back to native format parsing
            pass

    inner = varchar_value[1:-1].strip()
    if not inner:
        return {}

    try:
        if "=" in inner:
            # Named struct: {a=1, b=2}
            return _parse_named_struct(inner)
        # Unnamed struct: {Alice, 25}
        return _parse_unnamed_struct(inner)
    except Exception:
        return None


def _parse_array_native(inner: str) -> Optional[List[Any]]:
    """Parse array native format: 1, 2, 3 or {a, b}, {c, d}.

    Args:
        inner: Interior content of array without brackets.

    Returns:
        List with parsed values, or None if no valid values found.
    """
    result = []

    # Smart split by comma - respect brace groupings
    items = _split_array_items(inner)

    for item in items:
        if not item:
            continue

        # Handle struct (ROW) values in format {a, b, c} or {key=value, ...}
        if item.strip().startswith("{") and item.strip().endswith("}"):
            # This is a struct value - parse it as a struct
            struct_value = _to_struct(item.strip())
            if struct_value is not None:
                result.append(struct_value)
            continue

        # Skip items with nested arrays or complex quoting (safety check)
        if any(char in item for char in '[]="'):
            continue

        # Convert item to appropriate type
        converted_item = _convert_value(item)
        result.append(converted_item)

    return result if result else None


def _split_array_items(inner: str) -> List[str]:
    """Split array items by comma, respecting brace and bracket groupings.

    Args:
        inner: Interior content of array without brackets.

    Returns:
        List of item strings.
    """
    items = []
    current_item = ""
    brace_depth = 0
    bracket_depth = 0

    for char in inner:
        if char == "{":
            brace_depth += 1
        elif char == "}":
            brace_depth -= 1
        elif char == "[":
            bracket_depth += 1
        elif char == "]":
            bracket_depth -= 1
        elif char == "," and brace_depth == 0 and bracket_depth == 0:
            # Top-level comma - end current item
            items.append(current_item.strip())
            current_item = ""
            continue

        current_item += char

    # Add the last item
    if current_item.strip():
        items.append(current_item.strip())

    return items


def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]:
    """Parse map native format: key1=value1, key2=value2.

    Args:
        inner: Interior content of map without braces.

    Returns:
        Dictionary with parsed key-value pairs, or None if no valid pairs found.
    """
    result = {}

    # Simple split by comma for basic cases
    pairs = [pair.strip() for pair in inner.split(",")]

    for pair in pairs:
        if "=" not in pair:
            continue

        key, value = pair.split("=", 1)
        key = key.strip()
        value = value.strip()

        # Skip pairs with special characters (safety check)
        if any(char in key for char in '{}="') or any(char in value for char in '{}="'):
            continue

        # Convert both key and value to appropriate types
        converted_key = _convert_value(key)
        converted_value = _convert_value(value)
        # Always use string keys for consistency with expected test behavior
        result[str(converted_key)] = converted_value

    return result if result else None


def _parse_named_struct(inner: str) -> Optional[Dict[str, Any]]:
    """Parse named struct format: a=1, b=2.

    Args:
        inner: Interior content of struct without braces.

    Returns:
        Dictionary with parsed key-value pairs, or None if no valid pairs found.
    """
    result = {}

    # Simple split by comma for basic cases
    pairs = [pair.strip() for pair in inner.split(",")]

    for pair in pairs:
        if "=" not in pair:
            continue

        key, value = pair.split("=", 1)
        key = key.strip()
        value = value.strip()

        # Skip pairs with special characters (safety check)
        if any(char in key for char in '{}="') or any(char in value for char in '{}="'):
            continue

        # Convert value to appropriate type
        result[key] = _convert_value(value)

    return result if result else None


def _parse_unnamed_struct(inner: str) -> Dict[str, Any]:
    """Parse unnamed struct format: Alice, 25.

    Args:
        inner: Interior content of struct without braces.

    Returns:
        Dictionary with indexed keys mapping to parsed values.
    """
    values = [v.strip() for v in inner.split(",")]
    return {str(i): _convert_value(value) for i, value in enumerate(values)}


def _convert_value(value: str) -> Any:
    """Convert string value to appropriate Python type.

    Args:
        value: String value to convert.

    Returns:
        Converted value as int, float, bool, None, or string.
    """
    if value.lower() == "null":
        return None
    if value.lower() == "true":
        return True
    if value.lower() == "false":
        return False
    if value.isdigit() or value.startswith("-") and value[1:].isdigit():
        return int(value)
    if "." in value and value.replace(".", "", 1).replace("-", "", 1).isdigit():
        return float(value)
    return value


def _to_default(varchar_value: Optional[str]) -> Optional[str]:
    return varchar_value


_DEFAULT_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = {
    "boolean": _to_boolean,
    "tinyint": _to_int,
    "smallint": _to_int,
    "integer": _to_int,
    "bigint": _to_int,
    "float": _to_float,
    "real": _to_float,
    "double": _to_float,
    "char": _to_default,
    "varchar": _to_default,
    "string": _to_default,
    "timestamp": _to_datetime,
    "timestamp with time zone": _to_datetime_with_tz,
    "date": _to_date,
    "time": _to_time,
    "varbinary": _to_binary,
    "array": _to_array,
    "map": _to_map,
    "row": _to_struct,
    "decimal": _to_decimal,
    "json": _to_json,
}


[docs] class Converter(metaclass=ABCMeta): """Abstract base class for converting Athena data types to Python objects. Converters handle the transformation of string values returned by Athena into appropriate Python data types. Different cursor implementations may use different converters to optimize for their specific use cases. This class provides a framework for mapping Athena data type names to conversion functions and handles the conversion process during result set processing. Attributes: mappings: Dictionary mapping Athena type names to conversion functions. default: Default conversion function for unmapped types. types: Optional dictionary mapping type names to Python type objects. """
[docs] def __init__( self, mappings: Dict[str, Callable[[Optional[str]], Optional[Any]]], default: Callable[[Optional[str]], Optional[Any]] = _to_default, types: Optional[Dict[str, Type[Any]]] = None, ) -> None: if mappings: self._mappings = mappings else: self._mappings = {} self._default = default if types: self._types = types else: self._types = {}
@property def mappings(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: """Get the current type conversion mappings. Returns: Dictionary mapping Athena data types to conversion functions. """ return self._mappings @property def types(self) -> Dict[str, Type[Any]]: """Get the current type mappings for result set descriptions. Returns: Dictionary mapping Athena data types to Python types. """ return self._types
[docs] def get(self, type_: str) -> Callable[[Optional[str]], Optional[Any]]: """Get the conversion function for a specific Athena data type. Args: type_: The Athena data type name. Returns: The conversion function for the type, or the default converter if not found. """ return self.mappings.get(type_, self._default)
[docs] def set(self, type_: str, converter: Callable[[Optional[str]], Optional[Any]]) -> None: """Set a custom conversion function for an Athena data type. Args: type_: The Athena data type name. converter: The conversion function to use for this type. """ self.mappings[type_] = converter
[docs] def remove(self, type_: str) -> None: """Remove a custom conversion function for an Athena data type. Args: type_: The Athena data type name to remove. """ self.mappings.pop(type_, None)
[docs] def update(self, mappings: Dict[str, Callable[[Optional[str]], Optional[Any]]]) -> None: """Update multiple conversion functions at once. Args: mappings: Dictionary of type names to conversion functions. """ self.mappings.update(mappings)
[docs] @abstractmethod def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: raise NotImplementedError # pragma: no cover
[docs] class DefaultTypeConverter(Converter): """Default implementation of the Converter for standard Python types. This converter provides mappings for all standard Athena data types to their corresponding Python types using built-in conversion functions. It's used by the standard Cursor class by default. Supported conversions: - Numeric types: integer, bigint, real, double, decimal - String types: varchar, char - Date/time types: date, timestamp, time (with timezone support) - Boolean: boolean - Binary: varbinary - Complex types: array, map, row/struct - JSON: json Example: >>> converter = DefaultTypeConverter() >>> converter.convert('integer', '42') 42 >>> converter.convert('date', '2023-01-15') datetime.date(2023, 1, 15) """
[docs] def __init__(self) -> None: super().__init__(mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default)
[docs] def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: converter = self.get(type_) return converter(value)