Source code for asdf.tests.helpers

import io
import os
import warnings
from contextlib import contextmanager
from pathlib import Path

try:
    from astropy.coordinates import ICRS
except ImportError:
    ICRS = None

try:
    from astropy.coordinates.representation import CartesianRepresentation
except ImportError:
    CartesianRepresentation = None

try:
    from astropy.coordinates.representation import CartesianDifferential
except ImportError:
    CartesianDifferential = None

import yaml

import asdf

from .. import generic_io, versioning
from ..asdf import AsdfFile, get_asdf_library_info
from ..block import Block
from ..constants import YAML_TAG_PREFIX
from ..exceptions import AsdfConversionWarning
from ..extension import default_extensions
from ..resolver import Resolver, ResolverChain
from ..tags.core import AsdfObject
from ..versioning import (
    AsdfVersion,
    asdf_standard_development_version,
    get_version_map,
    split_tag_version,
    supported_versions,
)
from .httpserver import RangeHTTPServer

try:
    from pytest_remotedata.disable_internet import INTERNET_OFF
except ImportError:
    INTERNET_OFF = False


__all__ = [
    "get_test_data_path",
    "assert_tree_match",
    "assert_roundtrip_tree",
    "yaml_to_asdf",
    "get_file_sizes",
    "display_warnings",
]


