Source code for luigi.contrib.pyspark_runner
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2012-2020 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
The pyspark program.
This module will be run by spark-submit for PySparkTask jobs.
The first argument is a path to the pickled instance of the PySparkTask,
other arguments are the ones returned by PySparkTask.app_options()
"""
import abc
import logging
import os
import pickle
import sys
from luigi import configuration
# this prevents the modules in the directory of this script from shadowing global packages
sys.path.append(sys.path.pop(0))
class _SparkEntryPoint(metaclass=abc.ABCMeta):
def __init__(self, conf):
self.conf = conf
@abc.abstractmethod
def __enter__(self):
pass
@abc.abstractmethod
def __exit__(self, exc_type, exc_val, exc_tb):
pass
[docs]
class SparkContextEntryPoint(_SparkEntryPoint):
sc = None
def __enter__(self):
from pyspark import SparkContext
self.sc = SparkContext(conf=self.conf)
return self.sc, self.sc
def __exit__(self, exc_type, exc_val, exc_tb):
self.sc.stop()
[docs]
class SparkSessionEntryPoint(_SparkEntryPoint):
spark = None
def _check_major_spark_version(self):
from pyspark import __version__ as spark_version
major_version = int(spark_version.split('.')[0])
if major_version < 2:
raise RuntimeError(
'''
Apache Spark {} does not support SparkSession entrypoint.
Try to set 'pyspark_runner.use_spark_session' to 'False' and switch to old-style syntax
'''.format(spark_version)
)
def __enter__(self):
self._check_major_spark_version()
from pyspark.sql import SparkSession
self.spark = SparkSession \
.builder \
.config(conf=self.conf) \
.enableHiveSupport() \
.getOrCreate()
return self.spark, self.spark.sparkContext
def __exit__(self, exc_type, exc_val, exc_tb):
self.spark.stop()
[docs]
class AbstractPySparkRunner(object):
_entry_point_class = None
def __init__(self, job, *args):
# Append job directory to PYTHON_PATH to enable dynamic import
# of the module in which the class resides on unpickling
sys.path.append(os.path.dirname(job))
with open(job, "rb") as fd:
self.job = pickle.load(fd)
self.args = args
[docs]
def run(self):
from pyspark import SparkConf
conf = SparkConf()
self.job.setup(conf)
with self._entry_point_class(conf=conf) as (entry_point, sc):
self.job.setup_remote(sc)
self.job.main(entry_point, *self.args)
def _pyspark_runner_with(name, entry_point_class):
return type(name, (AbstractPySparkRunner,), {'_entry_point_class': entry_point_class})
PySparkRunner = _pyspark_runner_with('PySparkRunner', SparkContextEntryPoint)
PySparkSessionRunner = _pyspark_runner_with('PySparkSessionRunner', SparkSessionEntryPoint)
def _use_spark_session():
return bool(configuration.get_config().get('pyspark_runner', "use_spark_session", False))
def _get_runner_class():
if _use_spark_session():
return PySparkSessionRunner
return PySparkRunner
if __name__ == '__main__':
logging.basicConfig(level=logging.WARN)
_get_runner_class()(*sys.argv[1:]).run()