GildedRose-Refactoring-Kata/.venv/lib/python3.12/site-packages/mrjob/util.py
2025-06-22 13:36:01 +05:30

428 lines
12 KiB
Python

# Copyright 2009-2016 Yelp and Contributors
# Copyright 2017-2018 Yelp
# Copyright 2019 Yelp
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for MRJob
"""
# don't add imports here that aren't part of the standard Python library,
# since MRJobs need to run in Amazon's generic EMR environment
import logging
import os
import os.path
import pipes
import random
import shlex
import shutil
import sys
import tarfile
from contextlib import contextmanager
from datetime import timedelta
from distutils.spawn import find_executable
from logging import getLogger
from zipfile import ZIP_DEFLATED
from zipfile import ZIP_STORED
from zipfile import ZipFile
from zipfile import is_zipfile
from mrjob.py2 import PY2
log = getLogger(__name__)
class NullHandler(logging.Handler):
def emit(self, record):
pass
def cmd_line(args):
"""build a command line that works in a shell.
"""
args = [str(x) for x in args]
return ' '.join(pipes.quote(x) for x in args)
def expand_path(path):
"""Resolve ``~`` (home dir) and environment variables in *path*.
If *path* is ``None``, return ``None``.
"""
if path is None:
return None
else:
return os.path.expanduser(os.path.expandvars(path))
def file_ext(filename):
"""return the file extension, including the ``.``
>>> file_ext('foo.tar.gz')
'.tar.gz'
>>> file_ext('.emacs')
''
>>> file_ext('.mrjob.conf')
'.conf'
"""
stripped_name = filename.lstrip('.')
dot_index = stripped_name.find('.')
if dot_index == -1:
return ''
return stripped_name[dot_index:]
def log_to_null(name=None):
"""Set up a null handler for the given stream, to suppress
"no handlers could be found" warnings."""
logger = logging.getLogger(name)
logger.addHandler(NullHandler())
def log_to_stream(name=None, stream=None, format=None, level=None,
debug=False):
"""Set up logging.
:type name: str
:param name: name of the logger, or ``None`` for the root logger
:type stream: file object
:param stream: stream to log to (default is ``sys.stderr``)
:type format: str
:param format: log message format (default is '%(message)s')
:param level: log level to use
:type debug: bool
:param debug: quick way of setting the log level: if true, use
``logging.DEBUG``, otherwise use ``logging.INFO``
"""
if level is None:
level = logging.DEBUG if debug else logging.INFO
if format is None:
format = '%(message)s'
if stream is None:
stream = sys.stderr
handler = logging.StreamHandler(stream)
handler.setLevel(level)
handler.setFormatter(logging.Formatter(format))
logger = logging.getLogger(name)
logger.setLevel(level)
logger.addHandler(handler)
def random_identifier():
"""A random 16-digit hex string."""
return '%016x' % random.randint(0, 2 ** 64 - 1)
# Thanks to http://lybniz2.sourceforge.net/safeeval.html for
# explaining how to do this!
def safeeval(expr, globals=None, locals=None):
"""Like eval, but with nearly everything in the environment
blanked out, so that it's difficult to cause mischief.
*globals* and *locals* are optional dictionaries mapping names to
values for those names (just like in :py:func:`eval`).
"""
# blank out builtins, but keep None, True, and False
safe_globals = {
'False': False,
'None': None,
'True': True,
'__builtin__': None,
'__builtins__': None,
'set': set
}
# xrange is range in Python 3
if PY2:
safe_globals['xrange'] = xrange
else:
safe_globals['range'] = range
# PyPy needs special magic
def open(*args, **kwargs):
raise NameError("name 'open' is not defined")
safe_globals['open'] = open
# add the user-specified global variables
if globals:
safe_globals.update(globals)
return eval(expr, safe_globals, locals)
@contextmanager
def save_current_environment():
""" Context manager that saves os.environ and loads
it back again after execution
"""
original_environ = os.environ.copy()
try:
yield
finally:
os.environ.clear()
os.environ.update(original_environ)
@contextmanager
def save_cwd():
"""Context manager that saves the current working directory,
and chdir's back to it after execution."""
original_cwd = os.getcwd()
try:
yield
finally:
os.chdir(original_cwd)
@contextmanager
def save_sys_std():
"""Context manager that saves the current values of `sys.stdin`,
`sys.stdout`, and `sys.stderr`, and flushes these filehandles before
and after switching them out."""
stdin, stdout, stderr = sys.stdin, sys.stdout, sys.stderr
try:
sys.stdout.flush()
sys.stderr.flush()
yield
# at this point, sys.stdout/stderr may have been patched. Don't
# raise an exception if flush() fails
try:
sys.stdout.flush()
except:
pass
try:
sys.stderr.flush()
except:
pass
finally:
sys.stdin, sys.stdout, sys.stderr = stdin, stdout, stderr
@contextmanager
def save_sys_path():
"""Context manager that saves sys.path and restores it after execution."""
original_sys_path = list(sys.path)
try:
yield
finally:
sys.path = original_sys_path
def shlex_split(s):
"""Wrapper around shlex.split(), but convert to str if Python version <
2.7.3 when unicode support was added.
"""
if sys.version_info < (2, 7, 3):
return shlex.split(str(s))
else:
return shlex.split(s)
def strip_microseconds(delta):
"""Return the given :py:class:`datetime.timedelta`, without microseconds.
Useful for printing :py:class:`datetime.timedelta` objects.
"""
return timedelta(delta.days, delta.seconds)
def to_lines(chunks):
"""Take in data as a sequence of bytes, and yield it, one line at a time.
Only breaks lines on ``\\n`` (not ``\\r``), and does not add
a trailing newline.
For efficiency, passes through anything with a ``readline()`` attribute.
"""
# hopefully this is good enough for anything mrjob will encounter
if hasattr(chunks, 'readline'):
return chunks
else:
return _to_lines(chunks)
def _to_lines(chunks):
"""Take in data as a sequence of bytes, and yield it, one line at a time.
Only breaks lines on ``\\n`` (not ``\\r``), and does not add
a trailing newline.
Exception: if we encounter an empty bytestring ``b''``, immediately yield
what we have so far rather than joining it to the next chunk. This allows
us to handle bytes from multiple files without joining the end of one
file to the beginning of the next one.
Optimizes for:
* chunks bigger than lines (e.g. reading test files)
* chunks that are lines (idempotency)
"""
# list of chunks with no final newline
leftovers = []
for chunk in chunks:
# special case for b'' standing for EOF
if chunk == b'':
if leftovers:
yield b''.join(leftovers)
leftovers = []
continue
start = 0
while start < len(chunk):
end = chunk.find(b'\n', start) + 1
if end == 0: # no newlines found
leftovers.append(chunk[start:])
break
if leftovers:
leftovers.append(chunk[start:end])
yield b''.join(leftovers)
leftovers = []
else:
yield chunk[start:end]
start = end
if leftovers:
yield b''.join(leftovers)
def unique(items):
"""Yield items from *item* in order, skipping duplicates."""
seen = set()
for item in items:
if item in seen:
continue
else:
yield item
seen.add(item)
def unarchive(archive_path, dest):
"""Extract the contents of a tar or zip file at *archive_path* into the
directory *dest*.
:type archive_path: str
:param archive_path: path to archive file
:type dest: str
:param dest: path to directory where archive will be extracted
*dest* will be created if it doesn't already exist.
tar files can be gzip compressed, bzip2 compressed, or uncompressed. Files
within zip files can be deflated or stored.
"""
if tarfile.is_tarfile(archive_path):
with tarfile.open(archive_path, 'r') as archive:
archive.extractall(dest)
elif is_zipfile(archive_path):
with ZipFile(archive_path, 'r') as archive:
for name in archive.namelist():
# the zip spec specifies that front slashes are always
# used as directory separators
dest_path = os.path.join(dest, *name.split('/'))
# now, split out any dirname and filename and create
# one and/or the other
dirname, filename = os.path.split(dest_path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
if filename:
with open(dest_path, 'wb') as dest_file:
dest_file.write(archive.read(name))
else:
raise IOError('Unknown archive type: %s' % (archive_path,))
def which(cmd, path=None):
"""Like the UNIX which command: search in *path* for the executable named
*cmd*. *path* defaults to :envvar:`PATH`. Returns ``None`` if no
such executable found.
This is basically ``shutil.which()`` (which was introduced in Python 3.3)
without the *mode* argument. Best practice is to always specify *path*
as a keyword argument.
"""
if hasattr(shutil, 'which'):
# added in Python 3.3
return shutil.which(cmd, path=path)
elif path is None and os.environ.get('PATH') is None:
# find_executable() errors if neither path nor $PATH is set
return None
else:
return find_executable(cmd, path=path)
def zip_dir(dir, out_path, filter=None, prefix=''):
"""Compress the given *dir* into a zip file at *out_path*.
If we encounter symlinks, include the actual file, not the symlink.
:type dir: str
:param dir: dir to tar up
:type out_path: str
:param out_path: where to write the tarball too
:param filter: if defined, a function that takes paths (relative to *dir*
and returns ``True`` if we should keep them
:type prefix: str
:param prefix: subdirectory inside the tarball to put everything into (e.g.
``'mrjob'``)
"""
if not os.path.isdir(dir):
raise IOError('Not a directory: %r' % (dir,))
if not filter:
filter = lambda path: True
with _create_zip_file(out_path) as zip_file:
for dirpath, dirnames, filenames in os.walk(dir, followlinks=True):
for filename in filenames:
path = os.path.join(dirpath, filename)
rel_path = os.path.relpath(path, dir)
if filter(rel_path):
# copy over real files, not symlinks
real_path = os.path.realpath(path)
path_in_zip_file = os.path.join(prefix, rel_path)
zip_file.write(real_path, arcname=path_in_zip_file)
# this is also used by spark runner
def _create_zip_file(path):
try:
return ZipFile(path, mode='w', compression=ZIP_DEFLATED)
except RuntimeError: # zlib not available
return ZipFile(path, mode='w', compression=ZIP_STORED)