Source code for bayesmark.xr_util

# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities for `xarray` that should be included in `xarray`.
"""
from collections import OrderedDict

import numpy as np
import pandas as pd
import xarray as xr

from bayesmark.util import all_unique


[docs]def is_simple_coords(coords, min_side=0, dims=None): """Check if all xr coordinates are "simple". That is, equals to ``np.arange(n)``. Parameters ---------- coords : dict-like of coordinates The coordinates we would like to check, e.g. from ``DataArray.coords``. min_side : int The minimum side requirement. We can set this ``min_side=1`` and have empty coordinates result in a return value of ``False``. dims : None or list of dimension names Dimensions we want to check for simplicity. If ``None``, check all dimensions. Returns ------- simple : bool True when all coordinates are simple. """ for kk in coords: if (dims is None) or (kk in dims): C = coords[kk].values # Not checking dtype on empty coords, could check that too if we want to be strict if len(C) > 0 and C.dtype != np.int_: return False C = C.tolist() if len(C) < min_side: return False if C != list(range(len(C))): return False return True
[docs]def ds_like(ref, vars_, dims, fill=np.nan): """Produce a blank :class:`xarray:xarray.Dataset` copying some coordinates from another :class:`xarray:xarray.Dataset`. Parameters ---------- ref : :class:`xarray:xarray.Dataset` The reference dataset we want to copy coordinates from. vars_ : typing.Iterable List of variable names we want in the new dataset. dims : list List of dimensions we want to copy over from `ref`. These are the dimensions of the output. fill : scalar Scalar value to fill the blank dataset. The `dtype` will be determined from the `fill` value. Returns ------- ds : :class:`xarray:xarray.Dataset` A new dataset with variables `vars_` and dimensions `dims` where the coordinates have been copied from `ref`. All values are filled with `fill`. """ size = [ref.sizes[dd] for dd in dims] # Use OrderedDict for good measure, probably not needed data = OrderedDict([(vv, (dims, np.full(size, fill))) for vv in vars_]) coords = OrderedDict([(dd, ref.coords[dd].values) for dd in dims]) ds = xr.Dataset(data, coords=coords) return ds
[docs]def ds_like_mixed(ref, vars_, dims, fill=np.nan): """The same as `ds_like` but allow different dimensions for each variable. Parameters ---------- ref : :class:`xarray:xarray.Dataset` The reference dataset we want to copy coordinates from. vars_ : typing.Iterable List of (variable names, dimension) pairs we want in the new dataset. The dimensions for each variable must be a subset of `dims`. dims : list List of all dimensions we want to copy over from `ref`. fill : scalar Scalar value to fill the blank dataset. The `dtype` will be determined from the `fill` value. Returns ------- ds : :class:`xarray:xarray.Dataset` A new dataset with variables `vars_` and dimensions `dims` where the coordinates have been copied from `ref`. All values are filled with `fill`. """ coords = OrderedDict([(dd, ref.coords[dd].values) for dd in dims]) data = OrderedDict() for var_name, var_dims in vars_: assert set(var_dims).issubset(dims) size = [ref.sizes[dd] for dd in var_dims] data[var_name] = (var_dims, np.full(size, np.nan)) ds = xr.Dataset(data, coords=coords) return ds
[docs]def only_dataarray(ds): """Convert a :class:`xarray:xarray.Dataset` to a :class:`xarray:xarray.DataArray`. If the :class:`xarray:xarray.Dataset` has more than one variable, an error is raised. Parameters ---------- ds : :class:`xarray:xarray.Dataset` :class:`xarray:xarray.Dataset` we would like to convert to a :class:`xarray:xarray.DataArray`. This must contain only one variable. Returns ------- da : :class:`xarray:xarray.DataArray` The :class:`xarray:xarray.DataArray` extracted from `ds`. """ name, = ds da = ds[name] return da
[docs]def coord_compat(da_seq, dims): """Check if a sequence of :class:`xarray:xarray.DataArray` have compatible coordinates. Parameters ---------- da_seq : list(:class:`xarray:xarray.DataArray`) Sequence of :class:`xarray:xarray.DataArray` we would like to check for compatibility. :class:`xarray:xarray.Dataset` work too. dims : list Subset of all dimensions in the :class:`xarray:xarray.DataArray` we are concerned with for compatibility. Returns ------- compat : bool True if all the :class:`xarray:xarray.DataArray` have compatible coordinates. """ if len(da_seq) <= 1: return True ref = da_seq[0] for da in da_seq: # There is probably a better way to do this by attempting concat in try-except, but good enough for now: for dd in dims: assert dd in da.coords, "dim %s missing in dataarray" % dd if not np.all(ref.coords[dd].values == da.coords[dd].values): return False return True
[docs]def da_to_string(da): """Generate a human readable version of a 1D :class:`xarray:xarray.DataArray`. Parameters ---------- da : :class:`xarray:xarray.DataArray` The :class:`xarray:xarray.DataArray` to display. Must only have one dimension. Returns ------- str_val : str String with human readable version of `da`. """ assert len(da.dims) == 1 str_val = da.to_series().to_string() return str_val
[docs]def da_concat(da_dict, dims): """Concatenate a dictionary of :class:`xarray:xarray.DataArray` similar to :func:`pandas:pandas.concat`. Parameters ---------- da_dict : dict(tuple(str), :class:`xarray:xarray.DataArray`) Dictionary of :class:`xarray:xarray.DataArray` to combine. The keys are tuples of index values. The :class:`xarray:xarray.DataArray` must have compatible coordinates. dims : list(str) The names of the new dimensions we create for the dictionary keys. This must be of the same length as the key tuples in `da_dict`. Returns ------- da : :class:`xarray:xarray.DataArray` Combined data array. The new dimensions will be ``input_da.dims + dims``. """ assert len(da_dict) > 0 assert all(len(da.dims) > 0 for da in da_dict.values()), "0-dimensional DataArray not supported" assert all_unique(dims) cur_dims = list(da_dict.values())[0].dims assert all(da.dims == cur_dims for da in da_dict.values()) assert len(set(cur_dims) & set(dims)) == 0 def squeeze(tt): if len(tt) == 1: return tt[0] return tt D = OrderedDict([(squeeze(kk), da.to_series()) for kk, da in da_dict.items()]) df = pd.concat(D, axis=1) assert df.columns.nlevels == len(dims) df.columns.names = dims df = df.stack(level=list(range(df.columns.nlevels))) assert isinstance(df, pd.Series) da = df.to_xarray() assert isinstance(da, xr.DataArray) return da
[docs]def ds_concat(ds_dict, dims): """Concatenate a dictionary of :class:`xarray:xarray.Dataset` similar to :func:`pandas:pandas.concat`, and a generalization of :func:`.da_concat`. Parameters ---------- ds_dict : dict(tuple(str), :class:`xarray:xarray.DataArray`) Dictionary of :class:`xarray:xarray.Dataset` to combine. The keys are tuples of index values. The :class:`xarray:xarray.Dataset` must have compatible coordinates, and all have the same variables. dims : list(str) The names of the new dimensions we create for the dictionary keys. This must be of the same length as the key tuples in `ds_dict`. Returns ------- ds : :class:`xarray:xarray.Dataset` Combined dataset. For each variable `var`, the new dimensions will be ``input_ds[var].dims + dims``. """ assert len(ds_dict) > 0 assert len(dims) > 0 assert all(len(kk) == len(dims) for kk in ds_dict) # Get an arbitrary element as the reference k0 = list(ds_dict.keys())[0] # Check all vars the same vars_, = set([tuple(ds) for ds in ds_dict.values()]) # Now combine da for each variable, one at a time ds = xr.Dataset(coords=ds_dict[k0].coords) for vv in vars_: da_dict = OrderedDict([(kk, da[vv]) for kk, da in ds_dict.items()]) ds[vv] = da_concat(da_dict, dims) return ds