Source code for luigi.task_register

# -*- coding: utf-8 -*-
# Copyright 2012-2015 Spotify AB
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
Define the centralized register of all :class:`~luigi.task.Task` classes.

import abc

import logging
logger = logging.getLogger('luigi-interface')

[docs] class TaskClassException(Exception): pass
[docs] class TaskClassNotFoundException(TaskClassException): pass
[docs] class TaskClassAmbigiousException(TaskClassException): pass
[docs] class Register(abc.ABCMeta): """ The Metaclass of :py:class:`Task`. Acts as a global registry of Tasks with the following properties: 1. Cache instances of objects so that eg. ``X(1, 2, 3)`` always returns the same object. 2. Keep track of all subclasses of :py:class:`Task` and expose them. """ __instance_cache = {} _default_namespace_dict = {} _reg = [] AMBIGUOUS_CLASS = object() # Placeholder denoting an error """If this value is returned by :py:meth:`_get_reg` then there is an ambiguous task name (two :py:class:`Task` have the same name). This denotes an error.""" def __new__(metacls, classname, bases, classdict, **kwargs): """ Custom class creation for namespacing. Also register all subclasses. When the set or inherited namespace evaluates to ``None``, set the task namespace to whatever the currently declared namespace is. """ cls = super(Register, metacls).__new__(metacls, classname, bases, classdict, **kwargs) cls._namespace_at_class_time = metacls._get_namespace(cls.__module__) metacls._reg.append(cls) return cls def __call__(cls, *args, **kwargs): """ Custom class instantiation utilizing instance cache. If a Task has already been instantiated with the same parameters, the previous instance is returned to reduce number of object instances. """ def instantiate(): return super(Register, cls).__call__(*args, **kwargs) h = cls.__instance_cache if h is None: # disabled return instantiate() params = cls.get_params() param_values = cls.get_param_values(params, args, kwargs) k = (cls, tuple(param_values)) try: hash(k) except TypeError: logger.debug("Not all parameter values are hashable so instance isn't coming from the cache") return instantiate() # unhashable types in parameters if k not in h: h[k] = instantiate() return h[k]
[docs] @classmethod def clear_instance_cache(cls): """ Clear/Reset the instance cache. """ cls.__instance_cache = {}
[docs] @classmethod def disable_instance_cache(cls): """ Disables the instance cache. """ cls.__instance_cache = None
@property def task_family(cls): """ Internal note: This function will be deleted soon. """ task_namespace = cls.get_task_namespace() if not task_namespace: return cls.__name__ else: return f"{task_namespace}.{cls.__name__}" @classmethod def _get_reg(cls): """Return all of the registered classes. :return: an ``dict`` of task_family -> class """ # We have to do this on-demand in case task names have changed later reg = dict() for task_cls in cls._reg: if not task_cls._visible_in_registry: continue name = task_cls.get_task_family() if name in reg and \ (reg[name] == Register.AMBIGUOUS_CLASS or # Check so issubclass doesn't crash not issubclass(task_cls, reg[name])): # Registering two different classes - this means we can't instantiate them by name # The only exception is if one class is a subclass of the other. In that case, we # instantiate the most-derived class (this fixes some issues with decorator wrappers). reg[name] = Register.AMBIGUOUS_CLASS else: reg[name] = task_cls return reg @classmethod def _set_reg(cls, reg): """The writing complement of _get_reg """ cls._reg = [task_cls for task_cls in reg.values() if task_cls is not cls.AMBIGUOUS_CLASS]
[docs] @classmethod def task_names(cls): """ List of task names as strings """ return sorted(cls._get_reg().keys())
[docs] @classmethod def tasks_str(cls): """ Human-readable register contents dump. """ return ','.join(cls.task_names())
[docs] @classmethod def get_task_cls(cls, name): """ Returns an unambiguous class or raises an exception. """ task_cls = cls._get_reg().get(name) if not task_cls: raise TaskClassNotFoundException(cls._missing_task_msg(name)) if task_cls == cls.AMBIGUOUS_CLASS: raise TaskClassAmbigiousException('Task %r is ambiguous' % name) return task_cls
[docs] @classmethod def get_all_params(cls): """ Compiles and returns all parameters for all :py:class:`Task`. :return: a generator of tuples (TODO: we should make this more elegant) """ for task_name, task_cls in cls._get_reg().items(): if task_cls == cls.AMBIGUOUS_CLASS: continue for param_name, param_obj in task_cls.get_params(): yield task_name, (not task_cls.use_cmdline_section), param_name, param_obj
@staticmethod def _editdistance(a, b): """ Simple unweighted Levenshtein distance """ r0 = range(0, len(b) + 1) r1 = [0] * (len(b) + 1) for i in range(0, len(a)): r1[0] = i + 1 for j in range(0, len(b)): c = 0 if a[i] is b[j] else 1 r1[j + 1] = min(r1[j] + 1, r0[j + 1] + 1, r0[j] + c) r0 = r1[:] return r1[len(b)] @classmethod def _missing_task_msg(cls, task_name): weighted_tasks = [(Register._editdistance(task_name, task_name_2), task_name_2) for task_name_2 in cls.task_names()] ordered_tasks = sorted(weighted_tasks, key=lambda pair: pair[0]) candidates = [task for (dist, task) in ordered_tasks if dist <= 5 and dist < len(task)] if candidates: return "No task %s. Did you mean:\n%s" % (task_name, '\n'.join(candidates)) else: return "No task %s. Candidates are: %s" % (task_name, cls.tasks_str()) @classmethod def _get_namespace(mcs, module_name): for parent in mcs._module_parents(module_name): entry = mcs._default_namespace_dict.get(parent) if entry: return entry return '' # Default if nothing specifies @staticmethod def _module_parents(module_name): ''' >>> list(Register._module_parents('a.b')) ['a.b', 'a', ''] ''' spl = module_name.split('.') for i in range(len(spl), 0, -1): yield '.'.join(spl[0:i]) if module_name: yield ''
[docs] def load_task(module, task_name, params_str): """ Imports task dynamically given a module and a task name. """ if module is not None: __import__(module) task_cls = Register.get_task_cls(task_name) return task_cls.from_str_params(params_str)