Source code for entente.validation

[docs]def validate_shape(a, *shape, **kwargs): """ Check that the given argument has the expected shape. Shape dimensions can be ints or -1 for a wildcard. The wildcard dimensions are returned, which allows them to be used for subsequent validation or elsewhere in the function. Args: a (np.arraylike): An array-like input. shape (list): Shape to validate. To require 3 by 1, pass `3`. To require n by 3, pass `-1, 3`. name (str): Variable name to embed in the error message. Returns: object: The wildcard dimension (if one) or a tuple of wildcard dimensions (if more than one). """ is_wildcard = lambda dim: dim == -1 if all(not isinstance(dim, int) and not is_wildcard(dim) for dim in shape): raise ValueError("Expected shape dimensions to be int") if "name" in kwargs: preamble = "{} must be an array".format(kwargs["name"]) else: preamble = "Expected an array" if a is None: raise ValueError("{} with shape {}; got None".format(preamble, shape)) try: len(a.shape) except (AttributeError, TypeError): raise ValueError( "{} with shape {}; got {}".format(preamble, shape, a.__class__) ) # Check non-wildcard dimensions. if len(a.shape) != len(shape) or any( actual != expected for actual, expected in zip(a.shape, shape) if not is_wildcard(expected) ): raise ValueError("{} with shape {}; got {}".format(preamble, shape, a.shape)) wildcard_dims = [ actual for actual, expected in zip(a.shape, shape) if is_wildcard(expected) ] if len(wildcard_dims) == 0: return None elif len(wildcard_dims) == 1: return wildcard_dims[0] else: return tuple(wildcard_dims)
[docs]def validate_shape_from_ns(namespace, name, *shape): """ Convenience function for invoking `validate_shape()` with a `locals()` dict. Args: namespace (dict): A subscriptable object, typically `locals()`. name (str): Key to pull from `namespace`. shape (list): Shape to validate. To require 3 by 1, pass `3`. To require n by 3, pass `-1, 3`. Returns: object: The wildcard dimension (if one) or a tuple of wildcard dimensions (if more than one). Example: validate_shape_from_namespace(locals(), 'points', -1, 3) """ return validate_shape(namespace[name], *shape, name=name)