Source code for paramspace.yaml_constructors

"""Defines the yaml constructors for the generation of
:py:class:`~paramspace.paramspace.ParamSpace` and
:py:class:`~paramspace.paramdim.ParamDim` during loading of YAML files.

Note that they are not registered in this module but in the
:py:mod:`paramspace.yaml` module.
"""
import logging
import warnings
from collections import OrderedDict
from typing import Iterable, Union

import ruamel.yaml

from .paramdim import CoupledParamDim, ParamDim
from .paramspace import ParamSpace
from .tools import create_indices

# Get logger
log = logging.getLogger(__name__)


# Top-level functions for the yaml-module to import ---------------------------


[docs]def pspace(loader, node) -> ParamSpace: """yaml constructor for creating a ParamSpace object from a mapping. Suggested tag: ``!pspace`` """ return _pspace_constructor(loader, node)
[docs]def pspace_unsorted(loader, node) -> ParamSpace: """yaml constructor for creating a ParamSpace object from a mapping. Unlike the regular constructor, this one does NOT sort the input before instantiating ParamSpace. Suggested tag: ``!pspace-unsorted`` """ return _pspace_constructor(loader, node, sort_if_mapping=False)
[docs]def pdim(loader, node) -> ParamDim: """constructor for creating a ParamDim object from a mapping Suggested tag: ``!pdim`` """ return _pdim_constructor(loader, node)
[docs]def pdim_default(loader, node) -> ParamDim: """constructor for creating a ParamDim object from a mapping, but only return the default value. Suggested tag: ``!pdim-default`` """ pdim = _pdim_constructor(loader, node) log.debug("Returning default value of constructed ParamDim.") return pdim.default
[docs]def coupled_pdim(loader, node) -> CoupledParamDim: """constructor for creating a CoupledParamDim object from a mapping Suggested tag: ``!coupled-pdim`` """ return _coupled_pdim_constructor(loader, node)
[docs]def coupled_pdim_default(loader, node) -> CoupledParamDim: """constructor for creating a CoupledParamDim object from a mapping, but only return the default value. Suggested tag: ``!coupled-pdim-default`` """ cpdim = _coupled_pdim_constructor(loader, node) log.debug("Returning default value of constructed CoupledParamDim.") return cpdim.default
# The actual constructor functions --------------------------------------------
[docs]def _pspace_constructor( loader, node, sort_if_mapping: bool = True ) -> ParamSpace: """Constructor for instantiating ParamSpace from a mapping or a sequence""" log.debug("Encountered tag associated with ParamSpace.") # get fields as mapping or sequence if isinstance(node, ruamel.yaml.nodes.MappingNode): log.debug("Constructing mapping from node ...") d = loader.construct_mapping(node, deep=True) # Recursively order the content to have consistent loading if sort_if_mapping: log.debug("Recursively sorting the mapping ...") d = recursively_sort_dict(OrderedDict(d)) else: raise TypeError( f"ParamSpace node can only be constructed from a mapping or a " f"sequence, got node of type {type(node)} with value:\n{node}." ) log.debug("Instantiating ParamSpace ...") return ParamSpace(d)
[docs]def _pdim_constructor(loader, node) -> ParamDim: """Constructor for creating a ParamDim object from a mapping For it to be incorported into a ParamSpace, one parent (or higher) of this node needs to be tagged such that the pspace_constructor is invoked. """ log.debug("Encountered tag associated with ParamDim.") if isinstance(node, ruamel.yaml.nodes.MappingNode): log.debug("Constructing mapping ...") mapping = loader.construct_mapping(node, deep=True) pdim = ParamDim(**mapping) else: raise TypeError( f"ParamDim can only be constructed from a mapping node,got node " f"of type {type(node)} with value:\n{node}" ) return pdim
[docs]def _coupled_pdim_constructor(loader, node) -> ParamDim: """Constructor for creating a ParamDim object from a mapping For it to be incorported into a ParamSpace, one parent (or higher) of this node needs to be tagged such that the pspace_constructor is invoked. """ log.debug("Encountered tag associated with ParamDim.") if isinstance(node, ruamel.yaml.nodes.MappingNode): log.debug("Constructing mapping ...") mapping = loader.construct_mapping(node, deep=True) cpdim = CoupledParamDim(**mapping) else: raise TypeError( f"CoupledParamDim can only be constructed from a mapping node, " f"got node of type {type(node)} with value:\n{node}" ) return cpdim
# Some other constructors ----------------------------------------------------- # ...for constructing slice objects
[docs]def _slice_constructor(loader, node): """Constructor for slices""" log.debug("Encountered !slice tag.") # get slice arguments either from a scalar or from a sequence if isinstance(node, ruamel.yaml.nodes.SequenceNode): args = loader.construct_sequence(node, deep=True) else: args = [loader.construct_yaml_int(node)] log.debug(" args: %s", args) slc = slice(*args) log.debug(" slice object created: %s", slc) return slc
[docs]def _range_constructor(loader, node): """Constructor for range""" log.debug("Encountered !range tag.") # get range arguments either from a scalar or from a sequence if isinstance(node, ruamel.yaml.nodes.SequenceNode): args = loader.construct_sequence(node, deep=True) else: args = [loader.construct_yaml_int(node)] log.debug(" args: %s", args) rg = range(*args) log.debug(" range object created: %s", rg) return rg
[docs]def _list_constructor(loader, node): """Constructor for lists, where node can be a mapping or sequence""" log.debug("Encountered !listgen tag.") if isinstance(node, ruamel.yaml.nodes.MappingNode): kwargs = loader.construct_mapping(node, deep=True) elif isinstance(node, ruamel.yaml.nodes.SequenceNode): kwargs = dict(from_range=loader.construct_sequence(node)) else: raise TypeError( f"Expected mapping or sequence node for !listgen, but " f"got {type(node)}!" ) log.debug(" kwargs: %s", kwargs) return create_indices(**kwargs)
# Helpers ---------------------------------------------------------------------
[docs]def recursively_sort_dict(d: dict) -> OrderedDict: """Recursively sorts a dictionary by its keys, transforming it to an OrderedDict in the process. From: http://stackoverflow.com/a/22721724/1827608 Args: d (dict): The dictionary to be sorted Returns: OrderedDict: the recursively sorted dict """ # Start with empty ordered dict for this recursion level res = OrderedDict() # Fill it with the values from the dictionary for k, v in sorted(d.items()): if isinstance(v, dict): res[k] = recursively_sort_dict(v) else: res[k] = v return res