[docs]def get_test_data_path(name, module=None): if module is None: from . import data as test_data module = test_data module_root = Path(module.__file__).parent if name is None or name == "": return str(module_root) else: return str(module_root / name)
[docs]def assert_tree_match(old_tree, new_tree, ctx=None, funcname="assert_equal", ignore_keys=None): """ Assert that two ASDF trees match. Parameters ---------- old_tree : ASDF tree new_tree : ASDF tree ctx : ASDF file context Used to look up the set of types in effect. funcname : `str` or `callable` The name of a method on members of old_tree and new_tree that will be used to compare custom objects. The default of ``assert_equal`` handles Numpy arrays. ignore_keys : list of str List of keys to ignore """ seen = set() if ignore_keys is None: ignore_keys = ["asdf_library", "history"] ignore_keys = set(ignore_keys) if ctx is None: version_string = str(versioning.default_version) ctx = default_extensions.extension_list else: version_string = ctx.version_string def recurse(old, new): if id(old) in seen or id(new) in seen: return seen.add(id(old)) seen.add(id(new)) old_type = ctx.type_index.from_custom_type(type(old), version_string) new_type = ctx.type_index.from_custom_type(type(new), version_string) if ( old_type is not None and new_type is not None and old_type is new_type and (callable(funcname) or hasattr(old_type, funcname)) ): if callable(funcname): funcname(old, new) else: getattr(old_type, funcname)(old, new) elif isinstance(old, dict) and isinstance(new, dict): assert {x for x in old.keys() if x not in ignore_keys} == {x for x in new.keys() if x not in ignore_keys} for key in old.keys(): if key not in ignore_keys: recurse(old[key], new[key]) elif isinstance(old, (list, tuple)) and isinstance(new, (list, tuple)): assert len(old) == len(new) for a, b in zip(old, new): recurse(a, b) # The astropy classes CartesianRepresentation, CartesianDifferential, # and ICRS do not define equality in a way that is meaningful for unit # tests. We explicitly compare the fields that we care about in order # to enable our unit testing. It is possible that in the future it will # be necessary or useful to account for fields that are not currently # compared. elif CartesianRepresentation is not None and isinstance(old, CartesianRepresentation): assert old.x == new.x and old.y == new.y and old.z == new.z elif CartesianDifferential is not None and isinstance(old, CartesianDifferential): assert old.d_x == new.d_x and old.d_y == new.d_y and old.d_z == new.d_z elif ICRS is not None and isinstance(old, ICRS): assert old.ra == new.ra and old.dec == new.dec else: assert old == new recurse(old_tree, new_tree)
[docs]def assert_roundtrip_tree(*args, **kwargs): """ Assert that a given tree saves to ASDF and, when loaded back, the tree matches the original tree. tree : ASDF tree tmp_path : `str` or `pathlib.Path` Path to temporary directory to save file tree_match_func : `str` or `callable` Passed to `assert_tree_match` and used to compare two objects in the tree. raw_yaml_check_func : callable, optional Will be called with the raw YAML content as a string to perform any additional checks. asdf_check_func : callable, optional Will be called with the reloaded ASDF file to perform any additional checks. """ with warnings.catch_warnings(): warnings.filterwarnings("error", category=AsdfConversionWarning) _assert_roundtrip_tree(*args, **kwargs)
def _assert_roundtrip_tree( tree, tmp_path, *, asdf_check_func=None, raw_yaml_check_func=None, write_options={}, init_options={}, extensions=None, tree_match_func="assert_equal", ): fname = os.path.join(str(tmp_path), "test.asdf") # First, test writing/reading a BytesIO buffer buff = io.BytesIO() AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options) assert not buff.closed buff.seek(0) with asdf.open(buff, mode="rw", extensions=extensions) as ff: assert not buff.closed assert isinstance(ff.tree, AsdfObject) assert "asdf_library" in ff.tree assert ff.tree["asdf_library"] == get_asdf_library_info() assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) if asdf_check_func: asdf_check_func(ff) buff.seek(0) ff = AsdfFile(extensions=extensions, **init_options) content = AsdfFile._open_impl(ff, buff, mode="r", _get_yaml_content=True) buff.close() # We *never* want to get any raw python objects out assert b"!!python" not in content assert b"!core/asdf" in content assert content.startswith(b"%YAML 1.1") if raw_yaml_check_func: raw_yaml_check_func(content) # Then, test writing/reading to a real file ff = AsdfFile(tree, extensions=extensions, **init_options) ff.write_to(fname, **write_options) with asdf.open(fname, mode="rw", extensions=extensions) as ff: assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) if asdf_check_func: asdf_check_func(ff) # Make sure everything works without a block index write_options["include_block_index"] = False buff = io.BytesIO() AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options) assert not buff.closed buff.seek(0) with asdf.open(buff, mode="rw", extensions=extensions) as ff: assert not buff.closed assert isinstance(ff.tree, AsdfObject) assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) if asdf_check_func: asdf_check_func(ff) # Now try everything on an HTTP range server if not INTERNET_OFF: server = RangeHTTPServer() try: ff = AsdfFile(tree, extensions=extensions, **init_options) ff.write_to(os.path.join(server.tmpdir, "test.asdf"), **write_options) with asdf.open(server.url + "test.asdf", mode="r", extensions=extensions) as ff: assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) if asdf_check_func: asdf_check_func(ff) finally: server.finalize() # Now don't be lazy and check that nothing breaks with io.BytesIO() as buff: AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options) buff.seek(0) ff = asdf.open(buff, extensions=extensions, copy_arrays=True, lazy_load=False) # Ensure that all the blocks are loaded for block in ff.blocks._internal_blocks: assert isinstance(block, Block) assert block._data is not None # The underlying file is closed at this time and everything should still work assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) if asdf_check_func: asdf_check_func(ff) # Now repeat with copy_arrays=False and a real file to test mmap() AsdfFile(tree, extensions=extensions, **init_options).write_to(fname, **write_options) with asdf.open(fname, mode="rw", extensions=extensions, copy_arrays=False, lazy_load=False) as ff: for block in ff.blocks._internal_blocks: assert isinstance(block, Block) assert block._data is not None assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) if asdf_check_func: asdf_check_func(ff)
[docs]def yaml_to_asdf(yaml_content, yaml_headers=True, standard_version=None): """ Given a string of YAML content, adds the extra pre- and post-amble to make it an ASDF file. Parameters ---------- yaml_content : string yaml_headers : bool, optional When True (default) add the standard ASDF YAML headers. Returns ------- buff : io.BytesIO() A file-like object containing the ASDF-like content. """ if isinstance(yaml_content, str): yaml_content = yaml_content.encode("utf-8") buff = io.BytesIO() if standard_version is None: standard_version = versioning.default_version standard_version = AsdfVersion(standard_version) vm = get_version_map(standard_version) file_format_version = vm["FILE_FORMAT"] yaml_version = vm["YAML_VERSION"] tree_version = vm["tags"]["tag:stsci.edu:asdf/core/asdf"] if yaml_headers: buff.write( """#ASDF {} #ASDF_STANDARD {} %YAML {} %TAG ! tag:stsci.edu:asdf/ --- !core/asdf-{} """.format( file_format_version, standard_version, yaml_version, tree_version ).encode( "ascii" ) ) buff.write(yaml_content) if yaml_headers: buff.write(b"\n...\n") buff.seek(0) return buff
[docs]def get_file_sizes(dirname): """ Get the file sizes in a directory. Parameters ---------- dirname : string Path to a directory Returns ------- sizes : dict Dictionary of (file, size) pairs. """ files = {} for filename in os.listdir(dirname): path = os.path.join(dirname, filename) if os.path.isfile(path): files[filename] = os.stat(path).st_size return files
[docs]def display_warnings(_warnings): """ Return a string that displays a list of unexpected warnings Parameters ---------- _warnings : iterable List of warnings to be displayed Returns ------- msg : str String containing the warning messages to be displayed """ if len(_warnings) == 0: return "No warnings occurred (was one expected?)" msg = "Unexpected warning(s) occurred:\n" for warning in _warnings: msg += f"{warning.filename}:{warning.lineno}: {warning.category.__name__}: {warning.message}\n" return msg
@contextmanager def assert_no_warnings(warning_class=None): """ Assert that no warnings were emitted within the context. Requires that pytest be installed. Parameters ---------- warning_class : type, optional Assert only that no warnings of the specified class were emitted. """ import pytest if warning_class is None: with warnings.catch_warnings(): warnings.simplefilter("error") yield else: with pytest.warns(Warning) as recorded_warnings: yield assert not any(isinstance(w.message, warning_class) for w in recorded_warnings), display_warnings( recorded_warnings ) def assert_extension_correctness(extension): """ Assert that an ASDF extension's types are all correctly formed and that the extension provides all of the required schemas. Parameters ---------- extension : asdf.AsdfExtension The extension to validate """ __tracebackhide__ = True resolver = ResolverChain( Resolver(extension.tag_mapping, "tag"), Resolver(extension.url_mapping, "url"), ) for extension_type in extension.types: _assert_extension_type_correctness(extension, extension_type, resolver) def _assert_extension_type_correctness(extension, extension_type, resolver): __tracebackhide__ = True if extension_type.yaml_tag is not None and extension_type.yaml_tag.startswith(YAML_TAG_PREFIX): return if extension_type == asdf.stream.Stream: # Stream is a special case. It was implemented as a subclass of NDArrayType, # but shares a tag with that class, so it isn't really a distinct type. return assert extension_type.name is not None, f"{extension_type.__name__} must set the 'name' class attribute" # Currently ExtensionType sets a default version of 1.0.0, # but we want to encourage an explicit version on the subclass. assert "version" in extension_type.__dict__, "{} must set the 'version' class attribute".format( extension_type.__name__ ) # check the default version types_to_check = [extension_type] # Adding or updating a schema/type version might involve updating multiple # packages. This can result in types without schema and schema without types # for the development version of the asdf-standard. To account for this, # don't include versioned siblings of types with versions that are not # in one of the asdf-standard versions in supported_versions (excluding the # current development version). asdf_standard_versions = supported_versions.copy() if asdf_standard_development_version in asdf_standard_versions: asdf_standard_versions.remove(asdf_standard_development_version) for sibling in extension_type.versioned_siblings: tag_base, version = split_tag_version(sibling.yaml_tag) for asdf_standard_version in asdf_standard_versions: vm = get_version_map(asdf_standard_version) if tag_base in vm["tags"] and AsdfVersion(vm["tags"][tag_base]) == version: types_to_check.append(sibling) break for check_type in types_to_check: schema_location = resolver(check_type.yaml_tag) assert schema_location is not None, ( f"{extension_type.__name__} supports tag, {check_type.yaml_tag}, " + "but tag does not resolve. Check the tag_mapping and uri_mapping " + f"properties on the related extension ({extension_type.__name__})." ) if schema_location not in asdf.get_config().resource_manager: try: with generic_io.get_file(schema_location) as f: yaml.safe_load(f.read()) except Exception: assert False, ( f"{extension_type.__name__} supports tag, {check_type.yaml_tag}, " + f"which resolves to schema at {schema_location}, but " + "schema cannot be read." )