# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import collections
import logging
import operator
import os
import re
import subprocess
import sys
import tempfile
import warnings
from luigi import six
import luigi
import luigi.contrib.hadoop
from luigi.contrib.hdfs import get_autoconfig_client
from luigi.target import FileAlreadyExists, FileSystemTarget
from luigi.task import flatten
if six.PY3:
unicode = str
logger = logging.getLogger('luigi-interface')
[docs]class HiveCommandError(RuntimeError):
def __init__(self, message, out=None, err=None):
super(HiveCommandError, self).__init__(message, out, err)
self.message = message
self.out = out
self.err = err
[docs]def load_hive_cmd():
return luigi.configuration.get_config().get('hive', 'command', 'hive').split(' ')
[docs]def get_hive_syntax():
return luigi.configuration.get_config().get('hive', 'release', 'cdh4')
[docs]def get_hive_warehouse_location():
return luigi.configuration.get_config().get('hive', 'warehouse_location', '/user/hive/warehouse')
[docs]def get_ignored_file_masks():
return luigi.configuration.get_config().get('hive', 'ignored_file_masks', None)
[docs]def run_hive(args, check_return_code=True):
"""
Runs the `hive` from the command line, passing in the given args, and
returning stdout.
With the apache release of Hive, so of the table existence checks
(which are done using DESCRIBE do not exit with a return code of 0
so we need an option to ignore the return code and just return stdout for parsing
"""
cmd = load_hive_cmd() + args
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = p.communicate()
if check_return_code and p.returncode != 0:
raise HiveCommandError("Hive command: {0} failed with error code: {1}".format(" ".join(cmd), p.returncode),
stdout, stderr)
return stdout.decode('utf-8')
[docs]def run_hive_cmd(hivecmd, check_return_code=True):
"""
Runs the given hive query and returns stdout.
"""
return run_hive(['-e', hivecmd], check_return_code)
[docs]def run_hive_script(script):
"""
Runs the contents of the given script in hive and returns stdout.
"""
if not os.path.isfile(script):
raise RuntimeError("Hive script: {0} does not exist.".format(script))
return run_hive(['-f', script])
def _is_ordered_dict(dikt):
if isinstance(dikt, collections.OrderedDict):
return True
if sys.version_info >= (3, 7):
return isinstance(dikt, dict)
return False
def _validate_partition(partition):
"""
If partition is set and its size is more than one and not ordered,
then we're unable to restore its path in the warehouse
"""
if (
partition
and len(partition) > 1
and not _is_ordered_dict(partition)
):
raise ValueError('Unable to restore table/partition location')
[docs]@six.add_metaclass(abc.ABCMeta)
class HiveClient(object): # interface
[docs] @abc.abstractmethod
def table_location(self, table, database='default', partition=None):
"""
Returns location of db.table (or db.table.partition). partition is a dict of partition key to
value.
"""
pass
[docs] @abc.abstractmethod
def table_schema(self, table, database='default'):
"""
Returns list of [(name, type)] for each column in database.table.
"""
pass
[docs] @abc.abstractmethod
def table_exists(self, table, database='default', partition=None):
"""
Returns true if db.table (or db.table.partition) exists. partition is a dict of partition key to
value.
"""
pass
[docs] @abc.abstractmethod
def partition_spec(self, partition):
""" Turn a dict into a string partition specification """
pass
[docs]class HiveCommandClient(HiveClient):
"""
Uses `hive` invocations to find information.
"""
[docs] def table_location(self, table, database='default', partition=None):
cmd = "use {0}; describe formatted {1}".format(database, table)
if partition is not None:
cmd += " PARTITION ({0})".format(self.partition_spec(partition))
stdout = run_hive_cmd(cmd)
for line in stdout.split("\n"):
if "Location:" in line:
return line.split("\t")[1]
[docs] def table_exists(self, table, database='default', partition=None):
if partition is None:
stdout = run_hive_cmd('use {0}; show tables like "{1}";'.format(database, table))
return stdout and table.lower() in stdout
else:
stdout = run_hive_cmd("""use %s; show partitions %s partition
(%s)""" % (database, table, self.partition_spec(partition)))
if stdout:
return True
else:
return False
[docs] def table_schema(self, table, database='default'):
describe = run_hive_cmd("use {0}; describe {1}".format(database, table))
if not describe or "does not exist" in describe:
return None
return [tuple([x.strip() for x in line.strip().split("\t")]) for line in describe.strip().split("\n")]
[docs] def partition_spec(self, partition):
"""
Turns a dict into the a Hive partition specification string.
"""
return ','.join(["`{0}`='{1}'".format(k, v) for (k, v) in
sorted(six.iteritems(partition), key=operator.itemgetter(0))])
[docs]class ApacheHiveCommandClient(HiveCommandClient):
"""
A subclass for the HiveCommandClient to (in some cases) ignore the return code from
the hive command so that we can just parse the output.
"""
[docs] def table_schema(self, table, database='default'):
describe = run_hive_cmd("use {0}; describe {1}".format(database, table), False)
if not describe or "Table not found" in describe:
return None
return [tuple([x.strip() for x in line.strip().split("\t")]) for line in describe.strip().split("\n")]
[docs]class HiveThriftContext(object):
"""
Context manager for hive metastore client.
"""
def __enter__(self):
try:
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol
# Note that this will only work with a CDH release.
# This uses the thrift bindings generated by the ThriftHiveMetastore service in Beeswax.
# If using the Apache release of Hive this import will fail.
from hive_metastore import ThriftHiveMetastore
config = luigi.configuration.get_config()
host = config.get('hive', 'metastore_host')
port = config.getint('hive', 'metastore_port')
transport = TSocket.TSocket(host, port)
transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
transport.open()
self.transport = transport
return ThriftHiveMetastore.Client(protocol)
except ImportError as e:
raise Exception('Could not import Hive thrift library:' + str(e))
def __exit__(self, exc_type, exc_val, exc_tb):
self.transport.close()
[docs]class WarehouseHiveClient(HiveClient):
"""
Client for managed tables that makes decision based on presence of directory in hdfs
"""
def __init__(self, hdfs_client=None, warehouse_location=None):
self.hdfs_client = hdfs_client or get_autoconfig_client()
self.warehouse_location = warehouse_location or get_hive_warehouse_location()
[docs] def table_schema(self, table, database='default'):
return NotImplemented
[docs] def table_location(self, table, database='default', partition=None):
return os.path.join(
self.warehouse_location,
database + '.db',
table,
self.partition_spec(partition)
)
[docs] def table_exists(self, table, database='default', partition=None):
"""
The table/partition is considered existing if corresponding path in hdfs exists
and contains file except those which match pattern set in `ignored_file_masks`
"""
path = self.table_location(table, database, partition)
if self.hdfs_client.exists(path):
ignored_files = get_ignored_file_masks()
if ignored_files is None:
return True
filenames = self.hdfs_client.listdir(path)
pattern = re.compile(ignored_files)
for filename in filenames:
if not pattern.match(filename):
return True
return False
[docs] def partition_spec(self, partition):
_validate_partition(partition)
return '/'.join([
'{}={}'.format(k, v) for (k, v) in six.iteritems(partition or {})
])
[docs]def get_default_client():
syntax = get_hive_syntax()
if syntax == "apache":
return ApacheHiveCommandClient()
elif syntax == "metastore":
return MetastoreClient()
elif syntax == 'warehouse':
return WarehouseHiveClient()
else:
return HiveCommandClient()
client = get_default_client()
[docs]class HiveQueryTask(luigi.contrib.hadoop.BaseHadoopJobTask):
"""
Task to run a hive query.
"""
# by default, we let hive figure these out.
n_reduce_tasks = None
bytes_per_reducer = None
reducers_max = None
[docs] @abc.abstractmethod
def query(self):
""" Text of query to run in hive """
raise RuntimeError("Must implement query!")
[docs] def hiverc(self):
"""
Location of an rc file to run before the query
if hiverc-location key is specified in luigi.cfg, will default to the value there
otherwise returns None.
Returning a list of rc files will load all of them in order.
"""
return luigi.configuration.get_config().get('hive', 'hiverc-location', default=None)
[docs] def hivevars(self):
"""
Returns a dict of key=value settings to be passed along
to the hive command line via --hivevar.
This option can be used as a separated namespace for script local variables.
See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+VariableSubstitution
"""
return {}
[docs] def hiveconfs(self):
"""
Returns a dict of key=value settings to be passed along
to the hive command line via --hiveconf. By default, sets
mapred.job.name to task_id and if not None, sets:
* mapred.reduce.tasks (n_reduce_tasks)
* mapred.fairscheduler.pool (pool) or mapred.job.queue.name (pool)
* hive.exec.reducers.bytes.per.reducer (bytes_per_reducer)
* hive.exec.reducers.max (reducers_max)
"""
jcs = {}
jcs['mapred.job.name'] = "'" + self.task_id + "'"
if self.n_reduce_tasks is not None:
jcs['mapred.reduce.tasks'] = self.n_reduce_tasks
if self.pool is not None:
# Supporting two schedulers: fair (default) and capacity using the same option
scheduler_type = luigi.configuration.get_config().get('hadoop', 'scheduler', 'fair')
if scheduler_type == 'fair':
jcs['mapred.fairscheduler.pool'] = self.pool
elif scheduler_type == 'capacity':
jcs['mapred.job.queue.name'] = self.pool
if self.bytes_per_reducer is not None:
jcs['hive.exec.reducers.bytes.per.reducer'] = self.bytes_per_reducer
if self.reducers_max is not None:
jcs['hive.exec.reducers.max'] = self.reducers_max
return jcs
[docs] def job_runner(self):
return HiveQueryRunner()
[docs]class HiveQueryRunner(luigi.contrib.hadoop.JobRunner):
"""
Runs a HiveQueryTask by shelling out to hive.
"""
[docs] def prepare_outputs(self, job):
"""
Called before job is started.
If output is a `FileSystemTarget`, create parent directories so the hive command won't fail
"""
outputs = flatten(job.output())
for o in outputs:
if isinstance(o, FileSystemTarget):
parent_dir = os.path.dirname(o.path)
if parent_dir and not o.fs.exists(parent_dir):
logger.info("Creating parent directory %r", parent_dir)
try:
# there is a possible race condition
# which needs to be handled here
o.fs.mkdir(parent_dir)
except FileAlreadyExists:
pass
[docs] def get_arglist(self, f_name, job):
arglist = load_hive_cmd() + ['-f', f_name]
hiverc = job.hiverc()
if hiverc:
if isinstance(hiverc, str):
hiverc = [hiverc]
for rcfile in hiverc:
arglist += ['-i', rcfile]
hiveconfs = job.hiveconfs()
if hiveconfs:
for k, v in six.iteritems(hiveconfs):
arglist += ['--hiveconf', '{0}={1}'.format(k, v)]
hivevars = job.hivevars()
if hivevars:
for k, v in six.iteritems(hivevars):
arglist += ['--hivevar', '{0}={1}'.format(k, v)]
logger.info(arglist)
return arglist
[docs] def run_job(self, job, tracking_url_callback=None):
if tracking_url_callback is not None:
warnings.warn("tracking_url_callback argument is deprecated, task.set_tracking_url is "
"used instead.", DeprecationWarning)
self.prepare_outputs(job)
with tempfile.NamedTemporaryFile() as f:
query = job.query()
if isinstance(query, unicode):
query = query.encode('utf8')
f.write(query)
f.flush()
arglist = self.get_arglist(f.name, job)
return luigi.contrib.hadoop.run_and_track_hadoop_job(arglist, job.set_tracking_url)
[docs]class HivePartitionTarget(luigi.Target):
"""
Target representing Hive table or Hive partition
"""
def __init__(self, table, partition, database='default', fail_missing_table=True, client=None):
"""
@param table: Table name
@type table: str
@param partition: partition specificaton in form of
dict of {"partition_column_1": "partition_value_1", "partition_column_2": "partition_value_2", ... }
If `partition` is `None` or `{}` then target is Hive nonpartitioned table
@param database: Database name
@param fail_missing_table: flag to ignore errors raised due to table nonexistence
@param client: `HiveCommandClient` instance. Default if `client is None`
"""
self.database = database
self.table = table
self.partition = partition
self.client = client or get_default_client()
self.fail_missing_table = fail_missing_table
[docs] def exists(self):
"""
returns `True` if the partition/table exists
"""
try:
logger.debug(
"Checking Hive table '{d}.{t}' for partition {p}".format(
d=self.database,
t=self.table,
p=str(self.partition or {})
)
)
return self.client.table_exists(self.table, self.database, self.partition)
except HiveCommandError:
if self.fail_missing_table:
raise
else:
if self.client.table_exists(self.table, self.database):
# a real error occurred
raise
else:
# oh the table just doesn't exist
return False
@property
def path(self):
"""
Returns the path for this HiveTablePartitionTarget's data.
"""
location = self.client.table_location(self.table, self.database, self.partition)
if not location:
raise Exception("Couldn't find location for table: {0}".format(str(self)))
return location
[docs]class HiveTableTarget(HivePartitionTarget):
"""
Target representing non-partitioned table
"""
def __init__(self, table, database='default', client=None):
super(HiveTableTarget, self).__init__(
table=table,
partition=None,
database=database,
fail_missing_table=False,
client=client,
)
[docs]class ExternalHiveTask(luigi.ExternalTask):
"""
External task that depends on a Hive table/partition.
"""
database = luigi.Parameter(default='default')
table = luigi.Parameter()
partition = luigi.DictParameter(
default={},
description='Python dictionary specifying the target partition e.g. {"date": "2013-01-25"}'
)
[docs] def output(self):
return HivePartitionTarget(
database=self.database,
table=self.table,
partition=self.partition,
)