Source code for luigi.mock
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module provides a class :class:`MockTarget`, an implementation of :py:class:`~luigi.target.Target`.
:class:`MockTarget` contains all data in-memory.
The main purpose is unit testing workflows without writing to disk.
"""
import multiprocessing
from io import BytesIO
import sys
from luigi import target
from luigi.format import get_default_format
[docs]
class MockFileSystem(target.FileSystem):
"""
MockFileSystem inspects/modifies _data to simulate file system operations.
"""
_data = None
[docs]
def copy(self, path, dest, raise_if_exists=False):
"""
Copies the contents of a single file path to dest
"""
if raise_if_exists and dest in self.get_all_data():
raise RuntimeError('Destination exists: %s' % path)
contents = self.get_all_data()[path]
self.get_all_data()[dest] = contents
[docs]
def get_all_data(self):
# This starts a server in the background, so we don't want to do it in the global scope
if MockFileSystem._data is None:
MockFileSystem._data = multiprocessing.Manager().dict()
return MockFileSystem._data
[docs]
def get_data(self, fn):
return self.get_all_data()[fn]
[docs]
def exists(self, path):
return MockTarget(path).exists()
[docs]
def remove(self, path, recursive=True, skip_trash=True):
"""
Removes the given mockfile. skip_trash doesn't have any meaning.
"""
if recursive:
to_delete = []
for s in self.get_all_data().keys():
if s.startswith(path):
to_delete.append(s)
for s in to_delete:
self.get_all_data().pop(s)
else:
self.get_all_data().pop(path)
[docs]
def move(self, path, dest, raise_if_exists=False):
"""
Moves a single file from path to dest
"""
if raise_if_exists and dest in self.get_all_data():
raise RuntimeError('Destination exists: %s' % path)
contents = self.get_all_data().pop(path)
self.get_all_data()[dest] = contents
[docs]
def listdir(self, path):
"""
listdir does a prefix match of self.get_all_data(), but doesn't yet support globs.
"""
return [s for s in self.get_all_data().keys()
if s.startswith(path)]
[docs]
def isdir(self, path):
return any(self.listdir(path))
[docs]
def mkdir(self, path, parents=True, raise_if_exists=False):
"""
mkdir is a noop.
"""
pass
[docs]
def clear(self):
self.get_all_data().clear()
[docs]
class MockTarget(target.FileSystemTarget):
fs = MockFileSystem()
def __init__(self, fn, is_tmp=None, mirror_on_stderr=False, format=None):
self._mirror_on_stderr = mirror_on_stderr
self.path = fn
self.format = format or get_default_format()
[docs]
def exists(self,):
return self.path in self.fs.get_all_data()
[docs]
def move(self, path, raise_if_exists=False):
"""
Call MockFileSystem's move command
"""
self.fs.move(self.path, path, raise_if_exists)
[docs]
def rename(self, *args, **kwargs):
"""
Call move to rename self
"""
self.move(*args, **kwargs)
[docs]
def open(self, mode='r'):
fn = self.path
mock_target = self
class Buffer(BytesIO):
# Just to be able to do writing + reading from the same buffer
_write_line = True
def set_wrapper(self, wrapper):
self.wrapper = wrapper
def write(self, data):
if mock_target._mirror_on_stderr:
if self._write_line:
sys.stderr.write(fn + ": ")
if bytes:
sys.stderr.write(data.decode('utf8'))
else:
sys.stderr.write(data)
if (data[-1]) == '\n':
self._write_line = True
else:
self._write_line = False
super(Buffer, self).write(data)
def close(self):
if mode[0] == 'w':
try:
mock_target.wrapper.flush()
except AttributeError:
pass
mock_target.fs.get_all_data()[fn] = self.getvalue()
super(Buffer, self).close()
def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_type:
self.close()
def __enter__(self):
return self
def readable(self):
return mode[0] == 'r'
def writeable(self):
return mode[0] == 'w'
def seekable(self):
return False
if mode[0] == 'w':
wrapper = self.format.pipe_writer(Buffer())
wrapper.set_wrapper(wrapper)
return wrapper
else:
return self.format.pipe_reader(Buffer(self.fs.get_all_data()[fn]))