Source code for pyathena.spark.cursor

# -*- coding: utf-8 -*-
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional, Union, cast

from pyathena import OperationalError, ProgrammingError
from pyathena.model import AthenaCalculationExecution, AthenaCalculationExecutionStatus
from pyathena.spark.common import SparkBaseCursor, WithCalculationExecution

_logger = logging.getLogger(__name__)  # type: ignore


[docs] class SparkCursor(SparkBaseCursor, WithCalculationExecution): """Cursor for executing PySpark code on Amazon Athena for Apache Spark. This cursor allows you to execute PySpark code directly on Athena's managed Spark environment. It's designed for big data processing, ETL operations, and machine learning workloads that require Spark's distributed computing capabilities. The cursor manages Spark sessions automatically and provides an interface similar to other PyAthena cursors but optimized for Spark calculations rather than SQL queries. Attributes: session_id: The Athena Spark session ID. description: Optional description for the Spark session. engine_configuration: Spark engine configuration settings. calculation_id: ID of the current calculation being executed. Example: >>> from pyathena.spark.cursor import SparkCursor >>> cursor = connection.cursor(SparkCursor) >>> >>> # Execute PySpark code >>> spark_code = ''' ... df = spark.read.table("my_database.my_table") ... result = df.groupBy("category").count() ... result.show() ... ''' >>> cursor.execute(spark_code) >>> result = cursor.fetchall() # Configure Spark session >>> cursor = connection.cursor( ... SparkCursor, ... engine_configuration={ ... 'CoordinatorDpuSize': 1, ... 'MaxConcurrentDpus': 20, ... 'DefaultExecutorDpuSize': 1 ... } ... ) Note: Requires an Athena workgroup configured for Spark calculations. Spark sessions have associated costs and idle timeout settings. """
[docs] def __init__( self, session_id: Optional[str] = None, description: Optional[str] = None, engine_configuration: Optional[Dict[str, Any]] = None, notebook_version: Optional[str] = None, session_idle_timeout_minutes: Optional[int] = None, **kwargs, ) -> None: super().__init__( session_id=session_id, description=description, engine_configuration=engine_configuration, notebook_version=notebook_version, session_idle_timeout_minutes=session_idle_timeout_minutes, **kwargs, )
@property def calculation_execution(self) -> Optional[AthenaCalculationExecution]: return self._calculation_execution
[docs] def get_std_out(self) -> Optional[str]: """Get the standard output from the Spark calculation execution. Retrieves and returns the contents of the standard output generated during the Spark calculation execution, if available. Returns: The standard output as a string, or None if no output is available or the calculation has not been executed. """ if not self._calculation_execution or not self._calculation_execution.std_out_s3_uri: return None return self._read_s3_file_as_text(self._calculation_execution.std_out_s3_uri)
[docs] def get_std_error(self) -> Optional[str]: """Get the standard error from the Spark calculation execution. Retrieves and returns the contents of the standard error generated during the Spark calculation execution, if available. This is useful for debugging failed or problematic Spark operations. Returns: The standard error as a string, or None if no error output is available or the calculation has not been executed. """ if not self._calculation_execution or not self._calculation_execution.std_error_s3_uri: return None return self._read_s3_file_as_text(self._calculation_execution.std_error_s3_uri)
[docs] def execute( self, operation: str, parameters: Optional[Union[Dict[str, Any], List[str]]] = None, session_id: Optional[str] = None, description: Optional[str] = None, client_request_token: Optional[str] = None, work_group: Optional[str] = None, **kwargs, ) -> SparkCursor: self._calculation_id = self._calculate( session_id=session_id if session_id else self._session_id, code_block=operation, description=description, client_request_token=client_request_token, ) self._calculation_execution = cast( AthenaCalculationExecution, self._poll(self._calculation_id) ) if self._calculation_execution.state != AthenaCalculationExecutionStatus.STATE_COMPLETED: std_error = self.get_std_error() raise OperationalError(std_error) return self
[docs] def cancel(self) -> None: if not self.calculation_id: raise ProgrammingError("CalculationExecutionId is none or empty.") self._cancel(self.calculation_id)