# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
import time
from abc import ABCMeta, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Union, cast
import botocore
from pyathena import NotSupportedError, OperationalError
from pyathena.common import BaseCursor
from pyathena.model import (
AthenaCalculationExecution,
AthenaCalculationExecutionStatus,
AthenaQueryExecution,
AthenaSessionStatus,
)
from pyathena.util import parse_output_location, retry_api_call
_logger = logging.getLogger(__name__) # type: ignore
[docs]
class SparkBaseCursor(BaseCursor, metaclass=ABCMeta):
"""Abstract base class for Spark-enabled cursor implementations.
This class provides the foundational functionality for executing PySpark code
on Amazon Athena for Apache Spark. It manages Spark sessions, handles
calculation execution lifecycle, and provides utilities for reading
results from S3.
Features:
- Automatic Spark session management and lifecycle
- Configurable engine resources (DPU allocation)
- Session idle timeout and automatic cleanup
- Standard output and error stream access via S3
- Calculation execution status monitoring
- Session validation and error handling
Attributes:
session_id: The Athena Spark session identifier.
calculation_id: ID of the current calculation being executed.
engine_configuration: DPU and resource configuration for Spark.
Note:
This is an abstract base class used by concrete Spark cursor implementations
like SparkCursor and AsyncSparkCursor. It should not be instantiated directly.
"""
[docs]
def __init__(
self,
session_id: Optional[str] = None,
description: Optional[str] = None,
engine_configuration: Optional[Dict[str, Any]] = None,
notebook_version: Optional[str] = None,
session_idle_timeout_minutes: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self._engine_configuration = (
engine_configuration
if engine_configuration
else self.get_default_engine_configuration()
)
self._notebook_version = notebook_version
self._session_description = description
self._session_idle_timeout_minutes = session_idle_timeout_minutes
if session_id:
if self._exists_session(session_id):
self._session_id = session_id
else:
raise OperationalError(f"Session: {session_id} not found.")
else:
self._session_id = self._start_session()
self._calculation_id: Optional[str] = None
self._calculation_execution: Optional[AthenaCalculationExecution] = None
self._client = self.connection.session.client(
"s3",
region_name=self.connection.region_name,
config=self.connection.config,
**self.connection._client_kwargs,
)
@property
def session_id(self) -> str:
return self._session_id
@property
def calculation_id(self) -> Optional[str]:
return self._calculation_id
[docs]
@staticmethod
def get_default_engine_configuration() -> Dict[str, Any]:
return {
"CoordinatorDpuSize": 1,
"MaxConcurrentDpus": 2,
"DefaultExecutorDpuSize": 1,
}
def _read_s3_file_as_text(self, uri) -> str:
bucket, key = parse_output_location(uri)
response = retry_api_call(
self._client.get_object,
config=self._retry_config,
logger=_logger,
Bucket=bucket,
Key=key,
)
return cast(str, response["Body"].read().decode("utf-8").strip())
def _get_session_status(self, session_id: str):
request: Dict[str, Any] = {"SessionId": session_id}
try:
response = retry_api_call(
self._connection.client.get_session_status,
config=self._retry_config,
logger=_logger,
**request,
)
except Exception as e:
_logger.exception("Failed to get session status.")
raise OperationalError(*e.args) from e
else:
return AthenaSessionStatus(response)
def _wait_for_idle_session(self, session_id: str):
while True:
session_status = self._get_session_status(session_id)
if session_status.state in [AthenaSessionStatus.STATE_IDLE]:
break
if session_status in [
AthenaSessionStatus.STATE_TERMINATED,
AthenaSessionStatus.STATE_DEGRADED,
AthenaSessionStatus.STATE_FAILED,
]:
raise OperationalError(session_status.state_change_reason)
time.sleep(self._poll_interval)
def _exists_session(self, session_id: str) -> bool:
request = {"SessionId": session_id}
try:
retry_api_call(
self._connection.client.get_session,
config=self._retry_config,
logger=_logger,
**request,
)
except Exception as e:
if (
isinstance(e, botocore.exceptions.ClientError)
and e.response["Error"]["Code"] == "InvalidRequestException"
):
_logger.exception(f"Session: {session_id} not found.")
return False
raise OperationalError(*e.args) from e
else:
self._wait_for_idle_session(session_id)
return True
def _start_session(self) -> str:
request: Dict[str, Any] = {
"WorkGroup": self._work_group,
"EngineConfiguration": self._engine_configuration,
}
if self._session_description:
request.update({"Description": self._session_description})
if self._notebook_version:
request.update({"NotebookVersion": self._notebook_version})
if self._session_idle_timeout_minutes:
request.update({"SessionIdleTimeoutInMinutes": self._session_idle_timeout_minutes})
try:
session_id: str = retry_api_call(
self._connection.client.start_session,
config=self._retry_config,
logger=_logger,
**request,
)["SessionId"]
except Exception as e:
_logger.exception("Failed to start session.")
raise OperationalError(*e.args) from e
else:
self._wait_for_idle_session(session_id)
return session_id
def _terminate_session(self) -> None:
request = {"SessionId": self._session_id}
try:
retry_api_call(
self._connection.client.terminate_session,
config=self._retry_config,
logger=_logger,
**request,
)
except Exception as e:
_logger.exception("Failed to terminate session.")
raise OperationalError(*e.args) from e
def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]:
while True:
calculation_status = self._get_calculation_execution_status(query_id)
if calculation_status.state in [
AthenaCalculationExecutionStatus.STATE_COMPLETED,
AthenaCalculationExecutionStatus.STATE_FAILED,
AthenaCalculationExecutionStatus.STATE_CANCELED,
]:
return self._get_calculation_execution(query_id)
time.sleep(self._poll_interval)
def _poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]:
try:
query_execution = self.__poll(query_id)
except KeyboardInterrupt as e:
if self._kill_on_interrupt:
_logger.warning("Query canceled by user.")
self._cancel(query_id)
query_execution = self.__poll(query_id)
else:
raise e
return query_execution
def _cancel(self, query_id: str) -> None:
request = {"CalculationExecutionId": query_id}
try:
retry_api_call(
self._connection.client.stop_calculation_execution,
config=self._retry_config,
logger=_logger,
**request,
)
except Exception as e:
_logger.exception("Failed to cancel calculation.")
raise OperationalError(*e.args) from e
[docs]
def close(self) -> None:
self._terminate_session()
[docs]
def executemany(
self,
operation: str,
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
**kwargs,
) -> None:
raise NotSupportedError
[docs]
class WithCalculationExecution:
"""Mixin class providing access to Spark calculation execution properties.
This mixin provides property accessors for calculation execution metadata
and status information. It's designed to be mixed with cursor classes
that execute Spark calculations on Athena.
Properties:
- description: Human-readable description of the calculation
- working_directory: S3 path where calculation files are stored
- state: Current execution state (COMPLETED, FAILED, etc.)
- state_change_reason: Explanation for state changes
- submission_date_time: When the calculation was submitted
- completion_date_time: When the calculation completed
- dpu_execution_in_millis: DPU execution time in milliseconds
- progress: Current execution progress information
- std_out_s3_uri: S3 URI for standard output
- std_error_s3_uri: S3 URI for standard error
- result_s3_uri: S3 URI for calculation results
- result_type: Type of result produced by the calculation
Note:
This class requires that the implementing class provides
calculation_execution, session_id, and calculation_id properties.
"""
[docs]
def __init__(self):
super().__init__()
@property
@abstractmethod
def calculation_execution(self) -> Optional[AthenaCalculationExecution]:
raise NotImplementedError # pragma: no cover
@property
@abstractmethod
def session_id(self) -> str:
raise NotImplementedError # pragma: no cover
@property
@abstractmethod
def calculation_id(self) -> Optional[str]:
raise NotImplementedError # pragma: no cover
@property
def description(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.description
@property
def working_directory(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.working_directory
@property
def state(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.state
@property
def state_change_reason(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.state_change_reason
@property
def submission_date_time(self) -> Optional[datetime]:
if not self.calculation_execution:
return None
return self.calculation_execution.submission_date_time
@property
def completion_date_time(self) -> Optional[datetime]:
if not self.calculation_execution:
return None
return self.calculation_execution.completion_date_time
@property
def dpu_execution_in_millis(self) -> Optional[int]:
if not self.calculation_execution:
return None
return self.calculation_execution.dpu_execution_in_millis
@property
def progress(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.progress
@property
def std_out_s3_uri(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.std_out_s3_uri
@property
def std_error_s3_uri(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.std_error_s3_uri
@property
def result_s3_uri(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.result_s3_uri
@property
def result_type(self) -> Optional[str]:
if not self.calculation_execution:
return None
return self.calculation_execution.result_type