Edit on GitHub

sqlmesh.core.engine_adapter

 1import typing as t
 2
 3from sqlmesh.core.engine_adapter._typing import PySparkDataFrame
 4from sqlmesh.core.engine_adapter.base import (
 5    EngineAdapter,
 6    EngineAdapterWithIndexSupport,
 7)
 8from sqlmesh.core.engine_adapter.bigquery import BigQueryEngineAdapter
 9from sqlmesh.core.engine_adapter.databricks import DatabricksSparkSessionEngineAdapter
10from sqlmesh.core.engine_adapter.databricks_api import DatabricksSQLEngineAdapter
11from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
12from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter
13from sqlmesh.core.engine_adapter.shared import TransactionType
14from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter
15from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
16
17DIALECT_TO_ENGINE_ADAPTER = {
18    "spark": SparkEngineAdapter,
19    "bigquery": BigQueryEngineAdapter,
20    "duckdb": DuckDBEngineAdapter,
21    "snowflake": SnowflakeEngineAdapter,
22    "databricks": DatabricksSparkSessionEngineAdapter,
23    "redshift": RedshiftEngineAdapter,
24    "postgres": EngineAdapterWithIndexSupport,
25    "mysql": EngineAdapterWithIndexSupport,
26    "mssql": EngineAdapterWithIndexSupport,
27}
28
29
30def create_engine_adapter(
31    connection_factory: t.Callable[[], t.Any], dialect: str, multithreaded: bool = False
32) -> EngineAdapter:
33    dialect = dialect.lower()
34    if dialect == "postgresql":
35        dialect = "postgres"
36    if dialect == "databricks":
37        try:
38            from pyspark.sql import SparkSession
39
40            spark = SparkSession.getActiveSession()
41            if spark:
42                engine_adapter: t.Optional[
43                    t.Type[EngineAdapter]
44                ] = DatabricksSparkSessionEngineAdapter
45            else:
46                engine_adapter = DatabricksSQLEngineAdapter
47        except ImportError:
48            engine_adapter = DatabricksSQLEngineAdapter
49    else:
50        engine_adapter = DIALECT_TO_ENGINE_ADAPTER.get(dialect)
51    if engine_adapter is None:
52        return EngineAdapter(connection_factory, dialect, multithreaded=multithreaded)
53    if engine_adapter is EngineAdapterWithIndexSupport:
54        return EngineAdapterWithIndexSupport(
55            connection_factory, dialect, multithreaded=multithreaded
56        )
57    return engine_adapter(connection_factory, multithreaded=multithreaded)
def create_engine_adapter( connection_factory: Callable[[], Any], dialect: str, multithreaded: bool = False) -> sqlmesh.core.engine_adapter.base.EngineAdapter:
31def create_engine_adapter(
32    connection_factory: t.Callable[[], t.Any], dialect: str, multithreaded: bool = False
33) -> EngineAdapter:
34    dialect = dialect.lower()
35    if dialect == "postgresql":
36        dialect = "postgres"
37    if dialect == "databricks":
38        try:
39            from pyspark.sql import SparkSession
40
41            spark = SparkSession.getActiveSession()
42            if spark:
43                engine_adapter: t.Optional[
44                    t.Type[EngineAdapter]
45                ] = DatabricksSparkSessionEngineAdapter
46            else:
47                engine_adapter = DatabricksSQLEngineAdapter
48        except ImportError:
49            engine_adapter = DatabricksSQLEngineAdapter
50    else:
51        engine_adapter = DIALECT_TO_ENGINE_ADAPTER.get(dialect)
52    if engine_adapter is None:
53        return EngineAdapter(connection_factory, dialect, multithreaded=multithreaded)
54    if engine_adapter is EngineAdapterWithIndexSupport:
55        return EngineAdapterWithIndexSupport(
56            connection_factory, dialect, multithreaded=multithreaded
57        )
58    return engine_adapter(connection_factory, multithreaded=multithreaded)