Source code for bayesmark.serialize

# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A serialization abstraction layer (SAL) to save and load experimental results. All IO of experimental results should
go through this module. This makes changing the backend (between different databases) transparent to the benchmark code.
"""
import json
import os
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from tempfile import mkdtemp

import xarray as xr
from pathvalidate.argparse import validate_filename, validate_filepath

from bayesmark.path_util import join_safe_r, join_safe_w
from bayesmark.util import chomp, str_join_safe

NEWLINE = "\n"  # Just to be explicit, in case this ever gets run on Windows
PREFIX_FMT = "bo_%Y%m%d_%H%M%S_"  # The format we use for generating a new database name if none is specified

_XR_EXT = ".json"  # Extension we use for dumping xr.Dataset variables
_LOG_EXT = ".log"  # Extension to reccomend for logging files
_DERIVED_DIR = "derived"  # The folder for dervied variables (datasets)
_LOGGING_DIR = "log"  # The folder to reccomend for logging
_SETUP_STR = """
User must ensure
%s
exists, and setup folder using
mkdir %s
User must ensure equal reps of each optimizer for unbiased results."""


class Serializer(ABC):
    """Abstract base class for the serialization abstraction layer.
    """

    @staticmethod
    @abstractmethod
    def init_db(db_root, keys, db=None, exist_ok=True):
        """Initialize a "database" for storing data at the specified location.

        Parameters
        ----------
        db_root : str
            Absolute path to the database.
        keys : list(str)
            The variable names (or keys) we will store in the database for non-derived data.
        db : str
            The name of the database. If ``None``, a non-conflicting name will be generated.
        exist_ok : bool
            If true, do not raise an error if this database already exists.

        Returns
        -------
        db : str
            The name of the database.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_keys(db_root, db):
        """List the non-derived keys available in the database.

        Parameters
        ----------
        db_root : str
            Absolute path to the database.
        db : str
            The name of the database.

        Returns
        -------
        keys : list(str)
            The variable names (or keys) in the database for non-derived data.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_derived_keys(db_root, db):
        """List the derived keys currently available in the database.

        Parameters
        ----------
        db_root : str
            Absolute path to the database.
        db : str
            The name of the database.

        Returns
        -------
        keys : list(str)
            The variable names (or keys) in the database for derived data.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_uuids(db_root, db, key):
        """List the UUIDs for the versions of a variable (non-derived key) available in the database.

        Parameters
        ----------
        db_root : str
            Absolute path to the database.
        db : str
            The name of the database.
        keys : str
            The variable name in the database for non-derived data.

        Returns
        -------
        uuids : list(uuid.UUID)
            The UUIDs for the versions of this key.
        """
        pass

    @staticmethod
    @abstractmethod
    def save(data, meta, db_root, db, key, uuid_):
        """Abstract method for saving experimental data, details require the type of `data`.
        """
        pass

    @staticmethod
    @abstractmethod
    def load(db_root, db, key, uuid_):
        """Abstract method for loading experimental data, details require the type of `data`.
        """
        pass

    @staticmethod
    @abstractmethod
    def save_derived(data, meta, db_root, db, key):
        """Abstract method for saving derived data, details require the type of `data`.
        """
        pass

    @staticmethod
    @abstractmethod
    def load_derived(db_root, db, key):
        """Abstract method for loading derived data, details require the type of `data`.
        """
        pass


