Edit on GitHub

sqlmesh.core.engine_adapter.snowflake

 1from __future__ import annotations
 2
 3import typing as t
 4
 5import pandas as pd
 6from sqlglot import exp, parse_one
 7
 8from sqlmesh.core.engine_adapter.base import EngineAdapter
 9from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
10from sqlmesh.utils import nullsafe_join
11
12if t.TYPE_CHECKING:
13    from sqlmesh.core.engine_adapter._typing import DF
14
15
16class SnowflakeEngineAdapter(EngineAdapter):
17    DEFAULT_SQL_GEN_KWARGS = {"identify": False}
18    DIALECT = "snowflake"
19    ESCAPE_JSON = True
20
21    def _fetch_native_df(self, query: t.Union[exp.Expression, str]) -> DF:
22        from snowflake.connector.errors import NotSupportedError
23
24        self.execute(query)
25
26        try:
27            df = self.cursor.fetch_pandas_all()
28        except NotSupportedError:
29            # Sometimes Snowflake will not return results as an Arrow result and the fetch from
30            # pandas will fail (Ex: `SHOW TERSE OBJECTS IN SCHEMA`). Therefore we manually convert
31            # the result into a DataFrame when this happens.
32            rows = self.cursor.fetchall()
33            columns = self.cursor._result_set.batches[0].column_names
34            df = pd.DataFrame([dict(zip(columns, row)) for row in rows])
35
36        # Snowflake returns uppercase column names if the columns are not quoted (so case-insensitive)
37        # so replace the column names returned by Snowflake with the column names in the expression
38        # if the expression was a select expression
39        if isinstance(query, str):
40            parsed_query = parse_one(query, read=self.dialect)
41            if parsed_query is None:
42                # If we didn't get a result from parsing we will just optimistically assume that the df is fine
43                return df
44            query = parsed_query
45        if isinstance(query, exp.Subqueryable):
46            df.columns = query.named_selects
47        return df
48
49    def _get_data_objects(
50        self, schema_name: str, catalog_name: t.Optional[str] = None
51    ) -> t.List[DataObject]:
52        """
53        Returns all the data objects that exist in the given schema and optionally catalog.
54        """
55        target = nullsafe_join(".", catalog_name, schema_name)
56        sql = f"SHOW TERSE OBJECTS IN {target}"
57        df = self.fetchdf(sql)
58        return [
59            DataObject(
60                catalog=row.database_name,  # type: ignore
61                schema=row.schema_name,  # type: ignore
62                name=row.name,  # type: ignore
63                type=DataObjectType.from_str(row.kind),  # type: ignore
64            )
65            for row in df[["database_name", "schema_name", "name", "kind"]].itertuples()
66        ]
class SnowflakeEngineAdapter(sqlmesh.core.engine_adapter.base.EngineAdapter):
17class SnowflakeEngineAdapter(EngineAdapter):
18    DEFAULT_SQL_GEN_KWARGS = {"identify": False}
19    DIALECT = "snowflake"
20    ESCAPE_JSON = True
21
22    def _fetch_native_df(self, query: t.Union[exp.Expression, str]) -> DF:
23        from snowflake.connector.errors import NotSupportedError
24
25        self.execute(query)
26
27        try:
28            df = self.cursor.fetch_pandas_all()
29        except NotSupportedError:
30            # Sometimes Snowflake will not return results as an Arrow result and the fetch from
31            # pandas will fail (Ex: `SHOW TERSE OBJECTS IN SCHEMA`). Therefore we manually convert
32            # the result into a DataFrame when this happens.
33            rows = self.cursor.fetchall()
34            columns = self.cursor._result_set.batches[0].column_names
35            df = pd.DataFrame([dict(zip(columns, row)) for row in rows])
36
37        # Snowflake returns uppercase column names if the columns are not quoted (so case-insensitive)
38        # so replace the column names returned by Snowflake with the column names in the expression
39        # if the expression was a select expression
40        if isinstance(query, str):
41            parsed_query = parse_one(query, read=self.dialect)
42            if parsed_query is None:
43                # If we didn't get a result from parsing we will just optimistically assume that the df is fine
44                return df
45            query = parsed_query
46        if isinstance(query, exp.Subqueryable):
47            df.columns = query.named_selects
48        return df
49
50    def _get_data_objects(
51        self, schema_name: str, catalog_name: t.Optional[str] = None
52    ) -> t.List[DataObject]:
53        """
54        Returns all the data objects that exist in the given schema and optionally catalog.
55        """
56        target = nullsafe_join(".", catalog_name, schema_name)
57        sql = f"SHOW TERSE OBJECTS IN {target}"
58        df = self.fetchdf(sql)
59        return [
60            DataObject(
61                catalog=row.database_name,  # type: ignore
62                schema=row.schema_name,  # type: ignore
63                name=row.name,  # type: ignore
64                type=DataObjectType.from_str(row.kind),  # type: ignore
65            )
66            for row in df[["database_name", "schema_name", "name", "kind"]].itertuples()
67        ]

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.