mirror of
https://github.com/emilybache/GildedRose-Refactoring-Kata.git
synced 2026-02-08 19:21:28 +00:00
428 lines
12 KiB
Python
428 lines
12 KiB
Python
# Copyright 2009-2016 Yelp and Contributors
|
|
# Copyright 2017-2018 Yelp
|
|
# Copyright 2019 Yelp
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utility functions for MRJob
|
|
"""
|
|
# don't add imports here that aren't part of the standard Python library,
|
|
# since MRJobs need to run in Amazon's generic EMR environment
|
|
import logging
|
|
import os
|
|
import os.path
|
|
import pipes
|
|
import random
|
|
import shlex
|
|
import shutil
|
|
import sys
|
|
import tarfile
|
|
from contextlib import contextmanager
|
|
from datetime import timedelta
|
|
from distutils.spawn import find_executable
|
|
from logging import getLogger
|
|
from zipfile import ZIP_DEFLATED
|
|
from zipfile import ZIP_STORED
|
|
from zipfile import ZipFile
|
|
from zipfile import is_zipfile
|
|
|
|
from mrjob.py2 import PY2
|
|
|
|
|
|
log = getLogger(__name__)
|
|
|
|
|
|
class NullHandler(logging.Handler):
|
|
def emit(self, record):
|
|
pass
|
|
|
|
|
|
def cmd_line(args):
|
|
"""build a command line that works in a shell.
|
|
"""
|
|
args = [str(x) for x in args]
|
|
return ' '.join(pipes.quote(x) for x in args)
|
|
|
|
|
|
def expand_path(path):
|
|
"""Resolve ``~`` (home dir) and environment variables in *path*.
|
|
|
|
If *path* is ``None``, return ``None``.
|
|
"""
|
|
if path is None:
|
|
return None
|
|
else:
|
|
return os.path.expanduser(os.path.expandvars(path))
|
|
|
|
|
|
def file_ext(filename):
|
|
"""return the file extension, including the ``.``
|
|
|
|
>>> file_ext('foo.tar.gz')
|
|
'.tar.gz'
|
|
|
|
>>> file_ext('.emacs')
|
|
''
|
|
|
|
>>> file_ext('.mrjob.conf')
|
|
'.conf'
|
|
"""
|
|
stripped_name = filename.lstrip('.')
|
|
dot_index = stripped_name.find('.')
|
|
|
|
if dot_index == -1:
|
|
return ''
|
|
return stripped_name[dot_index:]
|
|
|
|
|
|
def log_to_null(name=None):
|
|
"""Set up a null handler for the given stream, to suppress
|
|
"no handlers could be found" warnings."""
|
|
logger = logging.getLogger(name)
|
|
logger.addHandler(NullHandler())
|
|
|
|
|
|
def log_to_stream(name=None, stream=None, format=None, level=None,
|
|
debug=False):
|
|
"""Set up logging.
|
|
|
|
:type name: str
|
|
:param name: name of the logger, or ``None`` for the root logger
|
|
:type stream: file object
|
|
:param stream: stream to log to (default is ``sys.stderr``)
|
|
:type format: str
|
|
:param format: log message format (default is '%(message)s')
|
|
:param level: log level to use
|
|
:type debug: bool
|
|
:param debug: quick way of setting the log level: if true, use
|
|
``logging.DEBUG``, otherwise use ``logging.INFO``
|
|
"""
|
|
if level is None:
|
|
level = logging.DEBUG if debug else logging.INFO
|
|
|
|
if format is None:
|
|
format = '%(message)s'
|
|
|
|
if stream is None:
|
|
stream = sys.stderr
|
|
|
|
handler = logging.StreamHandler(stream)
|
|
handler.setLevel(level)
|
|
handler.setFormatter(logging.Formatter(format))
|
|
|
|
logger = logging.getLogger(name)
|
|
logger.setLevel(level)
|
|
logger.addHandler(handler)
|
|
|
|
|
|
def random_identifier():
|
|
"""A random 16-digit hex string."""
|
|
return '%016x' % random.randint(0, 2 ** 64 - 1)
|
|
|
|
|
|
# Thanks to http://lybniz2.sourceforge.net/safeeval.html for
|
|
# explaining how to do this!
|
|
def safeeval(expr, globals=None, locals=None):
|
|
"""Like eval, but with nearly everything in the environment
|
|
blanked out, so that it's difficult to cause mischief.
|
|
|
|
*globals* and *locals* are optional dictionaries mapping names to
|
|
values for those names (just like in :py:func:`eval`).
|
|
"""
|
|
# blank out builtins, but keep None, True, and False
|
|
safe_globals = {
|
|
'False': False,
|
|
'None': None,
|
|
'True': True,
|
|
'__builtin__': None,
|
|
'__builtins__': None,
|
|
'set': set
|
|
}
|
|
|
|
# xrange is range in Python 3
|
|
if PY2:
|
|
safe_globals['xrange'] = xrange
|
|
else:
|
|
safe_globals['range'] = range
|
|
|
|
# PyPy needs special magic
|
|
def open(*args, **kwargs):
|
|
raise NameError("name 'open' is not defined")
|
|
safe_globals['open'] = open
|
|
|
|
# add the user-specified global variables
|
|
if globals:
|
|
safe_globals.update(globals)
|
|
|
|
return eval(expr, safe_globals, locals)
|
|
|
|
|
|
@contextmanager
|
|
def save_current_environment():
|
|
""" Context manager that saves os.environ and loads
|
|
it back again after execution
|
|
"""
|
|
original_environ = os.environ.copy()
|
|
|
|
try:
|
|
yield
|
|
|
|
finally:
|
|
os.environ.clear()
|
|
os.environ.update(original_environ)
|
|
|
|
|
|
@contextmanager
|
|
def save_cwd():
|
|
"""Context manager that saves the current working directory,
|
|
and chdir's back to it after execution."""
|
|
original_cwd = os.getcwd()
|
|
|
|
try:
|
|
yield
|
|
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
@contextmanager
|
|
def save_sys_std():
|
|
"""Context manager that saves the current values of `sys.stdin`,
|
|
`sys.stdout`, and `sys.stderr`, and flushes these filehandles before
|
|
and after switching them out."""
|
|
|
|
stdin, stdout, stderr = sys.stdin, sys.stdout, sys.stderr
|
|
|
|
try:
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
|
|
yield
|
|
|
|
# at this point, sys.stdout/stderr may have been patched. Don't
|
|
# raise an exception if flush() fails
|
|
try:
|
|
sys.stdout.flush()
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
sys.stderr.flush()
|
|
except:
|
|
pass
|
|
finally:
|
|
sys.stdin, sys.stdout, sys.stderr = stdin, stdout, stderr
|
|
|
|
|
|
@contextmanager
|
|
def save_sys_path():
|
|
"""Context manager that saves sys.path and restores it after execution."""
|
|
original_sys_path = list(sys.path)
|
|
|
|
try:
|
|
yield
|
|
|
|
finally:
|
|
sys.path = original_sys_path
|
|
|
|
|
|
def shlex_split(s):
|
|
"""Wrapper around shlex.split(), but convert to str if Python version <
|
|
2.7.3 when unicode support was added.
|
|
"""
|
|
if sys.version_info < (2, 7, 3):
|
|
return shlex.split(str(s))
|
|
else:
|
|
return shlex.split(s)
|
|
|
|
|
|
def strip_microseconds(delta):
|
|
"""Return the given :py:class:`datetime.timedelta`, without microseconds.
|
|
|
|
Useful for printing :py:class:`datetime.timedelta` objects.
|
|
"""
|
|
return timedelta(delta.days, delta.seconds)
|
|
|
|
|
|
def to_lines(chunks):
|
|
"""Take in data as a sequence of bytes, and yield it, one line at a time.
|
|
|
|
Only breaks lines on ``\\n`` (not ``\\r``), and does not add
|
|
a trailing newline.
|
|
|
|
For efficiency, passes through anything with a ``readline()`` attribute.
|
|
"""
|
|
# hopefully this is good enough for anything mrjob will encounter
|
|
if hasattr(chunks, 'readline'):
|
|
return chunks
|
|
else:
|
|
return _to_lines(chunks)
|
|
|
|
|
|
def _to_lines(chunks):
|
|
"""Take in data as a sequence of bytes, and yield it, one line at a time.
|
|
|
|
Only breaks lines on ``\\n`` (not ``\\r``), and does not add
|
|
a trailing newline.
|
|
|
|
Exception: if we encounter an empty bytestring ``b''``, immediately yield
|
|
what we have so far rather than joining it to the next chunk. This allows
|
|
us to handle bytes from multiple files without joining the end of one
|
|
file to the beginning of the next one.
|
|
|
|
Optimizes for:
|
|
|
|
* chunks bigger than lines (e.g. reading test files)
|
|
* chunks that are lines (idempotency)
|
|
"""
|
|
# list of chunks with no final newline
|
|
leftovers = []
|
|
|
|
for chunk in chunks:
|
|
# special case for b'' standing for EOF
|
|
if chunk == b'':
|
|
if leftovers:
|
|
yield b''.join(leftovers)
|
|
leftovers = []
|
|
|
|
continue
|
|
|
|
start = 0
|
|
|
|
while start < len(chunk):
|
|
end = chunk.find(b'\n', start) + 1
|
|
|
|
if end == 0: # no newlines found
|
|
leftovers.append(chunk[start:])
|
|
break
|
|
|
|
if leftovers:
|
|
leftovers.append(chunk[start:end])
|
|
yield b''.join(leftovers)
|
|
leftovers = []
|
|
else:
|
|
yield chunk[start:end]
|
|
|
|
start = end
|
|
|
|
if leftovers:
|
|
yield b''.join(leftovers)
|
|
|
|
|
|
def unique(items):
|
|
"""Yield items from *item* in order, skipping duplicates."""
|
|
seen = set()
|
|
|
|
for item in items:
|
|
if item in seen:
|
|
continue
|
|
else:
|
|
yield item
|
|
seen.add(item)
|
|
|
|
|
|
def unarchive(archive_path, dest):
|
|
"""Extract the contents of a tar or zip file at *archive_path* into the
|
|
directory *dest*.
|
|
|
|
:type archive_path: str
|
|
:param archive_path: path to archive file
|
|
:type dest: str
|
|
:param dest: path to directory where archive will be extracted
|
|
|
|
*dest* will be created if it doesn't already exist.
|
|
|
|
tar files can be gzip compressed, bzip2 compressed, or uncompressed. Files
|
|
within zip files can be deflated or stored.
|
|
"""
|
|
if tarfile.is_tarfile(archive_path):
|
|
with tarfile.open(archive_path, 'r') as archive:
|
|
archive.extractall(dest)
|
|
elif is_zipfile(archive_path):
|
|
with ZipFile(archive_path, 'r') as archive:
|
|
for name in archive.namelist():
|
|
# the zip spec specifies that front slashes are always
|
|
# used as directory separators
|
|
dest_path = os.path.join(dest, *name.split('/'))
|
|
|
|
# now, split out any dirname and filename and create
|
|
# one and/or the other
|
|
dirname, filename = os.path.split(dest_path)
|
|
if dirname and not os.path.exists(dirname):
|
|
os.makedirs(dirname)
|
|
if filename:
|
|
with open(dest_path, 'wb') as dest_file:
|
|
dest_file.write(archive.read(name))
|
|
else:
|
|
raise IOError('Unknown archive type: %s' % (archive_path,))
|
|
|
|
|
|
def which(cmd, path=None):
|
|
"""Like the UNIX which command: search in *path* for the executable named
|
|
*cmd*. *path* defaults to :envvar:`PATH`. Returns ``None`` if no
|
|
such executable found.
|
|
|
|
This is basically ``shutil.which()`` (which was introduced in Python 3.3)
|
|
without the *mode* argument. Best practice is to always specify *path*
|
|
as a keyword argument.
|
|
"""
|
|
if hasattr(shutil, 'which'):
|
|
# added in Python 3.3
|
|
return shutil.which(cmd, path=path)
|
|
elif path is None and os.environ.get('PATH') is None:
|
|
# find_executable() errors if neither path nor $PATH is set
|
|
return None
|
|
else:
|
|
return find_executable(cmd, path=path)
|
|
|
|
|
|
def zip_dir(dir, out_path, filter=None, prefix=''):
|
|
"""Compress the given *dir* into a zip file at *out_path*.
|
|
|
|
If we encounter symlinks, include the actual file, not the symlink.
|
|
|
|
:type dir: str
|
|
:param dir: dir to tar up
|
|
:type out_path: str
|
|
:param out_path: where to write the tarball too
|
|
:param filter: if defined, a function that takes paths (relative to *dir*
|
|
and returns ``True`` if we should keep them
|
|
:type prefix: str
|
|
:param prefix: subdirectory inside the tarball to put everything into (e.g.
|
|
``'mrjob'``)
|
|
"""
|
|
if not os.path.isdir(dir):
|
|
raise IOError('Not a directory: %r' % (dir,))
|
|
|
|
if not filter:
|
|
filter = lambda path: True
|
|
|
|
with _create_zip_file(out_path) as zip_file:
|
|
for dirpath, dirnames, filenames in os.walk(dir, followlinks=True):
|
|
for filename in filenames:
|
|
path = os.path.join(dirpath, filename)
|
|
rel_path = os.path.relpath(path, dir)
|
|
|
|
if filter(rel_path):
|
|
# copy over real files, not symlinks
|
|
real_path = os.path.realpath(path)
|
|
path_in_zip_file = os.path.join(prefix, rel_path)
|
|
zip_file.write(real_path, arcname=path_in_zip_file)
|
|
|
|
|
|
# this is also used by spark runner
|
|
def _create_zip_file(path):
|
|
try:
|
|
return ZipFile(path, mode='w', compression=ZIP_DEFLATED)
|
|
except RuntimeError: # zlib not available
|
|
return ZipFile(path, mode='w', compression=ZIP_STORED)
|