Source code for luigi.contrib.ssh

# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Light-weight remote execution library and utilities.

There are some examples in the unittest but I added another that is more
luigi-specific in the examples directory (examples/ssh_remote_execution.py)

:class:`RemoteContext` is meant to provide functionality similar to that of the
standard library subprocess module, but where the commands executed are run on
a remote machine instead, without the user having to think about prefixing
everything with "ssh" and credentials etc.

Using this mini library (which is just a convenience wrapper for subprocess),
:class:`RemoteTarget` is created to let you stream data from a remotely stored file using
the luigi :class:`~luigi.target.FileSystemTarget` semantics.

As a bonus, :class:`RemoteContext` also provides a really cool feature that let's you
set up ssh tunnels super easily using a python context manager (there is an example
in the integration part of unittests).

This can be super convenient when you want secure communication using a non-secure
protocol or circumvent firewalls (as long as they are open for ssh traffic).
"""

import contextlib
import logging
import os
import random
import subprocess
import posixpath

import luigi
import luigi.format
import luigi.target


logger = logging.getLogger('luigi-interface')


[docs] class RemoteCalledProcessError(subprocess.CalledProcessError): def __init__(self, returncode, command, host, output=None): super(RemoteCalledProcessError, self).__init__(returncode, command, output) self.host = host def __str__(self): return "Command '%s' on host %s returned non-zero exit status %d" % ( self.cmd, self.host, self.returncode)
[docs] class RemoteContext: def __init__(self, host, **kwargs): self.host = host self.username = kwargs.get('username', None) self.key_file = kwargs.get('key_file', None) self.connect_timeout = kwargs.get('connect_timeout', None) self.port = kwargs.get('port', None) self.no_host_key_check = kwargs.get('no_host_key_check', False) self.sshpass = kwargs.get('sshpass', False) self.tty = kwargs.get('tty', False) def __repr__(self): return '%s(%r, %r, %r, %r, %r)' % ( type(self).__name__, self.host, self.username, self.key_file, self.connect_timeout, self.port) def __eq__(self, other): return repr(self) == repr(other) def __hash__(self): return hash(repr(self)) def _host_ref(self): if self.username: return "{0}@{1}".format(self.username, self.host) else: return self.host def _prepare_cmd(self, cmd): connection_cmd = ["ssh", self._host_ref(), "-o", "ControlMaster=no"] if self.sshpass: connection_cmd = ["sshpass", "-e"] + connection_cmd else: connection_cmd += ["-o", "BatchMode=yes"] # no password prompts etc if self.port: connection_cmd.extend(["-p", self.port]) if self.connect_timeout is not None: connection_cmd += ['-o', 'ConnectTimeout=%d' % self.connect_timeout] if self.no_host_key_check: connection_cmd += ['-o', 'UserKnownHostsFile=/dev/null', '-o', 'StrictHostKeyChecking=no'] if self.key_file: connection_cmd.extend(["-i", self.key_file]) if self.tty: connection_cmd.append('-t') return connection_cmd + cmd
[docs] def Popen(self, cmd, **kwargs): """ Remote Popen. """ prefixed_cmd = self._prepare_cmd(cmd) return subprocess.Popen(prefixed_cmd, **kwargs)
[docs] def check_output(self, cmd): """ Execute a shell command remotely and return the output. Simplified version of Popen when you only want the output as a string and detect any errors. """ p = self.Popen(cmd, stdout=subprocess.PIPE) output, _ = p.communicate() if p.returncode != 0: raise RemoteCalledProcessError(p.returncode, cmd, self.host, output=output) return output
[docs] @contextlib.contextmanager def tunnel(self, local_port, remote_port=None, remote_host="localhost"): """ Open a tunnel between localhost:local_port and remote_host:remote_port via the host specified by this context. Remember to close() the returned "tunnel" object in order to clean up after yourself when you are done with the tunnel. """ tunnel_host = "{0}:{1}:{2}".format(local_port, remote_host, remote_port) proc = self.Popen( # cat so we can shut down gracefully by closing stdin ["-L", tunnel_host, "echo -n ready && cat"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, ) # make sure to get the data so we know the connection is established ready = proc.stdout.read(5) assert ready == b"ready", "Didn't get ready from remote echo" yield # user code executed here proc.communicate() assert proc.returncode == 0, "Tunnel process did an unclean exit (returncode %s)" % (proc.returncode,)
[docs] class RemoteFileSystem(luigi.target.FileSystem): def __init__(self, host, **kwargs): self.remote_context = RemoteContext(host, **kwargs)
[docs] def exists(self, path): """ Return `True` if file or directory at `path` exist, False otherwise. """ try: self.remote_context.check_output(["test", "-e", path]) except subprocess.CalledProcessError as e: if e.returncode == 1: return False else: raise return True
[docs] def listdir(self, path): while path.endswith('/'): path = path[:-1] path = path or '.' listing = self.remote_context.check_output(["find", "-L", path, "-type", "f"]).splitlines() return [v.decode('utf-8') for v in listing]
[docs] def isdir(self, path): """ Return `True` if directory at `path` exist, False otherwise. """ try: self.remote_context.check_output(["test", "-d", path]) except subprocess.CalledProcessError as e: if e.returncode == 1: return False else: raise return True
[docs] def remove(self, path, recursive=True): """ Remove file or directory at location `path`. """ if recursive: cmd = ["rm", "-r", path] else: cmd = ["rm", path] self.remote_context.check_output(cmd)
[docs] def mkdir(self, path, parents=True, raise_if_exists=False): if self.exists(path): if raise_if_exists: raise luigi.target.FileAlreadyExists() elif not self.isdir(path): raise luigi.target.NotADirectory() else: return if parents: cmd = ['mkdir', '-p', path] else: cmd = ['mkdir', path, '2>&1'] try: self.remote_context.check_output(cmd) except subprocess.CalledProcessError as e: if b'no such file' in e.output.lower(): raise luigi.target.MissingParentDirectory() raise
def _scp(self, src, dest): cmd = ["scp", "-q", "-C", "-o", "ControlMaster=no"] if self.remote_context.sshpass: cmd = ["sshpass", "-e"] + cmd else: cmd.append("-B") if self.remote_context.no_host_key_check: cmd.extend(['-o', 'UserKnownHostsFile=/dev/null', '-o', 'StrictHostKeyChecking=no']) if self.remote_context.key_file: cmd.extend(["-i", self.remote_context.key_file]) if self.remote_context.port: cmd.extend(["-P", self.remote_context.port]) if os.path.isdir(src): cmd.extend(["-r"]) cmd.extend([src, dest]) p = subprocess.Popen(cmd) output, _ = p.communicate() if p.returncode != 0: raise subprocess.CalledProcessError(p.returncode, cmd, output=output)
[docs] def put(self, local_path, path): # create parent folder if not exists normpath = posixpath.normpath(path) folder = os.path.dirname(normpath) if folder and not self.exists(folder): self.remote_context.check_output(['mkdir', '-p', folder]) tmp_path = path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) self._scp(local_path, "%s:%s" % (self.remote_context._host_ref(), tmp_path)) self.remote_context.check_output(['mv', tmp_path, path])
[docs] def get(self, path, local_path): # Create folder if it does not exist normpath = os.path.normpath(local_path) folder = os.path.dirname(normpath) if folder: try: os.makedirs(folder) except OSError: pass tmp_local_path = local_path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) self._scp("%s:%s" % (self.remote_context._host_ref(), path), tmp_local_path) os.rename(tmp_local_path, local_path)
[docs] class AtomicRemoteFileWriter(luigi.format.OutputPipeProcessWrapper): def __init__(self, fs, path): self._fs = fs self.path = path # create parent folder if not exists normpath = os.path.normpath(self.path) folder = os.path.dirname(normpath) if folder: self.fs.mkdir(folder) self.__tmp_path = self.path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) super(AtomicRemoteFileWriter, self).__init__( self.fs.remote_context._prepare_cmd(['cat', '>', self.__tmp_path])) def __del__(self): super(AtomicRemoteFileWriter, self).__del__() try: if self.fs.exists(self.__tmp_path): self.fs.remote_context.check_output(['rm', self.__tmp_path]) except Exception: # Don't propagate the exception; bad things can happen. logger.exception('Failed to delete in-flight file')
[docs] def close(self): super(AtomicRemoteFileWriter, self).close() self.fs.remote_context.check_output(['mv', self.__tmp_path, self.path])
@property def tmp_path(self): return self.__tmp_path @property def fs(self): return self._fs
[docs] class RemoteTarget(luigi.target.FileSystemTarget): """ Target used for reading from remote files. The target is implemented using ssh commands streaming data over the network. """ def __init__(self, path, host, format=None, **kwargs): super(RemoteTarget, self).__init__(path) if format is None: format = luigi.format.get_default_format() self.format = format self._fs = RemoteFileSystem(host, **kwargs) @property def fs(self): return self._fs
[docs] def open(self, mode='r'): if mode == 'w': file_writer = AtomicRemoteFileWriter(self.fs, self.path) if self.format: return self.format.pipe_writer(file_writer) else: return file_writer elif mode == 'r': file_reader = luigi.format.InputPipeProcessWrapper( self.fs.remote_context._prepare_cmd(["cat", self.path])) if self.format: return self.format.pipe_reader(file_reader) else: return file_reader else: raise Exception("mode must be 'r' or 'w' (got: %s)" % mode)
[docs] def put(self, local_path): self.fs.put(local_path, self.path)
[docs] def get(self, local_path): self.fs.get(self.path, local_path)