Edit on GitHub

sqlmesh.engines.spark.db_api.spark_session

 1import typing as t
 2from threading import get_ident
 3
 4from pyspark.sql import DataFrame, SparkSession
 5from pyspark.sql.types import Row
 6
 7from sqlmesh.engines.spark.db_api.errors import NotSupportedError, ProgrammingError
 8
 9
10class SparkSessionCursor:
11    def __init__(self, spark: SparkSession):
12        self._spark = spark
13        self._last_df: t.Optional[DataFrame] = None
14        self._last_output: t.Optional[t.List[t.Tuple]] = None
15        self._last_output_cursor: int = 0
16
17    def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None:
18        if parameters:
19            raise NotSupportedError("Parameterized queries are not supported")
20
21        self._last_df = self._spark.sql(query)
22        self._last_output = None
23        self._last_output_cursor = 0
24
25    def fetchone(self) -> t.Optional[t.Tuple]:
26        result = self._fetch(size=1)
27        return result[0] if result else None
28
29    def fetchmany(self, size: int = 1) -> t.List[t.Tuple]:
30        return self._fetch(size=size)
31
32    def fetchall(self) -> t.List[t.Tuple]:
33        return self._fetch()
34
35    def close(self) -> None:
36        pass
37
38    def fetchdf(self) -> t.Optional[DataFrame]:
39        return self._last_df
40
41    def _fetch(self, size: t.Optional[int] = None) -> t.List[t.Tuple]:
42        if size and size < 0:
43            raise ProgrammingError("The size argument can't be negative")
44
45        if self._last_df is None:
46            raise ProgrammingError("No call to .execute() has been issued")
47
48        if self._last_output is None:
49            self._last_output = _normalize_rows(self._last_df.collect())
50
51        if self._last_output_cursor >= len(self._last_output):
52            return []
53
54        if size is None:
55            size = len(self._last_output) - self._last_output_cursor
56
57        output = self._last_output[self._last_output_cursor : self._last_output_cursor + size]
58        self._last_output_cursor += size
59
60        return output
61
62
63class SparkSessionConnection:
64    def __init__(self, spark: SparkSession):
65        self.spark = spark
66
67    def cursor(self) -> SparkSessionCursor:
68        self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
69        return SparkSessionCursor(self.spark)
70
71    def commit(self) -> None:
72        pass
73
74    def rollback(self) -> None:
75        pass
76
77    def close(self) -> None:
78        pass
79
80
81def connection(spark: SparkSession) -> SparkSessionConnection:
82    return SparkSessionConnection(spark)
83
84
85def _normalize_rows(rows: t.Sequence[Row]) -> t.List[t.Tuple]:
86    return [tuple(r) for r in rows]
class SparkSessionCursor:
11class SparkSessionCursor:
12    def __init__(self, spark: SparkSession):
13        self._spark = spark
14        self._last_df: t.Optional[DataFrame] = None
15        self._last_output: t.Optional[t.List[t.Tuple]] = None
16        self._last_output_cursor: int = 0
17
18    def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None:
19        if parameters:
20            raise NotSupportedError("Parameterized queries are not supported")
21
22        self._last_df = self._spark.sql(query)
23        self._last_output = None
24        self._last_output_cursor = 0
25
26    def fetchone(self) -> t.Optional[t.Tuple]:
27        result = self._fetch(size=1)
28        return result[0] if result else None
29
30    def fetchmany(self, size: int = 1) -> t.List[t.Tuple]:
31        return self._fetch(size=size)
32
33    def fetchall(self) -> t.List[t.Tuple]:
34        return self._fetch()
35
36    def close(self) -> None:
37        pass
38
39    def fetchdf(self) -> t.Optional[DataFrame]:
40        return self._last_df
41
42    def _fetch(self, size: t.Optional[int] = None) -> t.List[t.Tuple]:
43        if size and size < 0:
44            raise ProgrammingError("The size argument can't be negative")
45
46        if self._last_df is None:
47            raise ProgrammingError("No call to .execute() has been issued")
48
49        if self._last_output is None:
50            self._last_output = _normalize_rows(self._last_df.collect())
51
52        if self._last_output_cursor >= len(self._last_output):
53            return []
54
55        if size is None:
56            size = len(self._last_output) - self._last_output_cursor
57
58        output = self._last_output[self._last_output_cursor : self._last_output_cursor + size]
59        self._last_output_cursor += size
60
61        return output
SparkSessionCursor(spark: pyspark.sql.session.SparkSession)
12    def __init__(self, spark: SparkSession):
13        self._spark = spark
14        self._last_df: t.Optional[DataFrame] = None
15        self._last_output: t.Optional[t.List[t.Tuple]] = None
16        self._last_output_cursor: int = 0
def execute(self, query: str, parameters: Optional[Any] = None) -> None:
18    def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None:
19        if parameters:
20            raise NotSupportedError("Parameterized queries are not supported")
21
22        self._last_df = self._spark.sql(query)
23        self._last_output = None
24        self._last_output_cursor = 0
def fetchone(self) -> Optional[Tuple]:
26    def fetchone(self) -> t.Optional[t.Tuple]:
27        result = self._fetch(size=1)
28        return result[0] if result else None
def fetchmany(self, size: int = 1) -> List[Tuple]:
30    def fetchmany(self, size: int = 1) -> t.List[t.Tuple]:
31        return self._fetch(size=size)
def fetchall(self) -> List[Tuple]:
33    def fetchall(self) -> t.List[t.Tuple]:
34        return self._fetch()
def close(self) -> None:
36    def close(self) -> None:
37        pass
def fetchdf(self) -> Optional[pyspark.sql.dataframe.DataFrame]:
39    def fetchdf(self) -> t.Optional[DataFrame]:
40        return self._last_df
class SparkSessionConnection:
64class SparkSessionConnection:
65    def __init__(self, spark: SparkSession):
66        self.spark = spark
67
68    def cursor(self) -> SparkSessionCursor:
69        self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
70        return SparkSessionCursor(self.spark)
71
72    def commit(self) -> None:
73        pass
74
75    def rollback(self) -> None:
76        pass
77
78    def close(self) -> None:
79        pass
SparkSessionConnection(spark: pyspark.sql.session.SparkSession)
65    def __init__(self, spark: SparkSession):
66        self.spark = spark
68    def cursor(self) -> SparkSessionCursor:
69        self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
70        return SparkSessionCursor(self.spark)
def commit(self) -> None:
72    def commit(self) -> None:
73        pass
def rollback(self) -> None:
75    def rollback(self) -> None:
76        pass
def close(self) -> None:
78    def close(self) -> None:
79        pass
def connection( spark: pyspark.sql.session.SparkSession) -> sqlmesh.engines.spark.db_api.spark_session.SparkSessionConnection:
82def connection(spark: SparkSession) -> SparkSessionConnection:
83    return SparkSessionConnection(spark)