Source code for astropy.timeseries.core
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from contextlib import contextmanager
from functools import wraps
from types import FunctionType
from astropy.table import QTable
__all__ = ["BaseTimeSeries", "autocheck_required_columns"]
COLUMN_RELATED_METHODS = [
    "add_column",
    "add_columns",
    "keep_columns",
    "remove_column",
    "remove_columns",
    "rename_column",
]
[docs]
def autocheck_required_columns(cls):
    """
    This is a decorator that ensures that the table contains specific
    methods indicated by the _required_columns attribute. The aim is to
    decorate all methods that might affect the columns in the table and check
    for consistency after the methods have been run.
    """
    def decorator_method(method):
        @wraps(method)
        def wrapper(self, *args, **kwargs):
            result = method(self, *args, **kwargs)
            self._check_required_columns()
            return result
        return wrapper
    for name in COLUMN_RELATED_METHODS:
        if not hasattr(cls, name) or not isinstance(getattr(cls, name), FunctionType):
            raise ValueError(f"{name} is not a valid method")
        setattr(cls, name, decorator_method(getattr(cls, name)))
    return cls 
[docs]
class BaseTimeSeries(QTable):
    _required_columns = None
    _required_columns_enabled = True
    # If _required_column_relax is True, we don't require the columns to be
    # present but we do require them to be the correct ones IF present. Note
    # that this is a temporary state - as soon as the required columns
    # are all present, we toggle this to False
    _required_columns_relax = False
    def _check_required_columns(self):
        def as_scalar_or_list_str(obj):
            if not hasattr(obj, "__len__"):
                return f"'{obj}'"
            elif len(obj) == 1:
                return f"'{obj[0]}'"
            else:
                return str(obj)
        if not self._required_columns_enabled:
            return
        if self._required_columns is not None:
            if self._required_columns_relax:
                required_columns = self._required_columns[: len(self.colnames)]
            else:
                required_columns = self._required_columns
            plural = "s" if len(required_columns) > 1 else ""
            if not self._required_columns_relax and len(self.colnames) == 0:
                raise ValueError(
                    f"{self.__class__.__name__} object is invalid - expected"
                    f" '{required_columns[0]}' as the first column{plural} but time"
                    " series has no columns"
                )
            elif self.colnames[: len(required_columns)] != required_columns:
                raise ValueError(
                    f"{self.__class__.__name__} object is invalid - expected"
                    f" {as_scalar_or_list_str(required_columns)} as the first"
                    f" column{plural} but found"
                    f" {as_scalar_or_list_str(self.colnames[: len(required_columns)])}"
                )
            if (
                self._required_columns_relax
                and self._required_columns
                == self.colnames[: len(self._required_columns)]
            ):
                self._required_columns_relax = False
    @contextmanager
    def _delay_required_column_checks(self):
        self._required_columns_enabled = False
        yield
        self._required_columns_enabled = True
        self._check_required_columns()