Source code for astropy.utils.diff

import difflib
import functools
import numbers
import sys

import numpy as np

from .misc import indent

__all__ = [
    "fixed_width_indent",
    "diff_values",
    "report_diff_values",
    "where_not_allclose",
]


# Smaller default shift-width for indent
fixed_width_indent = functools.partial(indent, width=2)


[docs]def diff_values(a, b, rtol=0.0, atol=0.0): """ Diff two scalar values. If both values are floats, they are compared to within the given absolute and relative tolerance. Parameters ---------- a, b : int, float, str Scalar values to compare. rtol, atol : float Relative and absolute tolerances as accepted by :func:`numpy.allclose`. Returns ------- is_different : bool `True` if they are different, else `False`. """ if isinstance(a, float) and isinstance(b, float): if np.isnan(a) and np.isnan(b): return False return not np.allclose(a, b, rtol=rtol, atol=atol) else: return a != b
[docs]def report_diff_values(a, b, fileobj=sys.stdout, indent_width=0, rtol=0.0, atol=0.0): """ Write a diff report between two values to the specified file-like object. Parameters ---------- a, b Values to compare. Anything that can be turned into strings and compared using :py:mod:`difflib` should work. fileobj : object File-like object to write to. The default is ``sys.stdout``, which writes to terminal. indent_width : int Character column(s) to indent. rtol, atol : float Relative and absolute tolerances as accepted by :func:`numpy.allclose`. Returns ------- identical : bool `True` if no diff, else `False`. """ if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): if a.shape != b.shape: fileobj.write( fixed_width_indent(" Different array shapes:\n", indent_width) ) report_diff_values( str(a.shape), str(b.shape), fileobj=fileobj, indent_width=indent_width + 1, ) return False if np.issubdtype(a.dtype, np.floating) and np.issubdtype(b.dtype, np.floating): diff_indices = np.transpose(where_not_allclose(a, b, rtol=rtol, atol=atol)) else: diff_indices = np.transpose(np.where(a != b)) num_diffs = diff_indices.shape[0] for idx in diff_indices[:3]: lidx = idx.tolist() fileobj.write(fixed_width_indent(f" at {lidx!r}:\n", indent_width)) report_diff_values( a[tuple(idx)], b[tuple(idx)], fileobj=fileobj, indent_width=indent_width + 1, rtol=rtol, atol=atol, ) if num_diffs > 3: fileobj.write( fixed_width_indent( f" ...and at {num_diffs - 3:d} more indices.\n", indent_width ) ) return False return num_diffs == 0 typea = type(a) typeb = type(b) if typea == typeb: lnpad = " " sign_a = "a>" sign_b = "b>" if isinstance(a, numbers.Number): a = repr(a) b = repr(b) else: a = str(a) b = str(b) else: padding = max(len(typea.__name__), len(typeb.__name__)) + 3 lnpad = (padding + 1) * " " sign_a = ("(" + typea.__name__ + ") ").rjust(padding) + "a>" sign_b = ("(" + typeb.__name__ + ") ").rjust(padding) + "b>" is_a_str = isinstance(a, str) is_b_str = isinstance(b, str) a = ( repr(a) if ( (is_a_str and not is_b_str) or (not is_a_str and isinstance(a, numbers.Number)) ) else str(a) ) b = ( repr(b) if ( (is_b_str and not is_a_str) or (not is_b_str and isinstance(b, numbers.Number)) ) else str(b) ) identical = True for line in difflib.ndiff(a.splitlines(), b.splitlines()): if line[0] == "-": identical = False line = sign_a + line[1:] elif line[0] == "+": identical = False line = sign_b + line[1:] else: line = lnpad + line fileobj.write( fixed_width_indent(" {}\n".format(line.rstrip("\n")), indent_width) ) return identical
[docs]def where_not_allclose(a, b, rtol=1e-5, atol=1e-8): """ A version of :func:`numpy.allclose` that returns the indices where the two arrays differ, instead of just a boolean value. Parameters ---------- a, b : array-like Input arrays to compare. rtol, atol : float Relative and absolute tolerances as accepted by :func:`numpy.allclose`. Returns ------- idx : tuple of array Indices where the two arrays differ. """ # Create fixed mask arrays to handle INF and NaN; currently INF and NaN # are handled as equivalent if not np.all(np.isfinite(a)): a = np.ma.fix_invalid(a).data if not np.all(np.isfinite(b)): b = np.ma.fix_invalid(b).data if atol == 0.0 and rtol == 0.0: # Use a faster comparison for the most simple (and common) case return np.where(a != b) return np.where(np.abs(a - b) > (atol + rtol * np.abs(b)))