Source code for pyathena

# -*- coding: utf-8 -*-
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, Any, FrozenSet, Type, overload

from pyathena.error import *  # noqa

if TYPE_CHECKING:
    from pyathena.connection import Connection, ConnectionCursor
    from pyathena.cursor import Cursor

try:
    from pyathena._version import __version__
except ImportError:
    try:
        from importlib.metadata import version

        __version__ = version("PyAthena")
    except Exception:
        __version__ = "unknown"
user_agent_extra: str = f"PyAthena/{__version__}"

# Globals https://www.python.org/dev/peps/pep-0249/#globals
apilevel: str = "2.0"
threadsafety: int = 2
paramstyle: str = "pyformat"


[docs] class DBAPITypeObject(FrozenSet[str]): """Type Objects and Constructors https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors """ def __eq__(self, other: object): if isinstance(other, frozenset): return frozenset.__eq__(self, other) return other in self def __ne__(self, other: object): if isinstance(other, frozenset): return frozenset.__ne__(self, other) return other not in self def __hash__(self): return frozenset.__hash__(self)
# https://docs.aws.amazon.com/athena/latest/ug/data-types.html STRING: DBAPITypeObject = DBAPITypeObject(("char", "varchar", "map", "array", "row")) BINARY: DBAPITypeObject = DBAPITypeObject(("varbinary",)) BOOLEAN: DBAPITypeObject = DBAPITypeObject(("boolean",)) NUMBER: DBAPITypeObject = DBAPITypeObject( ("tinyint", "smallint", "bigint", "integer", "real", "double", "float", "decimal") ) DATE: DBAPITypeObject = DBAPITypeObject(("date",)) TIME: DBAPITypeObject = DBAPITypeObject(("time", "time with time zone")) DATETIME: DBAPITypeObject = DBAPITypeObject(("timestamp", "timestamp with time zone")) JSON: DBAPITypeObject = DBAPITypeObject(("json",)) Date: Type[datetime.date] = datetime.date Time: Type[datetime.time] = datetime.time Timestamp: Type[datetime.datetime] = datetime.datetime @overload def connect(*args, cursor_class: None = ..., **kwargs) -> "Connection[Cursor]": ... @overload def connect( *args, cursor_class: Type[ConnectionCursor], **kwargs ) -> "Connection[ConnectionCursor]": ...
[docs] def connect(*args, **kwargs) -> "Connection[Any]": """Create a new database connection to Amazon Athena. This function provides the main entry point for establishing connections to Amazon Athena. It follows the DB API 2.0 specification and returns a Connection object that can be used to create cursors for executing SQL queries. Args: s3_staging_dir: S3 location to store query results. Required if not using workgroups or if the workgroup doesn't have a result location. region_name: AWS region name. If not specified, uses the default region from your AWS configuration. schema_name: Athena database/schema name. Defaults to "default". catalog_name: Athena data catalog name. Defaults to "awsdatacatalog". work_group: Athena workgroup name. Can be used instead of s3_staging_dir if the workgroup has a result location configured. poll_interval: Time in seconds between polling for query completion. Defaults to 1.0. encryption_option: S3 encryption option for query results. Can be "SSE_S3", "SSE_KMS", or "CSE_KMS". kms_key: KMS key ID for encryption when using SSE_KMS or CSE_KMS. profile_name: AWS profile name to use for authentication. role_arn: ARN of IAM role to assume for authentication. role_session_name: Session name when assuming a role. cursor_class: Custom cursor class to use. If not specified, uses the default Cursor class. kill_on_interrupt: Whether to cancel running queries when interrupted. Defaults to True. **kwargs: Additional keyword arguments passed to the Connection constructor. Returns: A Connection object that can be used to create cursors and execute queries. Raises: AssertionError: If neither s3_staging_dir nor work_group is provided. Example: >>> import pyathena >>> conn = pyathena.connect( ... s3_staging_dir='s3://my-bucket/staging/', ... region_name='us-east-1', ... schema_name='mydatabase' ... ) >>> cursor = conn.cursor() >>> cursor.execute("SELECT * FROM mytable LIMIT 10") >>> results = cursor.fetchall() """ from pyathena.connection import Connection return Connection(*args, **kwargs)