Edit on GitHub

sqlmesh.engines.spark.app

 1import argparse
 2import logging
 3import os
 4import tempfile
 5
 6from pyspark import SparkFiles
 7from pyspark.sql import SparkSession
 8
 9from sqlmesh.core.engine_adapter import create_engine_adapter
10from sqlmesh.core.snapshot import SnapshotEvaluator
11from sqlmesh.engines import commands
12from sqlmesh.engines.spark.db_api import spark_session as spark_session_db
13from sqlmesh.engines.spark.db_api.errors import NotSupportedError
14from sqlmesh.utils.errors import SQLMeshError
15
16logger = logging.getLogger(__name__)
17
18
19def get_or_create_spark_session(dialect: str) -> SparkSession:
20    if dialect == "databricks":
21        spark = SparkSession.getActiveSession()
22        if not spark:
23            raise SQLMeshError("Could not find an active SparkSession.")
24        return spark
25    return (
26        SparkSession.builder.config("spark.scheduler.mode", "FAIR")
27        .enableHiveSupport()
28        .getOrCreate()
29    )
30
31
32def main(
33    dialect: str, command_type: commands.CommandType, ddl_concurrent_tasks: int, payload_path: str
34) -> None:
35    if dialect not in ("databricks", "spark"):
36        raise NotSupportedError(
37            f"Dialect '{dialect}' not supported. Must be either 'databricks' or 'spark'"
38        )
39    logging.basicConfig(
40        format="%(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)",
41        level=logging.INFO,
42    )
43    command_handler = commands.COMMAND_HANDLERS.get(command_type)
44    if not command_handler:
45        raise NotSupportedError(f"Command '{command_type.value}' not supported")
46
47    spark = get_or_create_spark_session(dialect)
48
49    evaluator = SnapshotEvaluator(
50        create_engine_adapter(
51            lambda: spark_session_db.connection(spark),
52            dialect,
53            multithreaded=ddl_concurrent_tasks > 1,
54        ),
55        ddl_concurrent_tasks=ddl_concurrent_tasks,
56    )
57    if dialect == "spark":
58        with open(SparkFiles.get(payload_path), "r", encoding="utf-8") as payload_fd:
59            command_payload = payload_fd.read()
60    else:
61        from pyspark.dbutils import DBUtils  # type: ignore
62
63        dbutils = DBUtils(spark)
64        with tempfile.TemporaryDirectory() as tmp:
65            local_payload_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME)
66            dbutils.fs.cp(payload_path, f"file://{local_payload_path}")
67            with open(local_payload_path, "r", encoding="utf-8") as payload_fd:
68                command_payload = payload_fd.read()
69    logger.info("Command payload:\n %s", command_payload)
70    command_handler(evaluator, command_payload)
71
72    evaluator.close()
73
74
75if __name__ == "__main__":
76    parser = argparse.ArgumentParser(description="SQLMesh Spark Submit App")
77    parser.add_argument(
78        "--dialect",
79        help="The dialect to use when creating the engine adapter.",
80    )
81    parser.add_argument(
82        "--command_type",
83        type=commands.CommandType,
84        choices=list(commands.CommandType),
85        help="The type of command that is being run",
86    )
87    parser.add_argument(
88        "--ddl_concurrent_tasks",
89        type=int,
90        default=1,
91        help="The number of ddl concurrent tasks to use. Default to 1.",
92    )
93    parser.add_argument(
94        "--payload_path",
95        help="Path to the payload object. Can be a local or remote path.",
96    )
97    args = parser.parse_args()
98    main(args.dialect, args.command_type, args.ddl_concurrent_tasks, args.payload_path)
def get_or_create_spark_session(dialect: str) -> pyspark.sql.session.SparkSession:
20def get_or_create_spark_session(dialect: str) -> SparkSession:
21    if dialect == "databricks":
22        spark = SparkSession.getActiveSession()
23        if not spark:
24            raise SQLMeshError("Could not find an active SparkSession.")
25        return spark
26    return (
27        SparkSession.builder.config("spark.scheduler.mode", "FAIR")
28        .enableHiveSupport()
29        .getOrCreate()
30    )
def main( dialect: str, command_type: sqlmesh.engines.commands.CommandType, ddl_concurrent_tasks: int, payload_path: str) -> None:
33def main(
34    dialect: str, command_type: commands.CommandType, ddl_concurrent_tasks: int, payload_path: str
35) -> None:
36    if dialect not in ("databricks", "spark"):
37        raise NotSupportedError(
38            f"Dialect '{dialect}' not supported. Must be either 'databricks' or 'spark'"
39        )
40    logging.basicConfig(
41        format="%(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)",
42        level=logging.INFO,
43    )
44    command_handler = commands.COMMAND_HANDLERS.get(command_type)
45    if not command_handler:
46        raise NotSupportedError(f"Command '{command_type.value}' not supported")
47
48    spark = get_or_create_spark_session(dialect)
49
50    evaluator = SnapshotEvaluator(
51        create_engine_adapter(
52            lambda: spark_session_db.connection(spark),
53            dialect,
54            multithreaded=ddl_concurrent_tasks > 1,
55        ),
56        ddl_concurrent_tasks=ddl_concurrent_tasks,
57    )
58    if dialect == "spark":
59        with open(SparkFiles.get(payload_path), "r", encoding="utf-8") as payload_fd:
60            command_payload = payload_fd.read()
61    else:
62        from pyspark.dbutils import DBUtils  # type: ignore
63
64        dbutils = DBUtils(spark)
65        with tempfile.TemporaryDirectory() as tmp:
66            local_payload_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME)
67            dbutils.fs.cp(payload_path, f"file://{local_payload_path}")
68            with open(local_payload_path, "r", encoding="utf-8") as payload_fd:
69                command_payload = payload_fd.read()
70    logger.info("Command payload:\n %s", command_payload)
71    command_handler(evaluator, command_payload)
72
73    evaluator.close()