[docs]class XRSerializer(Serializer): """Serialization layer when saving and loading `xarray` datasets (currently) as `json`. """
[docs] def init_db(db_root, keys, db=None, exist_ok=True): # pragma: io XRSerializer._validate(db_root, keys, db) if db is None: folder_prefix = datetime.utcnow().strftime(PREFIX_FMT) exp_subdir = mkdtemp(prefix=folder_prefix, dir=db_root) db = os.path.basename(exp_subdir) assert db.startswith(folder_prefix) assert os.path.join(db_root, db) == exp_subdir else: exp_subdir = os.path.join(db_root, db) os.makedirs(exp_subdir, exist_ok=exist_ok) subdirs = [_DERIVED_DIR, _LOGGING_DIR] + list(keys) for subd in subdirs: os.makedirs(os.path.join(exp_subdir, subd), exist_ok=exist_ok) return db
[docs] def init_db_manual(db_root, keys, db): """Instruction for how one would manually initialize the "database" on another system. Parameters ---------- db_root : str Absolute path to the database. keys : list(str) The variable names (or keys) we will store in the database for non-derived data. db : str The name of the database. Returns ------- manual_setup_info : str The setup instructions. """ XRSerializer._validate(db_root, keys, db) assert db is not None, "Must specify db name to setup manually." exp_subdir = os.path.join(db_root, db) subdirs = [_DERIVED_DIR, _LOGGING_DIR] + list(keys) manual_setup_info = _SETUP_STR % (exp_subdir, str_join_safe(" ", subdirs)) return manual_setup_info
[docs] def get_keys(db_root, db): # pragma: io XRSerializer._validate(db_root, keys=(), db=db) keys = sorted(os.listdir(os.path.join(db_root, db))) keys.remove(_DERIVED_DIR) keys.remove(_LOGGING_DIR) return keys
[docs] def get_derived_keys(db_root, db): # pragma: io XRSerializer._validate(db_root, keys=(), db=db) fnames = sorted(os.listdir(os.path.join(db_root, db, _DERIVED_DIR))) keys = [XRSerializer._fname_to_key(ff) for ff in fnames] return keys
[docs] def get_uuids(db_root, db, key): # pragma: io XRSerializer._validate(db_root, keys=[key], db=db) fnames = sorted(os.listdir(os.path.join(db_root, db, key))) uuids = [XRSerializer._fname_to_uuid(ff) for ff in fnames] return uuids
[docs] def save(data, meta, db_root, db, key, uuid_): # pragma: io """Save a dataset under a key name in the database. Parameters ---------- data : :class:`xarray:xarray.Dataset` An :class:`xarray:xarray.Dataset` variable we would like to store as non-derived data from an experiment. meta : json-serializable Associated meta-data with the experiment. This can be anything json serializable. db_root : str Absolute path to the database. db : str The name of the database. key : str The variable name in the database for the data. uuid_ : uuid.UUID The UUID to represent the version of this variable we are storing. """ XRSerializer._validate(db_root, keys=[key], db=db) fname = XRSerializer._uuid_to_fname(uuid_) path = (db_root, db, key, fname) with open(join_safe_w(*path), "w") as f: _dump_xr(f, ds=data, meta=meta)
[docs] def load(db_root, db, key, uuid_): # pragma: io """Load a dataset under a key name in the database. This is the inverse of :func:`.save`. Parameters ---------- db_root : str Absolute path to the database. db : str The name of the database. key : str The variable name in the database for the data. uuid_ : uuid.UUID The UUID to represent the version of this variable we want to load. Returns ------- data : :class:`xarray:xarray.Dataset` An :class:`xarray:xarray.Dataset` variable for the non-derived data from an experiment. meta : json-serializable Associated meta-data with the experiment. This can be anything json serializable. """ XRSerializer._validate(db_root, keys=[key], db=db) fname = XRSerializer._uuid_to_fname(uuid_) path = (db_root, db, key, fname) with open(join_safe_r(*path), "r") as f: ds, meta = _load_xr(f) return ds, meta
[docs] def save_derived(data, meta, db_root, db, key): # pragma: io """Save a dataset under a key name in the database as derived data. Parameters ---------- data : :class:`xarray:xarray.Dataset` An :class:`xarray:xarray.Dataset` variable we would like to store as derived data from experiments. meta : json-serializable Associated meta-data with the experiments. This can be anything json serializable. db_root : str Absolute path to the database. db : str The name of the database. key : str The variable name in the database for the data. """ XRSerializer._validate(db_root, keys=[key], db=db) fname = XRSerializer._key_to_fname(key) path = (db_root, db, _DERIVED_DIR, fname) with open(join_safe_w(*path), "w") as f: _dump_xr(f, ds=data, meta=meta)
[docs] def load_derived(db_root, db, key): # pragma: io """Load a dataset under a key name in the database as derived data. This is the inverse of :func:`.save_derived`. Parameters ---------- db_root : str Absolute path to the database. db : str The name of the database. key : str The variable name in the database for the data. Returns ------- data : :class:`xarray:xarray.Dataset` An :class:`xarray:xarray.Dataset` variable for the derived data from experiments. meta : json-serializable Associated meta-data with the experiments. This can be anything json serializable. """ XRSerializer._validate(db_root, keys=[key], db=db) fname = XRSerializer._key_to_fname(key) path = (db_root, db, _DERIVED_DIR, fname) with open(join_safe_r(*path), "r") as f: data, meta = _load_xr(f) return data, meta
[docs] def logging_path(db_root, db, uuid_): # pragma: io """Get an absolute path for logging from an experiment given its UUID. Parameters ---------- db_root : str Absolute path to the database. db : str The name of the database. uuid_ : uuid.UUID The UUID to represent this experiment. Returns ------- logfile : str Absolute path suitable for logging in this experiment. """ XRSerializer._validate(db_root, keys=(), db=db) assert isinstance(uuid_, uuid.UUID) fname = uuid_.hex + _LOG_EXT logfile = join_safe_w(db_root, db, _LOGGING_DIR, fname) return logfile
def _fname_to_uuid(fname): uuid_ = uuid.UUID(chomp(fname, _XR_EXT)) return uuid_ def _uuid_to_fname(uuid_): assert isinstance(uuid_, uuid.UUID) # This can be eliminated once we use type hints fname = uuid_.hex + _XR_EXT return fname def _key_to_fname(key): fname = key + _XR_EXT return fname def _fname_to_key(fname): key = chomp(fname, _XR_EXT) return key def _validate(db_root, keys=(), db=None): validate_filepath(db_root, platform="auto") assert os.path.isabs(db_root), "db_root must be absolute path" if db is not None: validate_filename(db, platform="universal") for kk in keys: validate_filename(kk, platform="universal")
def _dump_xr(f, ds, meta): # pragma: io """Helper routine to `XRSerializer.save` and `XRSerializer.save_derived`. """ assert isinstance(ds, xr.Dataset) # Requiring Dataset and not DataArray for now meta_json = json.dumps(meta) # meta can be anything that json can handle # JSON dumps seems pretty good about escaping, but check to be sure assert NEWLINE not in meta_json # Built in json dumper doesn't allow us to only line break on top-level, so we manually do this for now f.write('{"meta": %s,' % meta_json) f.write(NEWLINE) f.write('"data": ') json.dump(ds.to_dict(), f) f.write("}") f.write(NEWLINE) def _load_xr(f): # pragma: io """Helper routine to `XRSerializer.load` and`XRSerializer.load_derived`. """ all_json = json.load(f) meta = all_json.pop("meta") ds = xr.Dataset.from_dict(all_json.pop("data")) return ds, meta