# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities that should arguably be included in Python.
"""
import shlex
[docs]def in_or_none(x, L):
"""Check if item is in list of list is None."""
return (L is None) or (x in L)
[docs]def all_unique(L):
"""Check if all elements in a list are unique.
Parameters
----------
L : list
List we would like to check for uniqueness.
Returns
-------
uniq : bool
True if all elements in `L` are unique.
"""
uniq = len(L) == len(set(L))
return uniq
[docs]def strict_sorted(L):
"""Return a strictly sorted version of `L`. Therefore, this raises an error if `L` contains duplicates.
Parameters
----------
L : list
List we would like to sort.
Returns
-------
S : list
Strictly sorted version of `L`.
"""
assert all_unique(L), "Cannot strict sort because list contains duplicates."
S = sorted(L)
return S
[docs]def range_str(stop):
"""Version of ``range(stop)`` that instead returns strings that are zero padded so the entire iteration is of the
same length.
Parameters
----------
stop : int
Stop value equivalent to ``range(stop)``.
Yields
------
x : str
String representation of integer zero padded so all items from this generator have the same ``len(x)``.
"""
str_len = len(str(stop - 1)) # moot if stop=0
def map_(x):
ss = str(x).zfill(str_len)
return x, ss
G = map(map_, range(stop))
return G
[docs]def str_join_safe(delim, str_vec, append=False):
"""Version of `str.join` that is guaranteed to be invertible.
Parameters
----------
delim : str
Delimiter to join the strings.
str_vec : list(str)
List of strings to join. A `ValueError` is raised if `delim` is present in any of these strings.
append : bool
If true, assume the first element is already joined and we are appending to it. So, `str_vec[0]` can contain
`delim`.
Returns
-------
joined_str : str
Joined version of `str_vec`, which is always recoverable with ``joined_str.split(delim)``.
Examples
--------
Append is required because,
.. code-block:: pycon
ss = str_join_safe('_', ('foo', 'bar'))
str_join_safe('_', (ss, 'baz', 'qux'))
would fail because we are appending ``'baz'`` and ``'qux'`` to the already joined string ``ss = 'foo_bar'``.
In this case, we use
.. code-block:: pycon
ss = str_join_safe('_', ('foo', 'bar'))
str_join_safe('_', (ss, 'baz', 'qux'), append=True)
"""
chk_vec = str_vec[1:] if append else str_vec
for ss in chk_vec:
if delim in ss:
raise ValueError("%s cannot contain delimeter %s" % (ss, delim))
joined_str = delim.join(str_vec)
return joined_str
[docs]def shell_join(argv, delim=" "):
"""Join strings together in a way that is an inverse of `shlex` shell parsing into `argv`.
Basically, if the resulting string is passed as a command line argument then `sys.argv` will equal `argv`.
Parameters
----------
argv : list(str)
List of arguments to collect into command line string. It will be escaped accordingly.
delim : str
Whitespace delimiter to join the strings.
Returns
-------
cmd : str
Properly escaped and joined command line string.
"""
vv = [shlex.quote(vv) for vv in argv]
cmd = delim.join(vv)
assert shlex.split(cmd) == list(argv)
return cmd
[docs]def chomp(str_val, ext="\n"):
"""Chomp a suffix off a string.
Parameters
----------
str_val : str
String we want to chomp off a suffix, e.g., ``"foo.log"``, and we want to chomp the file extension.
ext : str
The suffix we want to chomp. An error is raised if `str_val` doesn't end in `ext`.
Returns
-------
chomped : str
Version of `str_val` with `ext` removed from the end.
"""
n = len(ext)
assert n > 0
chomped, ext_ = str_val[:-n], str_val[-n:]
assert ext == ext_, "%s must end with %s" % (repr(str_val), repr(ext))
return chomped
[docs]def preimage_func(f, x):
"""Pre-image a funcation at a set of input points.
Parameters
----------
f : typing.Callable
The function we would like to pre-image. The output type must be hashable.
x : typing.Iterable
Input points we would like to evaluate `f`. `x` must be of a type acceptable by `f`.
Returns
-------
D : dict(object, list(object))
This dictionary maps the output of `f` to the list of `x` values that produce it.
"""
D = {}
for xx in x:
D.setdefault(f(xx), []).append(xx)
return D