Source code for luigi.contrib.pyspark_runner

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2012-2015 Spotify AB
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

The pyspark program.

This module will be run by spark-submit for PySparkTask jobs.

The first argument is a path to the pickled instance of the PySparkTask,
other arguments are the ones returned by PySparkTask.app_options()


from __future__ import print_function

import abc

    import cPickle as pickle
except ImportError:
    import pickle
import logging
import sys
import os

from luigi import configuration
from luigi import six

# this prevents the modules in the directory of this script from shadowing global packages

class _SparkEntryPoint(object):
    def __init__(self, conf):
        self.conf = conf

    def __enter__(self):

    def __exit__(self, exc_type, exc_val, exc_tb):

[docs]class SparkContextEntryPoint(_SparkEntryPoint): sc = None def __enter__(self): from pyspark import SparkContext = SparkContext(conf=self.conf) return, def __exit__(self, exc_type, exc_val, exc_tb):
[docs]class SparkSessionEntryPoint(_SparkEntryPoint): spark = None def _check_major_spark_version(self): from pyspark import __version__ as spark_version major_version = int(spark_version.split('.')[0]) if major_version < 2: raise RuntimeError( ''' Apache Spark {} does not support SparkSession entrypoint. Try to set 'pyspark_runner.use_spark_session' to 'False' and switch to old-style syntax '''.format(spark_version) ) def __enter__(self): self._check_major_spark_version() from pyspark.sql import SparkSession self.spark = SparkSession \ .builder \ .config(conf=self.conf) \ .enableHiveSupport() \ .getOrCreate() return self.spark, self.spark.sparkContext def __exit__(self, exc_type, exc_val, exc_tb): self.spark.stop()
[docs]class AbstractPySparkRunner(object): _entry_point_class = None def __init__(self, job, *args): # Append job directory to PYTHON_PATH to enable dynamic import # of the module in which the class resides on unpickling sys.path.append(os.path.dirname(job)) with open(job, "rb") as fd: self.job = pickle.load(fd) self.args = args
[docs] def run(self): from pyspark import SparkConf conf = SparkConf() self.job.setup(conf) with self._entry_point_class(conf=conf) as (entry_point, sc): self.job.setup_remote(sc) self.job.main(entry_point, *self.args)
def _pyspark_runner_with(name, entry_point_class): return type(name, (AbstractPySparkRunner,), {'_entry_point_class': entry_point_class}) PySparkRunner = _pyspark_runner_with('PySparkRunner', SparkContextEntryPoint) PySparkSessionRunner = _pyspark_runner_with('PySparkSessionRunner', SparkSessionEntryPoint) def _use_spark_session(): return bool(configuration.get_config().get('pyspark_runner', "use_spark_session", False)) def _get_runner_class(): if _use_spark_session(): return PySparkSessionRunner return PySparkRunner if __name__ == '__main__': logging.basicConfig(level=logging.WARN) _get_runner_class()(*sys.argv[1:]).run()