# Licensed under a 3-clause BSD style license - see LICENSE.rst
import contextlib
import re
import warnings
from operator import itemgetter
import numpy as np
__all__ = ["IORegistryError"]
[docs]
class IORegistryError(Exception):
    """Custom error for registry clashes.""" 
# -----------------------------------------------------------------------------
class _UnifiedIORegistryBase:
    """Base class for registries in Astropy's Unified IO.
    This base class provides identification functions and miscellaneous
    utilities. For an example how to build a registry subclass we suggest
    :class:`~astropy.io.registry.UnifiedInputRegistry`, which enables
    read-only registries. These higher-level subclasses will probably serve
    better as a baseclass, for instance
    :class:`~astropy.io.registry.UnifiedIORegistry` subclasses both
    :class:`~astropy.io.registry.UnifiedInputRegistry` and
    :class:`~astropy.io.registry.UnifiedOutputRegistry` to enable both
    reading from and writing to files.
    .. versionadded:: 5.0
    """
    def __init__(self):
        # registry of identifier functions
        self._identifiers = {}
        # what this class can do: e.g. 'read' &/or 'write'
        self._registries = {}
        self._registries["identify"] = {
            "attr": "_identifiers",
            "column": "Auto-identify",
        }
        self._registries_order = ("identify",)  # match keys in `_registries`
        # If multiple formats are added to one class the update of the docs is quite
        # expensive. Classes for which the doc update is temporarily delayed are added
        # to this set.
        self._delayed_docs_classes = set()
    @property
    def available_registries(self):
        """Available registries.
        Returns
        -------
        ``dict_keys``
        """
        return self._registries.keys()
    def get_formats(self, data_class=None, filter_on=None):
        """
        Get the list of registered formats as a `~astropy.table.Table`.
        Parameters
        ----------
        data_class : class or None, optional
            Filter readers/writer to match data class (default = all classes).
        filter_on : str or None, optional
            Which registry to show. E.g. "identify"
            If None search for both.  Default is None.
        Returns
        -------
        format_table : :class:`~astropy.table.Table`
            Table of available I/O formats.
        Raises
        ------
        ValueError
            If ``filter_on`` is not None nor a registry name.
        """
        from astropy.table import Table
        # set up the column names
        colnames = (
            "Data class",
            "Format",
            *[self._registries[k]["column"] for k in self._registries_order],
            "Deprecated",
        )
        i_dataclass = colnames.index("Data class")
        i_format = colnames.index("Format")
        i_regstart = colnames.index(
            self._registries[self._registries_order[0]]["column"]
        )
        i_deprecated = colnames.index("Deprecated")
        # registries
        regs = set()
        for k in self._registries.keys() - {"identify"}:
            regs |= set(getattr(self, self._registries[k]["attr"]))
        format_classes = sorted(regs, key=itemgetter(0))
        # the format classes from all registries except "identify"
        rows = []
        for fmt, cls in format_classes:
            # see if can skip, else need to document in row
            if data_class is not None and not self._is_best_match(
                data_class, cls, format_classes
            ):
                continue
            # flags for each registry
            has_ = {
                k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No"
                for k, v in self._registries.items()
            }
            # Check if this is a short name (e.g. 'rdb') which is deprecated in
            # favor of the full 'ascii.rdb'.
            ascii_format_class = ("ascii." + fmt, cls)
            # deprecation flag
            deprecated = "Yes" if ascii_format_class in format_classes else ""
            # add to rows
            rows.append(
                (
                    cls.__name__,
                    fmt,
                    *[has_[n] for n in self._registries_order],
                    deprecated,
                )
            )
        # filter_on can be in self_registries_order or None
        if str(filter_on).lower() in self._registries_order:
            index = self._registries_order.index(str(filter_on).lower())
            rows = [row for row in rows if row[i_regstart + index] == "Yes"]
        elif filter_on is not None:
            raise ValueError(
                'unrecognized value for "filter_on": {0}.\n'
                f"Allowed are {self._registries_order} and None."
            )
        # Sorting the list of tuples is much faster than sorting it after the
        # table is created. (#5262)
        if rows:
            # Indices represent "Data Class", "Deprecated" and "Format".
            data = list(
                zip(*sorted(rows, key=itemgetter(i_dataclass, i_deprecated, i_format)))
            )
        else:
            data = None
        # make table
        # need to filter elementwise comparison failure issue
        # https://github.com/numpy/numpy/issues/6784
        with warnings.catch_warnings():
            warnings.simplefilter(action="ignore", category=FutureWarning)
            format_table = Table(data, names=colnames)
            if not np.any(format_table["Deprecated"].data == "Yes"):
                format_table.remove_column("Deprecated")
        return format_table
    @contextlib.contextmanager
    def delay_doc_updates(self, cls):
        """Contextmanager to disable documentation updates when registering
        reader and writer. The documentation is only built once when the
        contextmanager exits.
        .. versionadded:: 1.3
        Parameters
        ----------
        cls : class
            Class for which the documentation updates should be delayed.
        Notes
        -----
        Registering multiple readers and writers can cause significant overhead
        because the documentation of the corresponding ``read`` and ``write``
        methods are build every time.
        Examples
        --------
        see for example the source code of ``astropy.table.__init__``.
        """
        self._delayed_docs_classes.add(cls)
        yield
        self._delayed_docs_classes.discard(cls)
        for method in self._registries.keys() - {"identify"}:
            self._update__doc__(cls, method)
    # =========================================================================
    # Identifier methods
    def register_identifier(self, data_format, data_class, identifier, force=False):
        """
        Associate an identifier function with a specific data type.
        Parameters
        ----------
        data_format : str
            The data format identifier. This is the string that is used to
            specify the data type when reading/writing.
        data_class : class
            The class of the object that can be written.
        identifier : function
            A function that checks the argument specified to `read` or `write` to
            determine whether the input can be interpreted as a table of type
            ``data_format``. This function should take the following arguments:
               - ``origin``: A string ``"read"`` or ``"write"`` identifying whether
                 the file is to be opened for reading or writing.
               - ``path``: The path to the file.
               - ``fileobj``: An open file object to read the file's contents, or
                 `None` if the file could not be opened.
               - ``*args``: Positional arguments for the `read` or `write`
                 function.
               - ``**kwargs``: Keyword arguments for the `read` or `write`
                 function.
            One or both of ``path`` or ``fileobj`` may be `None`.  If they are
            both `None`, the identifier will need to work from ``args[0]``.
            The function should return True if the input can be identified
            as being of format ``data_format``, and False otherwise.
        force : bool, optional
            Whether to override any existing function if already present.
            Default is ``False``.
        Examples
        --------
        To set the identifier based on extensions, for formats that take a
        filename as a first argument, you can do for example
        .. code-block:: python
            from astropy.io.registry import register_identifier
            from astropy.table import Table
            def my_identifier(*args, **kwargs):
                return isinstance(args[0], str) and args[0].endswith('.tbl')
            register_identifier('ipac', Table, my_identifier)
            unregister_identifier('ipac', Table)
        """
        if not (data_format, data_class) in self._identifiers or force:  # noqa: E713
            self._identifiers[(data_format, data_class)] = identifier
        else:
            raise IORegistryError(
                f"Identifier for format {data_format!r} and class"
                f" {data_class.__name__!r} is already defined"
            )
    def unregister_identifier(self, data_format, data_class):
        """
        Unregister an identifier function.
        Parameters
        ----------
        data_format : str
            The data format identifier.
        data_class : class
            The class of the object that can be read/written.
        """
        if (data_format, data_class) in self._identifiers:
            self._identifiers.pop((data_format, data_class))
        else:
            raise IORegistryError(
                f"No identifier defined for format {data_format!r} and class"
                f" {data_class.__name__!r}"
            )
    def identify_format(self, origin, data_class_required, path, fileobj, args, kwargs):
        """Loop through identifiers to see which formats match.
        Parameters
        ----------
        origin : str
            A string ``"read`` or ``"write"`` identifying whether the file is to be
            opened for reading or writing.
        data_class_required : object
            The specified class for the result of `read` or the class that is to be
            written.
        path : str or path-like or None
            The path to the file or None.
        fileobj : file-like or None.
            An open file object to read the file's contents, or ``None`` if the
            file could not be opened.
        args : sequence
            Positional arguments for the `read` or `write` function. Note that
            these must be provided as sequence.
        kwargs : dict-like
            Keyword arguments for the `read` or `write` function. Note that this
            parameter must be `dict`-like.
        Returns
        -------
        valid_formats : list
            List of matching formats.
        """
        valid_formats = []
        for data_format, data_class in self._identifiers:
            if self._is_best_match(data_class_required, data_class, self._identifiers):
                if self._identifiers[(data_format, data_class)](
                    origin, path, fileobj, *args, **kwargs
                ):
                    valid_formats.append(data_format)
        return valid_formats
    # =========================================================================
    # Utils
    def _get_format_table_str(self, data_class, filter_on):
        """``get_formats()``, without column "Data class", as a str."""
        format_table = self.get_formats(data_class, filter_on)
        format_table.remove_column("Data class")
        format_table_str = "\n".join(format_table.pformat(max_lines=-1))
        return format_table_str
    def _is_best_match(self, class1, class2, format_classes):
        """Determine if class2 is the "best" match for class1 in the list of classes.
        It is assumed that (class2 in classes) is True.
        class2 is the best match if:
        - ``class1`` is a subclass of ``class2`` AND
        - ``class2`` is the nearest ancestor of ``class1`` that is in classes
          (which includes the case that ``class1 is class2``)
        """
        if issubclass(class1, class2):
            classes = {cls for fmt, cls in format_classes}
            for parent in class1.__mro__:
                if parent is class2:  # class2 is closest registered ancestor
                    return True
                if parent in classes:  # class2 was superseded
                    return False
        return False
    def _get_valid_format(self, mode, cls, path, fileobj, args, kwargs):
        """
        Returns the first valid format that can be used to read/write the data in
        question.  Mode can be either 'read' or 'write'.
        """
        valid_formats = self.identify_format(mode, cls, path, fileobj, args, kwargs)
        if len(valid_formats) == 0:
            format_table_str = self._get_format_table_str(cls, mode.capitalize())
            raise IORegistryError(
                "Format could not be identified based on the"
                " file name or contents, please provide a"
                " 'format' argument.\n"
                f"The available formats are:\n{format_table_str}"
            )
        elif len(valid_formats) > 1:
            return self._get_highest_priority_format(mode, cls, valid_formats)
        return valid_formats[0]
    def _get_highest_priority_format(self, mode, cls, valid_formats):
        """
        Returns the reader or writer with the highest priority. If it is a tie,
        error.
        """
        if mode == "read":
            format_dict = self._readers
            mode_loader = "reader"
        elif mode == "write":
            format_dict = self._writers
            mode_loader = "writer"
        best_formats = []
        current_priority = -np.inf
        for format in valid_formats:
            try:
                _, priority = format_dict[(format, cls)]
            except KeyError:
                # We could throw an exception here, but get_reader/get_writer handle
                # this case better, instead maximally deprioritise the format.
                priority = -np.inf
            if priority == current_priority:
                best_formats.append(format)
            elif priority > current_priority:
                best_formats = [format]
                current_priority = priority
        if len(best_formats) > 1:
            raise IORegistryError(
                "Format is ambiguous - options are:"
                f" {', '.join(sorted(valid_formats, key=itemgetter(0)))}"
            )
        return best_formats[0]
    def _update__doc__(self, data_class, readwrite):
        """
        Update the docstring to include all the available readers / writers for
        the ``data_class.read``/``data_class.write`` functions (respectively).
        Don't update if the data_class does not have the relevant method.
        """
        # abort if method "readwrite" isn't on data_class
        if not hasattr(data_class, readwrite):
            return
        from .interface import UnifiedReadWrite
        FORMATS_TEXT = "The available built-in formats are:"
        # Get the existing read or write method and its docstring
        class_readwrite_func = getattr(data_class, readwrite)
        if not isinstance(class_readwrite_func.__doc__, str):
            # No docstring--could just be test code, or possibly code compiled
            # without docstrings
            return
        lines = class_readwrite_func.__doc__.splitlines()
        # Find the location of the existing formats table if it exists
        sep_indices = [ii for ii, line in enumerate(lines) if FORMATS_TEXT in line]
        if sep_indices:
            # Chop off the existing formats table, including the initial blank line
            chop_index = sep_indices[0]
            lines = lines[:chop_index]
        # Find the minimum indent, skipping the first line because it might be odd
        matches = [re.search(r"(\S)", line) for line in lines[1:]]
        left_indent = " " * min(match.start() for match in matches if match)
        # Get the available unified I/O formats for this class
        # Include only formats that have a reader, and drop the 'Data class' column
        format_table = self.get_formats(data_class, readwrite.capitalize())
        format_table.remove_column("Data class")
        # Get the available formats as a table, then munge the output of pformat()
        # a bit and put it into the docstring.
        new_lines = format_table.pformat(max_lines=-1, max_width=80)
        table_rst_sep = re.sub("-", "=", new_lines[1])
        new_lines[1] = table_rst_sep
        new_lines.insert(0, table_rst_sep)
        new_lines.append(table_rst_sep)
        # Check for deprecated names and include a warning at the end.
        if "Deprecated" in format_table.colnames:
            new_lines.extend(
                [
                    "",
                    "Deprecated format names like ``aastex`` will be "
                    "removed in a future version. Use the full ",
                    "name (e.g. ``ascii.aastex``) instead.",
                ]
            )
        new_lines = [FORMATS_TEXT, ""] + new_lines
        lines.extend([left_indent + line for line in new_lines])
        # Depending on Python version and whether class_readwrite_func is
        # an instancemethod or classmethod, one of the following will work.
        if isinstance(class_readwrite_func, UnifiedReadWrite):
            class_readwrite_func.__class__.__doc__ = "\n".join(lines)
        else:
            try:
                class_readwrite_func.__doc__ = "\n".join(lines)
            except AttributeError:
                class_readwrite_func.__func__.__doc__ = "\n".join(lines)