Source code for astropy.units.structured

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
This module defines structured units and quantities.
"""

from __future__ import annotations  # For python < 3.10

# Standard library
import operator

import numpy as np

from .core import UNITY, Unit, UnitBase

__all__ = ["StructuredUnit"]


DTYPE_OBJECT = np.dtype("O")


def _names_from_dtype(dtype):
    """Recursively extract field names from a dtype."""
    names = []
    for name in dtype.names:
        subdtype = dtype.fields[name][0]
        if subdtype.names:
            names.append([name, _names_from_dtype(subdtype)])
        else:
            names.append(name)
    return tuple(names)


def _normalize_names(names):
    """Recursively normalize, inferring upper level names for unadorned tuples.

    Generally, we want the field names to be organized like dtypes, as in
    ``(['pv', ('p', 'v')], 't')``.  But we automatically infer upper
    field names if the list is absent from items like ``(('p', 'v'), 't')``,
    by concatenating the names inside the tuple.
    """
    result = []
    for name in names:
        if isinstance(name, str) and len(name) > 0:
            result.append(name)
        elif (
            isinstance(name, list)
            and len(name) == 2
            and isinstance(name[0], str)
            and len(name[0]) > 0
            and isinstance(name[1], tuple)
            and len(name[1]) > 0
        ):
            result.append([name[0], _normalize_names(name[1])])
        elif isinstance(name, tuple) and len(name) > 0:
            new_tuple = _normalize_names(name)
            name = "".join([(i[0] if isinstance(i, list) else i) for i in new_tuple])
            result.append([name, new_tuple])
        else:
            raise ValueError(
                f"invalid entry {name!r}. Should be a name, "
                "tuple of names, or 2-element list of the "
                "form [name, tuple of names]."
            )

    return tuple(result)


[docs]class StructuredUnit: """Container for units for a structured Quantity. Parameters ---------- units : unit-like, tuple of unit-like, or `~astropy.units.StructuredUnit` Tuples can be nested. If a `~astropy.units.StructuredUnit` is passed in, it will be returned unchanged unless different names are requested. names : tuple of str, tuple or list; `~numpy.dtype`; or `~astropy.units.StructuredUnit`, optional Field names for the units, possibly nested. Can be inferred from a structured `~numpy.dtype` or another `~astropy.units.StructuredUnit`. For nested tuples, by default the name of the upper entry will be the concatenation of the names of the lower levels. One can pass in a list with the upper-level name and a tuple of lower-level names to avoid this. For tuples, not all levels have to be given; for any level not passed in, default field names of 'f0', 'f1', etc., will be used. Notes ----- It is recommended to initialze the class indirectly, using `~astropy.units.Unit`. E.g., ``u.Unit('AU,AU/day')``. When combined with a structured array to produce a structured `~astropy.units.Quantity`, array field names will take precedence. Generally, passing in ``names`` is needed only if the unit is used unattached to a `~astropy.units.Quantity` and one needs to access its fields. Examples -------- Various ways to initialize a `~astropy.units.StructuredUnit`:: >>> import astropy.units as u >>> su = u.Unit('(AU,AU/day),yr') >>> su Unit("((AU, AU / d), yr)") >>> su.field_names (['f0', ('f0', 'f1')], 'f1') >>> su['f1'] Unit("yr") >>> su2 = u.StructuredUnit(((u.AU, u.AU/u.day), u.yr), names=(('p', 'v'), 't')) >>> su2 == su True >>> su2.field_names (['pv', ('p', 'v')], 't') >>> su3 = u.StructuredUnit((su2['pv'], u.day), names=(['p_v', ('p', 'v')], 't')) >>> su3.field_names (['p_v', ('p', 'v')], 't') >>> su3.keys() ('p_v', 't') >>> su3.values() (Unit("(AU, AU / d)"), Unit("d")) Structured units share most methods with regular units:: >>> su.physical_type ((PhysicalType('length'), PhysicalType({'speed', 'velocity'})), PhysicalType('time')) >>> su.si Unit("((1.49598e+11 m, 1.73146e+06 m / s), 3.15576e+07 s)") """ def __new__(cls, units, names=None): dtype = None if names is not None: if isinstance(names, StructuredUnit): dtype = names._units.dtype names = names.field_names elif isinstance(names, np.dtype): if not names.fields: raise ValueError("dtype should be structured, with fields.") dtype = np.dtype([(name, DTYPE_OBJECT) for name in names.names]) names = _names_from_dtype(names) else: if not isinstance(names, tuple): names = (names,) names = _normalize_names(names) if not isinstance(units, tuple): units = Unit(units) if isinstance(units, StructuredUnit): # Avoid constructing a new StructuredUnit if no field names # are given, or if all field names are the same already anyway. if names is None or units.field_names == names: return units # Otherwise, turn (the upper level) into a tuple, for renaming. units = units.values() else: # Single regular unit: make a tuple for iteration below. units = (units,) if names is None: names = tuple(f"f{i}" for i in range(len(units))) elif len(units) != len(names): raise ValueError("lengths of units and field names must match.") converted = [] for unit, name in zip(units, names): if isinstance(name, list): # For list, the first item is the name of our level, # and the second another tuple of names, i.e., we recurse. unit = cls(unit, name[1]) name = name[0] else: # We are at the lowest level. Check unit. unit = Unit(unit) if dtype is not None and isinstance(unit, StructuredUnit): raise ValueError( "units do not match in depth with field " "names from dtype or structured unit." ) converted.append(unit) self = super().__new__(cls) if dtype is None: dtype = np.dtype( [ ((name[0] if isinstance(name, list) else name), DTYPE_OBJECT) for name in names ] ) # Decay array to void so we can access by field name and number. self._units = np.array(tuple(converted), dtype)[()] return self def __getnewargs__(self): """When de-serializing, e.g. pickle, start with a blank structure.""" return (), None @property def field_names(self): """Possibly nested tuple of the field names of the parts.""" return tuple( ([name, unit.field_names] if isinstance(unit, StructuredUnit) else name) for name, unit in self.items() ) # Allow StructuredUnit to be treated as an (ordered) mapping. def __len__(self): return len(self._units.dtype.names) def __getitem__(self, item): # Since we are based on np.void, indexing by field number works too. return self._units[item]
[docs] def values(self): return self._units.item()
[docs] def keys(self): return self._units.dtype.names
[docs] def items(self): return tuple(zip(self._units.dtype.names, self._units.item()))
def __iter__(self): yield from self._units.dtype.names # Helpers for methods below. def _recursively_apply(self, func, cls=None): """Apply func recursively. Parameters ---------- func : callable Function to apply to all parts of the structured unit, recursing as needed. cls : type, optional If given, should be a subclass of `~numpy.void`. By default, will return a new `~astropy.units.StructuredUnit` instance. """ applied = tuple(func(part) for part in self.values()) # Once not NUMPY_LT_1_23: results = np.void(applied, self._units.dtype). results = np.array(applied, self._units.dtype)[()] if cls is not None: return results.view((cls, results.dtype)) # Short-cut; no need to interpret field names, etc. result = super().__new__(self.__class__) result._units = results return result def _recursively_get_dtype(self, value, enter_lists=True): """Get structured dtype according to value, using our field names. This is useful since ``np.array(value)`` would treat tuples as lower levels of the array, rather than as elements of a structured array. The routine does presume that the type of the first tuple is representative of the rest. Used in ``_get_converter``. For the special value of ``UNITY``, all fields are assumed to be 1.0, and hence this will return an all-float dtype. """ if enter_lists: while isinstance(value, list): value = value[0] if value is UNITY: value = (UNITY,) * len(self) elif not isinstance(value, tuple) or len(self) != len(value): raise ValueError(f"cannot interpret value {value} for unit {self}.") descr = [] for (name, unit), part in zip(self.items(), value): if isinstance(unit, StructuredUnit): descr.append( (name, unit._recursively_get_dtype(part, enter_lists=False)) ) else: # Got a part associated with a regular unit. Gets its dtype. # Like for Quantity, we cast integers to float. part = np.array(part) part_dtype = part.dtype if part_dtype.kind in "iu": part_dtype = np.dtype(float) descr.append((name, part_dtype, part.shape)) return np.dtype(descr) @property def si(self): """The `StructuredUnit` instance in SI units.""" return self._recursively_apply(operator.attrgetter("si")) @property def cgs(self): """The `StructuredUnit` instance in cgs units.""" return self._recursively_apply(operator.attrgetter("cgs")) # Needed to pass through Unit initializer, so might as well use it. def _get_physical_type_id(self): return self._recursively_apply( operator.methodcaller("_get_physical_type_id"), cls=Structure ) @property def physical_type(self): """Physical types of all the fields.""" return self._recursively_apply( operator.attrgetter("physical_type"), cls=Structure )
[docs] def decompose(self, bases=set()): """The `StructuredUnit` composed of only irreducible units. Parameters ---------- bases : sequence of `~astropy.units.UnitBase`, optional The bases to decompose into. When not provided, decomposes down to any irreducible units. When provided, the decomposed result will only contain the given units. This will raises a `UnitsError` if it's not possible to do so. Returns ------- `~astropy.units.StructuredUnit` With the unit for each field containing only irreducible units. """ return self._recursively_apply(operator.methodcaller("decompose", bases=bases))
[docs] def is_equivalent(self, other, equivalencies=[]): """`True` if all fields are equivalent to the other's fields. Parameters ---------- other : `~astropy.units.StructuredUnit` The structured unit to compare with, or what can initialize one. equivalencies : list of tuple, optional A list of equivalence pairs to try if the units are not directly convertible. See :ref:`unit_equivalencies`. The list will be applied to all fields. Returns ------- bool """ try: other = StructuredUnit(other) except Exception: return False if len(self) != len(other): return False for self_part, other_part in zip(self.values(), other.values()): if not self_part.is_equivalent(other_part, equivalencies=equivalencies): return False return True
def _get_converter(self, other, equivalencies=[]): if not isinstance(other, type(self)): other = self.__class__(other, names=self) converters = [ self_part._get_converter(other_part, equivalencies=equivalencies) for (self_part, other_part) in zip(self.values(), other.values()) ] def converter(value): if not hasattr(value, "dtype"): value = np.array(value, self._recursively_get_dtype(value)) result = np.empty_like(value) for name, converter_ in zip(result.dtype.names, converters): result[name] = converter_(value[name]) # Index with empty tuple to decay array scalars to numpy void. return result if result.shape else result[()] return converter
[docs] def to(self, other, value=np._NoValue, equivalencies=[]): """Return values converted to the specified unit. Parameters ---------- other : `~astropy.units.StructuredUnit` The unit to convert to. If necessary, will be converted to a `~astropy.units.StructuredUnit` using the dtype of ``value``. value : array-like, optional Value(s) in the current unit to be converted to the specified unit. If a sequence, the first element must have entries of the correct type to represent all elements (i.e., not have, e.g., a ``float`` where other elements have ``complex``). If not given, assumed to have 1. in all fields. equivalencies : list of tuple, optional A list of equivalence pairs to try if the units are not directly convertible. See :ref:`unit_equivalencies`. This list is in addition to possible global defaults set by, e.g., `set_enabled_equivalencies`. Use `None` to turn off all equivalencies. Returns ------- values : scalar or array Converted value(s). Raises ------ UnitsError If units are inconsistent """ if value is np._NoValue: # We do not have UNITY as a default, since then the docstring # would list 1.0 as default, yet one could not pass that in. value = UNITY return self._get_converter(other, equivalencies=equivalencies)(value)
[docs] def to_string(self, format="generic"): """Output the unit in the given format as a string. Units are separated by commas. Parameters ---------- format : `astropy.units.format.Base` instance or str The name of a format or a formatter object. If not provided, defaults to the generic format. Notes ----- Structured units can be written to all formats, but can be re-read only with 'generic'. """ parts = [part.to_string(format) for part in self.values()] out_fmt = "({})" if len(self) > 1 else "({},)" if format.startswith("latex"): # Strip $ from parts and add them on the outside. parts = [part[1:-1] for part in parts] out_fmt = "$" + out_fmt + "$" return out_fmt.format(", ".join(parts))
def _repr_latex_(self): return self.to_string("latex") __array_ufunc__ = None def __mul__(self, other): if isinstance(other, str): try: other = Unit(other, parse_strict="silent") except Exception: return NotImplemented if isinstance(other, UnitBase): new_units = tuple(part * other for part in self.values()) return self.__class__(new_units, names=self) if isinstance(other, StructuredUnit): return NotImplemented # Anything not like a unit, try initialising as a structured quantity. try: from .quantity import Quantity return Quantity(other, unit=self) except Exception: return NotImplemented def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): if isinstance(other, str): try: other = Unit(other, parse_strict="silent") except Exception: return NotImplemented if isinstance(other, UnitBase): new_units = tuple(part / other for part in self.values()) return self.__class__(new_units, names=self) return NotImplemented def __rlshift__(self, m): try: from .quantity import Quantity return Quantity(m, self, copy=False, subok=True) except Exception: return NotImplemented def __str__(self): return self.to_string() def __repr__(self): return f'Unit("{self.to_string()}")' def __eq__(self, other): try: other = StructuredUnit(other) except Exception: return NotImplemented return self.values() == other.values() def __ne__(self, other): if not isinstance(other, type(self)): try: other = StructuredUnit(other) except Exception: return NotImplemented return self.values() != other.values()
class Structure(np.void): """Single element structure for physical type IDs, etc. Behaves like a `~numpy.void` and thus mostly like a tuple which can also be indexed with field names, but overrides ``__eq__`` and ``__ne__`` to compare only the contents, not the field names. Furthermore, this way no `FutureWarning` about comparisons is given. """ # Note that it is important for physical type IDs to not be stored in a # tuple, since then the physical types would be treated as alternatives in # :meth:`~astropy.units.UnitBase.is_equivalent`. (Of course, in that # case, they could also not be indexed by name.) def __eq__(self, other): if isinstance(other, np.void): other = other.item() return self.item() == other def __ne__(self, other): if isinstance(other, np.void): other = other.item() return self.item() != other def _structured_unit_like_dtype( unit: UnitBase | StructuredUnit, dtype: np.dtype ) -> StructuredUnit: """Make a `StructuredUnit` of one unit, with the structure of a `numpy.dtype`. Parameters ---------- unit : UnitBase The unit that will be filled into the structure. dtype : `numpy.dtype` The structure for the StructuredUnit. Returns ------- StructuredUnit """ if isinstance(unit, StructuredUnit): # If unit is structured, it should match the dtype. This function is # only used in Quantity, which performs this check, so it's fine to # return as is. return unit # Make a structured unit units = [] for name in dtype.names: subdtype = dtype.fields[name][0] if subdtype.names is not None: units.append(_structured_unit_like_dtype(unit, subdtype)) else: units.append(unit) return StructuredUnit(tuple(units), names=dtype.names)