# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
The system for scheduling tasks and executing them in order.
Deals with dependencies, priorities, resources, etc.
The :py:class:`~luigi.worker.Worker` pulls tasks from the scheduler (usually over the REST interface) and executes them.
See :doc:`/central_scheduler` for more info.
"""
import collections
from collections.abc import MutableSet
import json
from luigi.batch_notifier import BatchNotifier
import pickle
import functools
import hashlib
import inspect
import itertools
import logging
import os
import re
import time
import uuid
from luigi import configuration
from luigi import notifications
from luigi import parameter
from luigi import task_history as history
from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN, \
BATCH_RUNNING
from luigi.task import Config
from luigi.parameter import ParameterVisibility
from luigi.metrics import MetricsCollectors
logger = logging.getLogger(__name__)
UPSTREAM_RUNNING = 'UPSTREAM_RUNNING'
UPSTREAM_MISSING_INPUT = 'UPSTREAM_MISSING_INPUT'
UPSTREAM_FAILED = 'UPSTREAM_FAILED'
UPSTREAM_DISABLED = 'UPSTREAM_DISABLED'
UPSTREAM_SEVERITY_ORDER = (
'',
UPSTREAM_RUNNING,
UPSTREAM_MISSING_INPUT,
UPSTREAM_FAILED,
UPSTREAM_DISABLED,
)
UPSTREAM_SEVERITY_KEY = UPSTREAM_SEVERITY_ORDER.index
STATUS_TO_UPSTREAM_MAP = {
FAILED: UPSTREAM_FAILED,
RUNNING: UPSTREAM_RUNNING,
BATCH_RUNNING: UPSTREAM_RUNNING,
PENDING: UPSTREAM_MISSING_INPUT,
DISABLED: UPSTREAM_DISABLED,
}
WORKER_STATE_DISABLED = 'disabled'
WORKER_STATE_ACTIVE = 'active'
TASK_FAMILY_RE = re.compile(r'([^(_]+)[(_]')
RPC_METHODS = {}
_retry_policy_fields = [
"retry_count",
"disable_hard_timeout",
"disable_window",
]
RetryPolicy = collections.namedtuple("RetryPolicy", _retry_policy_fields) # type: ignore
def _get_empty_retry_policy():
return RetryPolicy(*[None] * len(_retry_policy_fields))
[docs]
def rpc_method(**request_args):
def _rpc_method(fn):
# If request args are passed, return this function again for use as
# the decorator function with the request args attached.
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, ann = inspect.getfullargspec(fn)
assert not varargs
first_arg, *all_args = args
assert first_arg == 'self'
defaults = dict(zip(reversed(all_args), reversed(defaults or ())))
required_args = frozenset(arg for arg in all_args if arg not in defaults)
fn_name = fn.__name__
@functools.wraps(fn)
def rpc_func(self, *args, **kwargs):
actual_args = defaults.copy()
actual_args.update(dict(zip(all_args, args)))
actual_args.update(kwargs)
if not all(arg in actual_args for arg in required_args):
raise TypeError('{} takes {} arguments ({} given)'.format(
fn_name, len(all_args), len(actual_args)))
return self._request('/api/{}'.format(fn_name), actual_args, **request_args)
RPC_METHODS[fn_name] = rpc_func
return fn
return _rpc_method
[docs]
class scheduler(Config):
retry_delay = parameter.FloatParameter(default=900.0)
remove_delay = parameter.FloatParameter(default=600.0)
worker_disconnect_delay = parameter.FloatParameter(default=60.0)
state_path = parameter.Parameter(default='/var/lib/luigi-server/state.pickle')
batch_emails = parameter.BoolParameter(default=False, description="Send e-mails in batches rather than immediately")
# Jobs are disabled if we see more than retry_count failures in disable_window seconds.
# These disables last for disable_persist seconds.
disable_window = parameter.IntParameter(default=3600)
retry_count = parameter.IntParameter(default=999999999)
disable_hard_timeout = parameter.IntParameter(default=999999999)
disable_persist = parameter.IntParameter(default=86400)
max_shown_tasks = parameter.IntParameter(default=100000)
max_graph_nodes = parameter.IntParameter(default=100000)
record_task_history = parameter.BoolParameter(default=False)
prune_on_get_work = parameter.BoolParameter(default=False)
pause_enabled = parameter.BoolParameter(default=True)
send_messages = parameter.BoolParameter(default=True)
metrics_collector = parameter.EnumParameter(enum=MetricsCollectors, default=MetricsCollectors.default)
metrics_custom_import = parameter.OptionalStrParameter(default=None)
stable_done_cooldown_secs = parameter.IntParameter(default=10,
description="Sets cooldown period to avoid running the same task twice")
"""
Sets a cooldown period in seconds after a task was completed, during this period the same task will not accepted by the scheduler.
"""
def _get_retry_policy(self):
return RetryPolicy(self.retry_count, self.disable_hard_timeout, self.disable_window)
def _get_default(x, default):
if x is not None:
return x
else:
return default
[docs]
class OrderedSet(MutableSet):
"""
Standard Python OrderedSet recipe found at http://code.activestate.com/recipes/576694/
Modified to include a peek function to get the last element
"""
def __init__(self, iterable=None):
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __len__(self):
return len(self.map)
def __contains__(self, key):
return key in self.map
[docs]
def add(self, key):
if key not in self.map:
end = self.end
curr = end[1]
curr[2] = end[1] = self.map[key] = [key, curr, end]
[docs]
def discard(self, key):
if key in self.map:
key, prev, next = self.map.pop(key)
prev[2] = next
next[1] = prev
def __iter__(self):
end = self.end
curr = end[2]
while curr is not end:
yield curr[0]
curr = curr[2]
def __reversed__(self):
end = self.end
curr = end[1]
while curr is not end:
yield curr[0]
curr = curr[1]
[docs]
def peek(self, last=True):
if not self:
raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0]
return key
[docs]
def pop(self, last=True):
key = self.peek(last)
self.discard(key)
return key
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
[docs]
class Task:
def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None,
params=None, param_visibilities=None, accepts_messages=False, tracking_url=None, status_message=None,
progress_percentage=None, retry_policy='notoptional'):
self.id = task_id
self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active)
self.workers = OrderedSet() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active
if deps is None:
self.deps = set()
else:
self.deps = set(deps)
self.status = status # PENDING, RUNNING, FAILED or DONE
self.time = time.time() # Timestamp when task was first added
self.updated = self.time
self.retry = None
self.remove = None
self.worker_running = None # the worker id that is currently running the task or None
self.time_running = None # Timestamp when picked up by worker
self.expl = None
self.priority = priority
self.resources = _get_default(resources, {})
self.family = family
self.module = module
self.param_visibilities = _get_default(param_visibilities, {})
self.params = {}
self.public_params = {}
self.hidden_params = {}
self.set_params(params)
self.accepts_messages = accepts_messages
self.retry_policy = retry_policy
self.failures = collections.deque()
self.first_failure_time = None
self.tracking_url = tracking_url
self.status_message = status_message
self.progress_percentage = progress_percentage
self.scheduler_message_responses = {}
self.scheduler_disable_time = None
self.runnable = False
self.batchable = False
self.batch_id = None
def __repr__(self):
return "Task(%r)" % vars(self)
[docs]
def set_params(self, params):
self.params = _get_default(params, {})
self.public_params = {key: value for key, value in self.params.items() if
self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.PUBLIC}
self.hidden_params = {key: value for key, value in self.params.items() if
self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.HIDDEN}
# TODO(2017-08-10) replace this function with direct calls to batchable
# this only exists for backward compatibility
[docs]
def is_batchable(self):
try:
return self.batchable
except AttributeError:
return False
[docs]
def add_failure(self):
"""
Add a failure event with the current timestamp.
"""
failure_time = time.time()
if not self.first_failure_time:
self.first_failure_time = failure_time
self.failures.append(failure_time)
[docs]
def num_failures(self):
"""
Return the number of failures in the window.
"""
min_time = time.time() - self.retry_policy.disable_window
while self.failures and self.failures[0] < min_time:
self.failures.popleft()
return len(self.failures)
[docs]
def has_excessive_failures(self):
if self.first_failure_time is not None:
if time.time() >= self.first_failure_time + self.retry_policy.disable_hard_timeout:
return True
logger.debug('%s task num failures is %s and limit is %s', self.id, self.num_failures(), self.retry_policy.retry_count)
if self.num_failures() >= self.retry_policy.retry_count:
logger.debug('%s task num failures limit(%s) is exceeded', self.id, self.retry_policy.retry_count)
return True
return False
[docs]
def clear_failures(self):
"""
Clear the failures history
"""
self.failures.clear()
self.first_failure_time = None
@property
def pretty_id(self):
param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.public_params.items()))
return u'{}({})'.format(self.family, param_str)
[docs]
class Worker:
"""
Structure for tracking worker activity and keeping their references.
"""
def __init__(self, worker_id, last_active=None):
self.id = worker_id
self.reference = None # reference to the worker in the real world. (Currently a dict containing just the host)
self.last_active = last_active or time.time() # seconds since epoch
self.last_get_work = None
self.started = time.time() # seconds since epoch
self.tasks = set() # task objects
self.info = {}
self.disabled = False
self.rpc_messages = []
[docs]
def add_info(self, info):
self.info.update(info)
[docs]
def update(self, worker_reference, get_work=False):
if worker_reference:
self.reference = worker_reference
self.last_active = time.time()
if get_work:
self.last_get_work = time.time()
[docs]
def prune(self, config):
# Delete workers that haven't said anything for a while (probably killed)
if self.last_active + config.worker_disconnect_delay < time.time():
return True
[docs]
def get_tasks(self, state, *statuses):
num_self_tasks = len(self.tasks)
num_state_tasks = sum(len(state._status_tasks[status]) for status in statuses)
if num_self_tasks < num_state_tasks:
return filter(lambda task: task.status in statuses, self.tasks)
else:
return filter(lambda task: self.id in task.workers, state.get_active_tasks_by_status(*statuses))
[docs]
def is_trivial_worker(self, state):
"""
If it's not an assistant having only tasks that are without
requirements.
We have to pass the state parameter for optimization reasons.
"""
if self.assistant:
return False
return all(not task.resources for task in self.get_tasks(state, PENDING))
@property
def assistant(self):
return self.info.get('assistant', False)
@property
def enabled(self):
return not self.disabled
@property
def state(self):
if self.enabled:
return WORKER_STATE_ACTIVE
else:
return WORKER_STATE_DISABLED
[docs]
def add_rpc_message(self, name, **kwargs):
# the message has the format {'name': <function_name>, 'kwargs': <function_kwargs>}
self.rpc_messages.append({'name': name, 'kwargs': kwargs})
[docs]
def fetch_rpc_messages(self):
messages = self.rpc_messages[:]
del self.rpc_messages[:]
return messages
def __str__(self):
return self.id
[docs]
class SimpleTaskState:
"""
Keep track of the current state and handle persistence.
The point of this class is to enable other ways to keep state, eg. by using a database
These will be implemented by creating an abstract base class that this and other classes
inherit from.
"""
def __init__(self, state_path):
self._state_path = state_path
self._tasks = {} # map from id to a Task object
self._status_tasks = collections.defaultdict(dict)
self._active_workers = {} # map from id to a Worker object
self._task_batchers = {}
self._metrics_collector = None
[docs]
def get_state(self):
return self._tasks, self._active_workers, self._task_batchers
[docs]
def set_state(self, state):
self._tasks, self._active_workers = state[:2]
if len(state) >= 3:
self._task_batchers = state[2]
[docs]
def dump(self):
try:
with open(self._state_path, 'wb') as fobj:
pickle.dump(self.get_state(), fobj)
except IOError:
logger.warning("Failed saving scheduler state", exc_info=1)
else:
logger.info("Saved state in %s", self._state_path)
# prone to lead to crashes when old state is unpickled with updated code. TODO some kind of version control?
[docs]
def load(self):
if os.path.exists(self._state_path):
logger.info("Attempting to load state from %s", self._state_path)
try:
with open(self._state_path, 'rb') as fobj:
state = pickle.load(fobj)
except BaseException:
logger.exception("Error when loading state. Starting from empty state.")
return
self.set_state(state)
self._status_tasks = collections.defaultdict(dict)
for task in self._tasks.values():
self._status_tasks[task.status][task.id] = task
else:
logger.info("No prior state file exists at %s. Starting with empty state", self._state_path)
[docs]
def get_active_tasks(self):
return self._tasks.values()
[docs]
def get_active_tasks_by_status(self, *statuses):
return itertools.chain.from_iterable(self._status_tasks[status].values() for status in statuses)
[docs]
def get_active_task_count_for_status(self, status):
if status:
return len(self._status_tasks[status])
else:
return len(self._tasks)
[docs]
def get_batch_running_tasks(self, batch_id):
assert batch_id is not None
return [
task for task in self.get_active_tasks_by_status(BATCH_RUNNING)
if task.batch_id == batch_id
]
[docs]
def set_batcher(self, worker_id, family, batcher_args, max_batch_size):
self._task_batchers.setdefault(worker_id, {})
self._task_batchers[worker_id][family] = (batcher_args, max_batch_size)
[docs]
def get_batcher(self, worker_id, family):
return self._task_batchers.get(worker_id, {}).get(family, (None, 1))
[docs]
def num_pending_tasks(self):
"""
Return how many tasks are PENDING + RUNNING. O(1).
"""
return len(self._status_tasks[PENDING]) + len(self._status_tasks[RUNNING])
[docs]
def get_task(self, task_id, default=None, setdefault=None):
if setdefault:
task = self._tasks.setdefault(task_id, setdefault)
self._status_tasks[task.status][task.id] = task
return task
else:
return self._tasks.get(task_id, default)
[docs]
def has_task(self, task_id):
return task_id in self._tasks
[docs]
def re_enable(self, task, config=None):
task.scheduler_disable_time = None
task.clear_failures()
if config:
self.set_status(task, FAILED, config)
task.clear_failures()
[docs]
def set_batch_running(self, task, batch_id, worker_id):
self.set_status(task, BATCH_RUNNING)
task.batch_id = batch_id
task.worker_running = worker_id
task.resources_running = task.resources
task.time_running = time.time()
[docs]
def set_status(self, task, new_status, config=None):
if new_status == FAILED:
assert config is not None
if new_status == DISABLED and task.status in (RUNNING, BATCH_RUNNING):
return
remove_on_failure = task.batch_id is not None and not task.batchable
if task.status == DISABLED:
if new_status == DONE:
self.re_enable(task)
# don't allow workers to override a scheduler disable
elif task.scheduler_disable_time is not None and new_status != DISABLED:
return
if task.status == RUNNING and task.batch_id is not None and new_status != RUNNING:
for batch_task in self.get_batch_running_tasks(task.batch_id):
self.set_status(batch_task, new_status, config)
batch_task.batch_id = None
task.batch_id = None
if new_status == FAILED and task.status != DISABLED:
task.add_failure()
if task.has_excessive_failures():
task.scheduler_disable_time = time.time()
new_status = DISABLED
if not config.batch_emails:
notifications.send_error_email(
'Luigi Scheduler: DISABLED {task} due to excessive failures'.format(task=task.id),
'{task} failed {failures} times in the last {window} seconds, so it is being '
'disabled for {persist} seconds'.format(
failures=task.retry_policy.retry_count,
task=task.id,
window=task.retry_policy.disable_window,
persist=config.disable_persist,
))
elif new_status == DISABLED:
task.scheduler_disable_time = None
if new_status != task.status:
self._status_tasks[task.status].pop(task.id)
self._status_tasks[new_status][task.id] = task
task.status = new_status
task.updated = time.time()
self.update_metrics(task, config)
if new_status == FAILED:
task.retry = time.time() + config.retry_delay
if remove_on_failure:
task.remove = time.time()
[docs]
def fail_dead_worker_task(self, task, config, assistants):
# If a running worker disconnects, tag all its jobs as FAILED and subject it to the same retry logic
if task.status in (BATCH_RUNNING, RUNNING) and task.worker_running and task.worker_running not in task.stakeholders | assistants:
logger.info("Task %r is marked as running by disconnected worker %r -> marking as "
"FAILED with retry delay of %rs", task.id, task.worker_running,
config.retry_delay)
task.worker_running = None
self.set_status(task, FAILED, config)
task.retry = time.time() + config.retry_delay
[docs]
def update_status(self, task, config):
# Mark tasks with no remaining active stakeholders for deletion
if (not task.stakeholders) and (task.remove is None) and (task.status != RUNNING):
# We don't check for the RUNNING case, because that is already handled
# by the fail_dead_worker_task function.
logger.debug("Task %r has no stakeholders anymore -> might remove "
"task in %s seconds", task.id, config.remove_delay)
task.remove = time.time() + config.remove_delay
# Re-enable task after the disable time expires
if task.status == DISABLED and task.scheduler_disable_time is not None:
if time.time() - task.scheduler_disable_time > config.disable_persist:
self.re_enable(task, config)
# Reset FAILED tasks to PENDING if max timeout is reached, and retry delay is >= 0
if task.status == FAILED and config.retry_delay >= 0 and task.retry < time.time():
self.set_status(task, PENDING, config)
[docs]
def may_prune(self, task):
return task.remove and time.time() >= task.remove
[docs]
def inactivate_tasks(self, delete_tasks):
# The terminology is a bit confusing: we used to "delete" tasks when they became inactive,
# but with a pluggable state storage, you might very well want to keep some history of
# older tasks as well. That's why we call it "inactivate" (as in the verb)
for task in delete_tasks:
task_obj = self._tasks.pop(task)
self._status_tasks[task_obj.status].pop(task)
[docs]
def get_active_workers(self, last_active_lt=None, last_get_work_gt=None):
for worker in self._active_workers.values():
if last_active_lt is not None and worker.last_active >= last_active_lt:
continue
last_get_work = worker.last_get_work
if last_get_work_gt is not None and (
last_get_work is None or last_get_work <= last_get_work_gt):
continue
yield worker
[docs]
def get_assistants(self, last_active_lt=None):
return filter(lambda w: w.assistant, self.get_active_workers(last_active_lt))
[docs]
def get_worker_ids(self):
return self._active_workers.keys() # only used for unit tests
[docs]
def get_worker(self, worker_id):
return self._active_workers.setdefault(worker_id, Worker(worker_id))
[docs]
def inactivate_workers(self, delete_workers):
# Mark workers as inactive
for worker in delete_workers:
self._active_workers.pop(worker)
self._remove_workers_from_tasks(delete_workers)
def _remove_workers_from_tasks(self, workers, remove_stakeholders=True):
for task in self.get_active_tasks():
if remove_stakeholders:
task.stakeholders.difference_update(workers)
task.workers -= workers
[docs]
def disable_workers(self, worker_ids):
self._remove_workers_from_tasks(worker_ids, remove_stakeholders=False)
for worker_id in worker_ids:
worker = self.get_worker(worker_id)
worker.disabled = True
worker.tasks.clear()
[docs]
def update_metrics(self, task, config):
if task.status == DISABLED:
self._metrics_collector.handle_task_disabled(task, config)
elif task.status == DONE:
self._metrics_collector.handle_task_done(task)
elif task.status == FAILED:
self._metrics_collector.handle_task_failed(task)
[docs]
class Scheduler:
"""
Async scheduler that can handle multiple workers, etc.
Can be run locally or on a server (using RemoteScheduler + server.Server).
"""
def __init__(self, config=None, resources=None, task_history_impl=None, **kwargs):
"""
Keyword Arguments:
:param config: an object of class "scheduler" or None (in which the global instance will be used)
:param resources: a dict of str->int constraints
:param task_history_impl: ignore config and use this object as the task history
"""
self._config = config or scheduler(**kwargs)
self._state = SimpleTaskState(self._config.state_path)
if task_history_impl:
self._task_history = task_history_impl
elif self._config.record_task_history:
from luigi import db_task_history # Needs sqlalchemy, thus imported here
self._task_history = db_task_history.DbTaskHistory()
else:
self._task_history = history.NopHistory()
self._resources = resources or configuration.get_config().getintdict('resources') # TODO: Can we make this a Parameter?
self._make_task = functools.partial(Task, retry_policy=self._config._get_retry_policy())
self._worker_requests = {}
self._paused = False
if self._config.batch_emails:
self._email_batcher = BatchNotifier()
self._state._metrics_collector = MetricsCollectors.get(self._config.metrics_collector, self._config.metrics_custom_import)
[docs]
def load(self):
self._state.load()
[docs]
def dump(self):
self._state.dump()
if self._config.batch_emails:
self._email_batcher.send_email()
[docs]
@rpc_method()
def prune(self):
logger.debug("Starting pruning of task graph")
self._prune_workers()
self._prune_tasks()
self._prune_emails()
logger.debug("Done pruning task graph")
def _prune_workers(self):
remove_workers = []
for worker in self._state.get_active_workers():
if worker.prune(self._config):
logger.debug("Worker %s timed out (no contact for >=%ss)", worker, self._config.worker_disconnect_delay)
remove_workers.append(worker.id)
self._state.inactivate_workers(remove_workers)
def _prune_tasks(self):
assistant_ids = {w.id for w in self._state.get_assistants()}
remove_tasks = []
for task in self._state.get_active_tasks():
self._state.fail_dead_worker_task(task, self._config, assistant_ids)
self._state.update_status(task, self._config)
if self._state.may_prune(task):
logger.info("Removing task %r", task.id)
remove_tasks.append(task.id)
self._state.inactivate_tasks(remove_tasks)
def _prune_emails(self):
if self._config.batch_emails:
self._email_batcher.update()
def _update_worker(self, worker_id, worker_reference=None, get_work=False):
# Keep track of whenever the worker was last active.
# For convenience also return the worker object.
worker = self._state.get_worker(worker_id)
worker.update(worker_reference, get_work=get_work)
return worker
def _update_priority(self, task, prio, worker):
"""
Update priority of the given task.
Priority can only be increased.
If the task doesn't exist, a placeholder task is created to preserve priority when the task is later scheduled.
"""
task.priority = prio = max(prio, task.priority)
for dep in task.deps or []:
t = self._state.get_task(dep)
if t is not None and prio > t.priority:
self._update_priority(t, prio, worker)
[docs]
@rpc_method()
def add_task_batcher(self, worker, task_family, batched_args, max_batch_size=float('inf')):
self._state.set_batcher(worker, task_family, batched_args, max_batch_size)
[docs]
@rpc_method()
def forgive_failures(self, task_id=None):
status = PENDING
task = self._state.get_task(task_id)
if task is None:
return {"task_id": task_id, "status": None}
# we forgive only failures
if task.status == FAILED:
# forgive but do not forget
self._update_task_history(task, status)
self._state.set_status(task, status, self._config)
return {"task_id": task_id, "status": task.status}
[docs]
@rpc_method()
def mark_as_done(self, task_id=None):
status = DONE
task = self._state.get_task(task_id)
if task is None:
return {"task_id": task_id, "status": None}
# we can force mark DONE for running or failed tasks
if task.status in {RUNNING, FAILED, DISABLED}:
self._update_task_history(task, status)
self._state.set_status(task, status, self._config)
return {"task_id": task_id, "status": task.status}
[docs]
@rpc_method()
def add_task(self, task_id=None, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None,
priority=0, family='', module=None, params=None, param_visibilities=None, accepts_messages=False,
assistant=False, tracking_url=None, worker=None, batchable=None,
batch_id=None, retry_policy_dict=None, owners=None, **kwargs):
"""
* add task identified by task_id if it doesn't exist
* if deps is not None, update dependency list
* update status of task
* add additional workers/stakeholders
* update priority when needed
"""
assert worker is not None
worker_id = worker
worker = self._update_worker(worker_id)
resources = {} if resources is None else resources.copy()
if retry_policy_dict is None:
retry_policy_dict = {}
retry_policy = self._generate_retry_policy(retry_policy_dict)
if worker.enabled:
_default_task = self._make_task(
task_id=task_id, status=PENDING, deps=deps, resources=resources,
priority=priority, family=family, module=module, params=params, param_visibilities=param_visibilities,
)
else:
_default_task = None
task = self._state.get_task(task_id, setdefault=_default_task)
if task is None or (task.status != RUNNING and not worker.enabled):
return
# Ignore claims that the task is PENDING if it very recently was marked as DONE.
if status == PENDING and task.status == DONE and (time.time() - task.updated) < self._config.stable_done_cooldown_secs:
return
# for setting priority, we'll sometimes create tasks with unset family and params
if not task.family:
task.family = family
if not getattr(task, 'module', None):
task.module = module
if not getattr(task, 'param_visibilities', None):
task.param_visibilities = _get_default(param_visibilities, {})
if not task.params:
task.set_params(params)
if batch_id is not None:
task.batch_id = batch_id
if status == RUNNING and not task.worker_running:
task.worker_running = worker_id
if batch_id:
# copy resources_running of the first batch task
batch_tasks = self._state.get_batch_running_tasks(batch_id)
task.resources_running = batch_tasks[0].resources_running.copy()
task.time_running = time.time()
if accepts_messages is not None:
task.accepts_messages = accepts_messages
if tracking_url is not None or task.status != RUNNING:
task.tracking_url = tracking_url
if task.batch_id is not None:
for batch_task in self._state.get_batch_running_tasks(task.batch_id):
batch_task.tracking_url = tracking_url
if batchable is not None:
task.batchable = batchable
if task.remove is not None:
task.remove = None # unmark task for removal so it isn't removed after being added
if expl is not None:
task.expl = expl
if task.batch_id is not None:
for batch_task in self._state.get_batch_running_tasks(task.batch_id):
batch_task.expl = expl
task_is_not_running = task.status not in (RUNNING, BATCH_RUNNING)
task_started_a_run = status in (DONE, FAILED, RUNNING)
running_on_this_worker = task.worker_running == worker_id
if task_is_not_running or (task_started_a_run and running_on_this_worker) or new_deps:
# don't allow re-scheduling of task while it is running, it must either fail or succeed on the worker actually running it
if status != task.status or status == PENDING:
# Update the DB only if there was a acctual change, to prevent noise.
# We also check for status == PENDING b/c that's the default value
# (so checking for status != task.status woule lie)
self._update_task_history(task, status)
self._state.set_status(task, PENDING if status == SUSPENDED else status, self._config)
if status == FAILED and self._config.batch_emails:
batched_params, _ = self._state.get_batcher(worker_id, family)
if batched_params:
unbatched_params = {
param: value
for param, value in task.params.items()
if param not in batched_params
}
else:
unbatched_params = task.params
try:
expl_raw = json.loads(expl)
except ValueError:
expl_raw = expl
self._email_batcher.add_failure(
task.pretty_id, task.family, unbatched_params, expl_raw, owners)
if task.status == DISABLED:
self._email_batcher.add_disable(
task.pretty_id, task.family, unbatched_params, owners)
if deps is not None:
task.deps = set(deps)
if new_deps is not None:
task.deps.update(new_deps)
if resources is not None:
task.resources = resources
if worker.enabled and not assistant:
task.stakeholders.add(worker_id)
# Task dependencies might not exist yet. Let's create dummy tasks for them for now.
# Otherwise the task dependencies might end up being pruned if scheduling takes a long time
for dep in task.deps or []:
t = self._state.get_task(dep, setdefault=self._make_task(task_id=dep, status=UNKNOWN, deps=None, priority=priority))
t.stakeholders.add(worker_id)
self._update_priority(task, priority, worker_id)
# Because some tasks (non-dynamic dependencies) are `_make_task`ed
# before we know their retry_policy, we always set it here
task.retry_policy = retry_policy
if runnable and status != FAILED and worker.enabled:
task.workers.add(worker_id)
self._state.get_worker(worker_id).tasks.add(task)
task.runnable = runnable
[docs]
@rpc_method()
def announce_scheduling_failure(self, task_name, family, params, expl, owners, **kwargs):
if not self._config.batch_emails:
return
worker_id = kwargs['worker']
batched_params, _ = self._state.get_batcher(worker_id, family)
if batched_params:
unbatched_params = {
param: value
for param, value in params.items()
if param not in batched_params
}
else:
unbatched_params = params
self._email_batcher.add_scheduling_fail(task_name, family, unbatched_params, expl, owners)
[docs]
@rpc_method()
def add_worker(self, worker, info, **kwargs):
self._state.get_worker(worker).add_info(info)
[docs]
@rpc_method()
def disable_worker(self, worker):
self._state.disable_workers({worker})
[docs]
@rpc_method()
def set_worker_processes(self, worker, n):
self._state.get_worker(worker).add_rpc_message('set_worker_processes', n=n)
[docs]
@rpc_method()
def send_scheduler_message(self, worker, task, content):
if not self._config.send_messages:
return {"message_id": None}
message_id = str(uuid.uuid4())
self._state.get_worker(worker).add_rpc_message('dispatch_scheduler_message', task_id=task,
message_id=message_id, content=content)
return {"message_id": message_id}
[docs]
@rpc_method()
def add_scheduler_message_response(self, task_id, message_id, response):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
task.scheduler_message_responses[message_id] = response
[docs]
@rpc_method()
def get_scheduler_message_response(self, task_id, message_id):
response = None
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
response = task.scheduler_message_responses.pop(message_id, None)
return {"response": response}
[docs]
@rpc_method()
def has_task_history(self):
return self._config.record_task_history
[docs]
@rpc_method()
def is_pause_enabled(self):
return {'enabled': self._config.pause_enabled}
[docs]
@rpc_method()
def is_paused(self):
return {'paused': self._paused}
[docs]
@rpc_method()
def pause(self):
if self._config.pause_enabled:
self._paused = True
[docs]
@rpc_method()
def unpause(self):
if self._config.pause_enabled:
self._paused = False
[docs]
@rpc_method()
def update_resources(self, **resources):
if self._resources is None:
self._resources = {}
self._resources.update(resources)
[docs]
@rpc_method()
def update_resource(self, resource, amount):
if not isinstance(amount, int) or amount < 0:
return False
self._resources[resource] = amount
return True
def _generate_retry_policy(self, task_retry_policy_dict):
retry_policy_dict = self._config._get_retry_policy()._asdict()
retry_policy_dict.update({k: v for k, v in task_retry_policy_dict.items() if v is not None})
return RetryPolicy(**retry_policy_dict)
def _has_resources(self, needed_resources, used_resources):
if needed_resources is None:
return True
available_resources = self._resources or {}
for resource, amount in needed_resources.items():
if amount + used_resources[resource] > available_resources.get(resource, 1):
return False
return True
def _used_resources(self):
used_resources = collections.defaultdict(int)
if self._resources is not None:
for task in self._state.get_active_tasks_by_status(RUNNING):
resources_running = getattr(task, "resources_running", task.resources)
if resources_running:
for resource, amount in resources_running.items():
used_resources[resource] += amount
return used_resources
def _rank(self, task):
"""
Return worker's rank function for task scheduling.
:return:
"""
return task.priority, -task.time
def _schedulable(self, task):
if task.status != PENDING:
return False
for dep in task.deps:
dep_task = self._state.get_task(dep, default=None)
if dep_task is None or dep_task.status != DONE:
return False
return True
def _reset_orphaned_batch_running_tasks(self, worker_id):
running_batch_ids = {
task.batch_id
for task in self._state.get_active_tasks_by_status(RUNNING)
if task.worker_running == worker_id
}
orphaned_tasks = [
task for task in self._state.get_active_tasks_by_status(BATCH_RUNNING)
if task.worker_running == worker_id and task.batch_id not in running_batch_ids
]
for task in orphaned_tasks:
self._state.set_status(task, PENDING)
[docs]
@rpc_method()
def count_pending(self, worker):
worker_id, worker = worker, self._state.get_worker(worker)
num_pending, num_unique_pending, num_pending_last_scheduled = 0, 0, 0
running_tasks = []
upstream_status_table = {}
for task in worker.get_tasks(self._state, RUNNING):
if self._upstream_status(task.id, upstream_status_table) == UPSTREAM_DISABLED:
continue
# Return a list of currently running tasks to the client,
# makes it easier to troubleshoot
other_worker = self._state.get_worker(task.worker_running)
if other_worker is not None:
more_info = {'task_id': task.id, 'worker': str(other_worker)}
more_info.update(other_worker.info)
running_tasks.append(more_info)
for task in worker.get_tasks(self._state, PENDING, FAILED):
if self._upstream_status(task.id, upstream_status_table) == UPSTREAM_DISABLED:
continue
num_pending += 1
num_unique_pending += int(len(task.workers) == 1)
num_pending_last_scheduled += int(task.workers.peek(last=True) == worker_id)
return {
'n_pending_tasks': num_pending,
'n_unique_pending': num_unique_pending,
'n_pending_last_scheduled': num_pending_last_scheduled,
'worker_state': worker.state,
'running_tasks': running_tasks,
}
[docs]
@rpc_method(allow_null=False)
def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, **kwargs):
# TODO: remove any expired nodes
# Algo: iterate over all nodes, find the highest priority node no dependencies and available
# resources.
# Resource checking looks both at currently available resources and at which resources would
# be available if all running tasks died and we rescheduled all workers greedily. We do both
# checks in order to prevent a worker with many low-priority tasks from starving other
# workers with higher priority tasks that share the same resources.
# TODO: remove tasks that can't be done, figure out if the worker has absolutely
# nothing it can wait for
if self._config.prune_on_get_work:
self.prune()
assert worker is not None
worker_id = worker
worker = self._update_worker(
worker_id,
worker_reference={'host': host},
get_work=True)
if not worker.enabled:
reply = {'n_pending_tasks': 0,
'running_tasks': [],
'task_id': None,
'n_unique_pending': 0,
'worker_state': worker.state,
}
return reply
if assistant:
self.add_worker(worker_id, [('assistant', assistant)])
batched_params, unbatched_params, batched_tasks, max_batch_size = None, None, [], 1
best_task = None
if current_tasks is not None:
ct_set = set(current_tasks)
for task in sorted(self._state.get_active_tasks_by_status(RUNNING), key=self._rank):
if task.worker_running == worker_id and task.id not in ct_set:
best_task = task
if current_tasks is not None:
# batch running tasks that weren't claimed since the last get_work go back in the pool
self._reset_orphaned_batch_running_tasks(worker_id)
greedy_resources = collections.defaultdict(int)
worker = self._state.get_worker(worker_id)
if self._paused:
relevant_tasks = []
elif worker.is_trivial_worker(self._state):
relevant_tasks = worker.get_tasks(self._state, PENDING, RUNNING)
used_resources = collections.defaultdict(int)
greedy_workers = dict() # If there's no resources, then they can grab any task
else:
relevant_tasks = self._state.get_active_tasks_by_status(PENDING, RUNNING)
used_resources = self._used_resources()
activity_limit = time.time() - self._config.worker_disconnect_delay
active_workers = self._state.get_active_workers(last_get_work_gt=activity_limit)
greedy_workers = dict((worker.id, worker.info.get('workers', 1))
for worker in active_workers)
tasks = list(relevant_tasks)
tasks.sort(key=self._rank, reverse=True)
for task in tasks:
if (best_task and batched_params and task.family == best_task.family and
len(batched_tasks) < max_batch_size and task.is_batchable() and all(
task.params.get(name) == value for name, value in unbatched_params.items()) and
task.resources == best_task.resources and self._schedulable(task)):
for name, params in batched_params.items():
params.append(task.params.get(name))
batched_tasks.append(task)
if best_task:
continue
if task.status == RUNNING and (task.worker_running in greedy_workers):
greedy_workers[task.worker_running] -= 1
for resource, amount in (getattr(task, 'resources_running', task.resources) or {}).items():
greedy_resources[resource] += amount
if self._schedulable(task) and self._has_resources(task.resources, greedy_resources):
in_workers = (assistant and task.runnable) or worker_id in task.workers
if in_workers and self._has_resources(task.resources, used_resources):
best_task = task
batch_param_names, max_batch_size = self._state.get_batcher(
worker_id, task.family)
if batch_param_names and task.is_batchable():
try:
batched_params = {
name: [task.params[name]] for name in batch_param_names
}
unbatched_params = {
name: value for name, value in task.params.items()
if name not in batched_params
}
batched_tasks.append(task)
except KeyError:
batched_params, unbatched_params = None, None
else:
workers = itertools.chain(task.workers, [worker_id]) if assistant else task.workers
for task_worker in workers:
if greedy_workers.get(task_worker, 0) > 0:
# use up a worker
greedy_workers[task_worker] -= 1
# keep track of the resources used in greedy scheduling
for resource, amount in (task.resources or {}).items():
greedy_resources[resource] += amount
break
reply = self.count_pending(worker_id)
if len(batched_tasks) > 1:
batch_string = '|'.join(task.id for task in batched_tasks)
batch_id = hashlib.new('md5', batch_string.encode('utf-8'), usedforsecurity=False).hexdigest()
for task in batched_tasks:
self._state.set_batch_running(task, batch_id, worker_id)
combined_params = best_task.params.copy()
combined_params.update(batched_params)
reply['task_id'] = None
reply['task_family'] = best_task.family
reply['task_module'] = getattr(best_task, 'module', None)
reply['task_params'] = combined_params
reply['batch_id'] = batch_id
reply['batch_task_ids'] = [task.id for task in batched_tasks]
elif best_task:
self.update_metrics_task_started(best_task)
self._state.set_status(best_task, RUNNING, self._config)
best_task.worker_running = worker_id
best_task.resources_running = best_task.resources.copy()
best_task.time_running = time.time()
self._update_task_history(best_task, RUNNING, host=host)
reply['task_id'] = best_task.id
reply['task_family'] = best_task.family
reply['task_module'] = getattr(best_task, 'module', None)
reply['task_params'] = best_task.params
else:
reply['task_id'] = None
return reply
[docs]
@rpc_method(attempts=1)
def ping(self, **kwargs):
worker_id = kwargs['worker']
worker = self._update_worker(worker_id)
return {"rpc_messages": worker.fetch_rpc_messages()}
def _upstream_status(self, task_id, upstream_status_table):
if task_id in upstream_status_table:
return upstream_status_table[task_id]
elif self._state.has_task(task_id):
task_stack = [task_id]
while task_stack:
dep_id = task_stack.pop()
dep = self._state.get_task(dep_id)
if dep:
if dep.status == DONE:
continue
if dep_id not in upstream_status_table:
if dep.status == PENDING and dep.deps:
task_stack += [dep_id] + list(dep.deps)
upstream_status_table[dep_id] = '' # will be updated postorder
else:
dep_status = STATUS_TO_UPSTREAM_MAP.get(dep.status, '')
upstream_status_table[dep_id] = dep_status
elif upstream_status_table[dep_id] == '' and dep.deps:
# This is the postorder update step when we set the
# status based on the previously calculated child elements
status = max((upstream_status_table.get(a_task_id, '')
for a_task_id in dep.deps),
key=UPSTREAM_SEVERITY_KEY)
upstream_status_table[dep_id] = status
return upstream_status_table[dep_id]
def _serialize_task(self, task_id, include_deps=True, deps=None):
task = self._state.get_task(task_id)
ret = {
'display_name': task.pretty_id,
'status': task.status,
'workers': list(task.workers),
'worker_running': task.worker_running,
'time_running': getattr(task, "time_running", None),
'start_time': task.time,
'last_updated': getattr(task, "updated", task.time),
'params': task.public_params,
'name': task.family,
'priority': task.priority,
'resources': task.resources,
'resources_running': getattr(task, "resources_running", None),
'tracking_url': getattr(task, "tracking_url", None),
'status_message': getattr(task, "status_message", None),
'progress_percentage': getattr(task, "progress_percentage", None),
}
if task.status == DISABLED:
ret['re_enable_able'] = task.scheduler_disable_time is not None
if include_deps:
ret['deps'] = list(task.deps if deps is None else deps)
if self._config.send_messages and task.status == RUNNING:
ret['accepts_messages'] = task.accepts_messages
return ret
[docs]
@rpc_method()
def graph(self, **kwargs):
self.prune()
serialized = {}
seen = set()
for task in self._state.get_active_tasks():
serialized.update(self._traverse_graph(task.id, seen))
return serialized
def _filter_done(self, task_ids):
for task_id in task_ids:
task = self._state.get_task(task_id)
if task is None or task.status != DONE:
yield task_id
def _traverse_graph(self, root_task_id, seen=None, dep_func=None, include_done=True):
""" Returns the dependency graph rooted at task_id
This does a breadth-first traversal to find the nodes closest to the
root before hitting the scheduler.max_graph_nodes limit.
:param root_task_id: the id of the graph's root
:return: A map of task id to serialized node
"""
if seen is None:
seen = set()
elif root_task_id in seen:
return {}
if dep_func is None:
def dep_func(t):
return t.deps
seen.add(root_task_id)
serialized = {}
queue = collections.deque([root_task_id])
while queue:
task_id = queue.popleft()
task = self._state.get_task(task_id)
if task is None or not task.family:
logger.debug('Missing task for id [%s]', task_id)
# NOTE : If a dependency is missing from self._state there is no way to deduce the
# task family and parameters.
family_match = TASK_FAMILY_RE.match(task_id)
family = family_match.group(1) if family_match else UNKNOWN
params = {'task_id': task_id}
serialized[task_id] = {
'deps': [],
'status': UNKNOWN,
'workers': [],
'start_time': UNKNOWN,
'params': params,
'name': family,
'display_name': task_id,
'priority': 0,
}
else:
deps = dep_func(task)
if not include_done:
deps = list(self._filter_done(deps))
serialized[task_id] = self._serialize_task(task_id, deps=deps)
for dep in sorted(deps):
if dep not in seen:
seen.add(dep)
queue.append(dep)
if task_id != root_task_id:
del serialized[task_id]['display_name']
if len(serialized) >= self._config.max_graph_nodes:
break
return serialized
[docs]
@rpc_method()
def dep_graph(self, task_id, include_done=True, **kwargs):
self.prune()
if not self._state.has_task(task_id):
return {}
return self._traverse_graph(task_id, include_done=include_done)
[docs]
@rpc_method()
def inverse_dep_graph(self, task_id, include_done=True, **kwargs):
self.prune()
if not self._state.has_task(task_id):
return {}
inverse_graph = collections.defaultdict(set)
for task in self._state.get_active_tasks():
for dep in task.deps:
inverse_graph[dep].add(task.id)
return self._traverse_graph(
task_id, dep_func=lambda t: inverse_graph[t.id], include_done=include_done)
[docs]
@rpc_method()
def task_list(self, status='', upstream_status='', limit=True, search=None, max_shown_tasks=None,
**kwargs):
"""
Query for a subset of tasks by status.
"""
if not search:
count_limit = max_shown_tasks or self._config.max_shown_tasks
pre_count = self._state.get_active_task_count_for_status(status)
if limit and pre_count > count_limit:
return {'num_tasks': -1 if upstream_status else pre_count}
self.prune()
result = {}
upstream_status_table = {} # used to memoize upstream status
if search is None:
def filter_func(_):
return True
else:
terms = search.split()
def filter_func(t):
return all(term.casefold() in t.pretty_id.casefold() for term in terms)
tasks = self._state.get_active_tasks_by_status(status) if status else self._state.get_active_tasks()
for task in filter(filter_func, tasks):
if task.status != PENDING or not upstream_status or upstream_status == self._upstream_status(task.id, upstream_status_table):
serialized = self._serialize_task(task.id, include_deps=False)
result[task.id] = serialized
if limit and len(result) > (max_shown_tasks or self._config.max_shown_tasks):
return {'num_tasks': len(result)}
return result
def _first_task_display_name(self, worker):
task_id = worker.info.get('first_task', '')
if self._state.has_task(task_id):
return self._state.get_task(task_id).pretty_id
else:
return task_id
[docs]
@rpc_method()
def worker_list(self, include_running=True, **kwargs):
self.prune()
workers = [
dict(
name=worker.id,
last_active=worker.last_active,
started=worker.started,
state=worker.state,
first_task_display_name=self._first_task_display_name(worker),
num_unread_rpc_messages=len(worker.rpc_messages),
**worker.info
) for worker in self._state.get_active_workers()]
workers.sort(key=lambda worker: worker['started'], reverse=True)
if include_running:
running = collections.defaultdict(dict)
for task in self._state.get_active_tasks_by_status(RUNNING):
if task.worker_running:
running[task.worker_running][task.id] = self._serialize_task(task.id, include_deps=False)
num_pending = collections.defaultdict(int)
num_uniques = collections.defaultdict(int)
for task in self._state.get_active_tasks_by_status(PENDING):
for worker in task.workers:
num_pending[worker] += 1
if len(task.workers) == 1:
num_uniques[list(task.workers)[0]] += 1
for worker in workers:
tasks = running[worker['name']]
worker['num_running'] = len(tasks)
worker['num_pending'] = num_pending[worker['name']]
worker['num_uniques'] = num_uniques[worker['name']]
worker['running'] = tasks
return workers
[docs]
@rpc_method()
def resource_list(self):
"""
Resources usage info and their consumers (tasks).
"""
self.prune()
resources = [
dict(
name=resource,
num_total=r_dict['total'],
num_used=r_dict['used']
) for resource, r_dict in self.resources().items()]
if self._resources is not None:
consumers = collections.defaultdict(dict)
for task in self._state.get_active_tasks_by_status(RUNNING):
if task.status == RUNNING and task.resources:
for resource, amount in task.resources.items():
consumers[resource][task.id] = self._serialize_task(task.id, include_deps=False)
for resource in resources:
tasks = consumers[resource['name']]
resource['num_consumer'] = len(tasks)
resource['running'] = tasks
return resources
[docs]
def resources(self):
''' get total resources and available ones '''
used_resources = self._used_resources()
ret = collections.defaultdict(dict)
for resource, total in self._resources.items():
ret[resource]['total'] = total
if resource in used_resources:
ret[resource]['used'] = used_resources[resource]
else:
ret[resource]['used'] = 0
return ret
[docs]
@rpc_method()
def task_search(self, task_str, **kwargs):
"""
Query for a subset of tasks by task_id.
:param task_str:
:return:
"""
self.prune()
result = collections.defaultdict(dict)
for task in self._state.get_active_tasks():
if task.id.find(task_str) != -1:
serialized = self._serialize_task(task.id, include_deps=False)
result[task.status][task.id] = serialized
return result
[docs]
@rpc_method()
def re_enable_task(self, task_id):
serialized = {}
task = self._state.get_task(task_id)
if task and task.status == DISABLED and task.scheduler_disable_time:
self._state.re_enable(task, self._config)
serialized = self._serialize_task(task_id)
return serialized
[docs]
@rpc_method()
def fetch_error(self, task_id, **kwargs):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
return {"taskId": task_id, "error": task.expl, 'displayName':
task.pretty_id, 'taskParams': task.params, 'taskModule':
task.module, 'taskFamily': task.family}
else:
return {"taskId": task_id, "error": ""}
[docs]
@rpc_method()
def set_task_status_message(self, task_id, status_message):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
task.status_message = status_message
if task.status == RUNNING and task.batch_id is not None:
for batch_task in self._state.get_batch_running_tasks(task.batch_id):
batch_task.status_message = status_message
[docs]
@rpc_method()
def get_task_status_message(self, task_id):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
return {"taskId": task_id, "statusMessage": task.status_message}
else:
return {"taskId": task_id, "statusMessage": ""}
[docs]
@rpc_method()
def set_task_progress_percentage(self, task_id, progress_percentage):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
task.progress_percentage = progress_percentage
if task.status == RUNNING and task.batch_id is not None:
for batch_task in self._state.get_batch_running_tasks(task.batch_id):
batch_task.progress_percentage = progress_percentage
[docs]
@rpc_method()
def get_task_progress_percentage(self, task_id):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
return {"taskId": task_id, "progressPercentage": task.progress_percentage}
else:
return {"taskId": task_id, "progressPercentage": None}
[docs]
@rpc_method()
def decrease_running_task_resources(self, task_id, decrease_resources):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
if task.status != RUNNING:
return
def decrease(resources, decrease_resources):
for resource, decrease_amount in decrease_resources.items():
if decrease_amount > 0 and resource in resources:
resources[resource] = max(0, resources[resource] - decrease_amount)
decrease(task.resources_running, decrease_resources)
if task.batch_id is not None:
for batch_task in self._state.get_batch_running_tasks(task.batch_id):
decrease(batch_task.resources_running, decrease_resources)
[docs]
@rpc_method()
def get_running_task_resources(self, task_id):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
return {"taskId": task_id, "resources": getattr(task, "resources_running", None)}
else:
return {"taskId": task_id, "resources": None}
def _update_task_history(self, task, status, host=None):
try:
if status == DONE or status == FAILED:
successful = (status == DONE)
self._task_history.task_finished(task, successful)
elif status == PENDING:
self._task_history.task_scheduled(task)
elif status == RUNNING:
self._task_history.task_started(task, host)
except BaseException:
logger.warning("Error saving Task history", exc_info=True)
@property
def task_history(self):
# Used by server.py to expose the calls
return self._task_history
[docs]
@rpc_method()
def update_metrics_task_started(self, task):
self._state._metrics_collector.handle_task_started(task)