sqlmesh.engines.spark.app
1import argparse 2import logging 3import os 4import tempfile 5 6from pyspark import SparkFiles 7from pyspark.sql import SparkSession 8 9from sqlmesh.core.engine_adapter import create_engine_adapter 10from sqlmesh.core.snapshot import SnapshotEvaluator 11from sqlmesh.engines import commands 12from sqlmesh.engines.spark.db_api import spark_session as spark_session_db 13from sqlmesh.engines.spark.db_api.errors import NotSupportedError 14from sqlmesh.utils.errors import SQLMeshError 15 16logger = logging.getLogger(__name__) 17 18 19def get_or_create_spark_session(dialect: str) -> SparkSession: 20 if dialect == "databricks": 21 spark = SparkSession.getActiveSession() 22 if not spark: 23 raise SQLMeshError("Could not find an active SparkSession.") 24 return spark 25 return ( 26 SparkSession.builder.config("spark.scheduler.mode", "FAIR") 27 .enableHiveSupport() 28 .getOrCreate() 29 ) 30 31 32def main( 33 dialect: str, command_type: commands.CommandType, ddl_concurrent_tasks: int, payload_path: str 34) -> None: 35 if dialect not in ("databricks", "spark"): 36 raise NotSupportedError( 37 f"Dialect '{dialect}' not supported. Must be either 'databricks' or 'spark'" 38 ) 39 logging.basicConfig( 40 format="%(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)", 41 level=logging.INFO, 42 ) 43 command_handler = commands.COMMAND_HANDLERS.get(command_type) 44 if not command_handler: 45 raise NotSupportedError(f"Command '{command_type.value}' not supported") 46 47 spark = get_or_create_spark_session(dialect) 48 49 evaluator = SnapshotEvaluator( 50 create_engine_adapter( 51 lambda: spark_session_db.connection(spark), 52 dialect, 53 multithreaded=ddl_concurrent_tasks > 1, 54 ), 55 ddl_concurrent_tasks=ddl_concurrent_tasks, 56 ) 57 if dialect == "spark": 58 with open(SparkFiles.get(payload_path), "r", encoding="utf-8") as payload_fd: 59 command_payload = payload_fd.read() 60 else: 61 from pyspark.dbutils import DBUtils # type: ignore 62 63 dbutils = DBUtils(spark) 64 with tempfile.TemporaryDirectory() as tmp: 65 local_payload_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME) 66 dbutils.fs.cp(payload_path, f"file://{local_payload_path}") 67 with open(local_payload_path, "r", encoding="utf-8") as payload_fd: 68 command_payload = payload_fd.read() 69 logger.info("Command payload:\n %s", command_payload) 70 command_handler(evaluator, command_payload) 71 72 evaluator.close() 73 74 75if __name__ == "__main__": 76 parser = argparse.ArgumentParser(description="SQLMesh Spark Submit App") 77 parser.add_argument( 78 "--dialect", 79 help="The dialect to use when creating the engine adapter.", 80 ) 81 parser.add_argument( 82 "--command_type", 83 type=commands.CommandType, 84 choices=list(commands.CommandType), 85 help="The type of command that is being run", 86 ) 87 parser.add_argument( 88 "--ddl_concurrent_tasks", 89 type=int, 90 default=1, 91 help="The number of ddl concurrent tasks to use. Default to 1.", 92 ) 93 parser.add_argument( 94 "--payload_path", 95 help="Path to the payload object. Can be a local or remote path.", 96 ) 97 args = parser.parse_args() 98 main(args.dialect, args.command_type, args.ddl_concurrent_tasks, args.payload_path)
def
get_or_create_spark_session(dialect: str) -> pyspark.sql.session.SparkSession:
20def get_or_create_spark_session(dialect: str) -> SparkSession: 21 if dialect == "databricks": 22 spark = SparkSession.getActiveSession() 23 if not spark: 24 raise SQLMeshError("Could not find an active SparkSession.") 25 return spark 26 return ( 27 SparkSession.builder.config("spark.scheduler.mode", "FAIR") 28 .enableHiveSupport() 29 .getOrCreate() 30 )
def
main( dialect: str, command_type: sqlmesh.engines.commands.CommandType, ddl_concurrent_tasks: int, payload_path: str) -> None:
33def main( 34 dialect: str, command_type: commands.CommandType, ddl_concurrent_tasks: int, payload_path: str 35) -> None: 36 if dialect not in ("databricks", "spark"): 37 raise NotSupportedError( 38 f"Dialect '{dialect}' not supported. Must be either 'databricks' or 'spark'" 39 ) 40 logging.basicConfig( 41 format="%(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)", 42 level=logging.INFO, 43 ) 44 command_handler = commands.COMMAND_HANDLERS.get(command_type) 45 if not command_handler: 46 raise NotSupportedError(f"Command '{command_type.value}' not supported") 47 48 spark = get_or_create_spark_session(dialect) 49 50 evaluator = SnapshotEvaluator( 51 create_engine_adapter( 52 lambda: spark_session_db.connection(spark), 53 dialect, 54 multithreaded=ddl_concurrent_tasks > 1, 55 ), 56 ddl_concurrent_tasks=ddl_concurrent_tasks, 57 ) 58 if dialect == "spark": 59 with open(SparkFiles.get(payload_path), "r", encoding="utf-8") as payload_fd: 60 command_payload = payload_fd.read() 61 else: 62 from pyspark.dbutils import DBUtils # type: ignore 63 64 dbutils = DBUtils(spark) 65 with tempfile.TemporaryDirectory() as tmp: 66 local_payload_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME) 67 dbutils.fs.cp(payload_path, f"file://{local_payload_path}") 68 with open(local_payload_path, "r", encoding="utf-8") as payload_fd: 69 command_payload = payload_fd.read() 70 logger.info("Command payload:\n %s", command_payload) 71 command_handler(evaluator, command_payload) 72 73 evaluator.close()