Source code for astropy.io.fits.hdu.groups

# Licensed under a 3-clause BSD style license - see PYFITS.rst

import sys

import numpy as np

from astropy.io.fits.column import FITS2NUMPY, ColDefs, Column
from astropy.io.fits.fitsrec import FITS_rec, FITS_record
from astropy.io.fits.util import _is_int, _is_pseudo_integer, _pseudo_zero
from astropy.utils import lazyproperty

from .base import DELAYED, DTYPE2BITPIX
from .image import PrimaryHDU
from .table import _TableLikeHDU


[docs]class Group(FITS_record): """ One group of the random group data. """ def __init__(self, input, row=0, start=None, end=None, step=None, base=None): super().__init__(input, row, start, end, step, base) @property def parnames(self): return self.array.parnames @property def data(self): # The last column in the coldefs is the data portion of the group return self.field(self.array._coldefs.names[-1]) @lazyproperty def _unique(self): return _par_indices(self.parnames)
[docs] def par(self, parname): """ Get the group parameter value. """ if _is_int(parname): result = self.array[self.row][parname] else: indx = self._unique[parname.upper()] if len(indx) == 1: result = self.array[self.row][indx[0]] # if more than one group parameter have the same name else: result = self.array[self.row][indx[0]].astype("f8") for i in indx[1:]: result += self.array[self.row][i] return result
[docs] def setpar(self, parname, value): """ Set the group parameter value. """ # TODO: It would be nice if, instead of requiring a multi-part value to # be an array, there were an *option* to automatically split the value # into multiple columns if it doesn't already fit in the array data # type. if _is_int(parname): self.array[self.row][parname] = value else: indx = self._unique[parname.upper()] if len(indx) == 1: self.array[self.row][indx[0]] = value # if more than one group parameter have the same name, the # value must be a list (or tuple) containing arrays else: if isinstance(value, (list, tuple)) and len(indx) == len(value): for i in range(len(indx)): self.array[self.row][indx[i]] = value[i] else: raise ValueError( "Parameter value must be a sequence with " "{} arrays/numbers.".format(len(indx)) )
[docs]class GroupData(FITS_rec): """ Random groups data object. Allows structured access to FITS Group data in a manner analogous to tables. """ _record_type = Group def __new__( cls, input=None, bitpix=None, pardata=None, parnames=[], bscale=None, bzero=None, parbscales=None, parbzeros=None, ): """ Parameters ---------- input : array or FITS_rec instance input data, either the group data itself (a `numpy.ndarray`) or a record array (`FITS_rec`) which will contain both group parameter info and the data. The rest of the arguments are used only for the first case. bitpix : int data type as expressed in FITS ``BITPIX`` value (8, 16, 32, 64, -32, or -64) pardata : sequence of array parameter data, as a list of (numeric) arrays. parnames : sequence of str list of parameter names. bscale : int ``BSCALE`` of the data bzero : int ``BZERO`` of the data parbscales : sequence of int list of bscales for the parameters parbzeros : sequence of int list of bzeros for the parameters """ if not isinstance(input, FITS_rec): if pardata is None: npars = 0 else: npars = len(pardata) if parbscales is None: parbscales = [None] * npars if parbzeros is None: parbzeros = [None] * npars if parnames is None: parnames = [f"PAR{idx + 1}" for idx in range(npars)] if len(parnames) != npars: raise ValueError( "The number of parameter data arrays does " "not match the number of parameters." ) unique_parnames = _unique_parnames(parnames + ["DATA"]) if bitpix is None: bitpix = DTYPE2BITPIX[input.dtype.name] fits_fmt = GroupsHDU._bitpix2tform[bitpix] # -32 -> 'E' format = FITS2NUMPY[fits_fmt] # 'E' -> 'f4' data_fmt = f"{str(input.shape[1:])}{format}" formats = ",".join(([format] * npars) + [data_fmt]) gcount = input.shape[0] cols = [ Column( name=unique_parnames[idx], format=fits_fmt, bscale=parbscales[idx], bzero=parbzeros[idx], ) for idx in range(npars) ] cols.append( Column( name=unique_parnames[-1], format=fits_fmt, bscale=bscale, bzero=bzero, ) ) coldefs = ColDefs(cols) self = FITS_rec.__new__( cls, np.rec.array(None, formats=formats, names=coldefs.names, shape=gcount), ) # By default the data field will just be 'DATA', but it may be # uniquified if 'DATA' is already used by one of the group names self._data_field = unique_parnames[-1] self._coldefs = coldefs self.parnames = parnames for idx, name in enumerate(unique_parnames[:-1]): column = coldefs[idx] # Note: _get_scale_factors is used here and in other cases # below to determine whether the column has non-default # scale/zero factors. # TODO: Find a better way to do this than using this interface scale, zero = self._get_scale_factors(column)[3:5] if scale or zero: self._cache_field(name, pardata[idx]) else: np.rec.recarray.field(self, idx)[:] = pardata[idx] column = coldefs[self._data_field] scale, zero = self._get_scale_factors(column)[3:5] if scale or zero: self._cache_field(self._data_field, input) else: np.rec.recarray.field(self, npars)[:] = input else: self = FITS_rec.__new__(cls, input) self.parnames = None return self def __array_finalize__(self, obj): super().__array_finalize__(obj) if isinstance(obj, GroupData): self.parnames = obj.parnames elif isinstance(obj, FITS_rec): self.parnames = obj._coldefs.names def __getitem__(self, key): out = super().__getitem__(key) if isinstance(out, GroupData): out.parnames = self.parnames return out @property def data(self): """ The raw group data represented as a multi-dimensional `numpy.ndarray` array. """ # The last column in the coldefs is the data portion of the group return self.field(self._coldefs.names[-1]) @lazyproperty def _unique(self): return _par_indices(self.parnames)
[docs] def par(self, parname): """ Get the group parameter values. """ if _is_int(parname): result = self.field(parname) else: indx = self._unique[parname.upper()] if len(indx) == 1: result = self.field(indx[0]) # if more than one group parameter have the same name else: result = self.field(indx[0]).astype("f8") for i in indx[1:]: result += self.field(i) return result
[docs]class GroupsHDU(PrimaryHDU, _TableLikeHDU): """ FITS Random Groups HDU class. See the :ref:`astropy:random-groups` section in the Astropy documentation for more details on working with this type of HDU. """ _bitpix2tform = {8: "B", 16: "I", 32: "J", 64: "K", -32: "E", -64: "D"} _data_type = GroupData _data_field = "DATA" """ The name of the table record array field that will contain the group data for each group; 'DATA' by default, but may be preceded by any number of underscores if 'DATA' is already a parameter name """ def __init__(self, data=None, header=None): super().__init__(data=data, header=header) if data is not DELAYED: self.update_header() # Update the axes; GROUPS HDUs should always have at least one axis if len(self._axes) <= 0: self._axes = [0] self._header["NAXIS"] = 1 self._header.set("NAXIS1", 0, after="NAXIS")
[docs] @classmethod def match_header(cls, header): keyword = header.cards[0].keyword return keyword == "SIMPLE" and "GROUPS" in header and header["GROUPS"] is True
@lazyproperty def data(self): """ The data of a random group FITS file will be like a binary table's data. """ if self._axes == [0]: return data = self._get_tbdata() data._coldefs = self.columns data.parnames = self.parnames del self.columns return data @lazyproperty def parnames(self): """The names of the group parameters as described by the header.""" pcount = self._header["PCOUNT"] # The FITS standard doesn't really say what to do if a parname is # missing, so for now just assume that won't happen return [self._header["PTYPE" + str(idx + 1)] for idx in range(pcount)] @lazyproperty def columns(self): if self._has_data and hasattr(self.data, "_coldefs"): return self.data._coldefs format = self._bitpix2tform[self._header["BITPIX"]] pcount = self._header["PCOUNT"] parnames = [] bscales = [] bzeros = [] for idx in range(pcount): bscales.append(self._header.get("PSCAL" + str(idx + 1), None)) bzeros.append(self._header.get("PZERO" + str(idx + 1), None)) parnames.append(self._header["PTYPE" + str(idx + 1)]) formats = [format] * len(parnames) dim = [None] * len(parnames) # Now create columns from collected parameters, but first add the DATA # column too, to contain the group data. parnames.append("DATA") bscales.append(self._header.get("BSCALE")) bzeros.append(self._header.get("BZEROS")) data_shape = self.shape[:-1] formats.append(str(int(np.prod(data_shape))) + format) dim.append(data_shape) parnames = _unique_parnames(parnames) self._data_field = parnames[-1] cols = [ Column(name=name, format=fmt, bscale=bscale, bzero=bzero, dim=dim) for name, fmt, bscale, bzero, dim in zip( parnames, formats, bscales, bzeros, dim ) ] coldefs = ColDefs(cols) return coldefs @property def _nrows(self): if not self._data_loaded: # The number of 'groups' equates to the number of rows in the table # representation of the data return self._header.get("GCOUNT", 0) else: return len(self.data) @lazyproperty def _theap(self): # Only really a lazyproperty for symmetry with _TableBaseHDU return 0 @property def is_image(self): return False @property def size(self): """ Returns the size (in bytes) of the HDU's data part. """ size = 0 naxis = self._header.get("NAXIS", 0) # for random group image, NAXIS1 should be 0, so we skip NAXIS1. if naxis > 1: size = 1 for idx in range(1, naxis): size = size * self._header["NAXIS" + str(idx + 1)] bitpix = self._header["BITPIX"] gcount = self._header.get("GCOUNT", 1) pcount = self._header.get("PCOUNT", 0) size = abs(bitpix) * gcount * (pcount + size) // 8 return size
[docs] def update_header(self): old_naxis = self._header.get("NAXIS", 0) if self._data_loaded: if isinstance(self.data, GroupData): self._axes = list(self.data.data.shape)[1:] self._axes.reverse() self._axes = [0] + self._axes field0 = self.data.dtype.names[0] field0_code = self.data.dtype.fields[field0][0].name elif self.data is None: self._axes = [0] field0_code = "uint8" # For lack of a better default else: raise ValueError("incorrect array type") self._header["BITPIX"] = DTYPE2BITPIX[field0_code] self._header["NAXIS"] = len(self._axes) # add NAXISi if it does not exist for idx, axis in enumerate(self._axes): if idx == 0: after = "NAXIS" else: after = "NAXIS" + str(idx) self._header.set("NAXIS" + str(idx + 1), axis, after=after) # delete extra NAXISi's for idx in range(len(self._axes) + 1, old_naxis + 1): try: del self._header["NAXIS" + str(idx)] except KeyError: pass if self._has_data and isinstance(self.data, GroupData): self._header.set("GROUPS", True, after="NAXIS" + str(len(self._axes))) self._header.set("PCOUNT", len(self.data.parnames), after="GROUPS") self._header.set("GCOUNT", len(self.data), after="PCOUNT") column = self.data._coldefs[self._data_field] scale, zero = self.data._get_scale_factors(column)[3:5] if scale: self._header.set("BSCALE", column.bscale) if zero: self._header.set("BZERO", column.bzero) for idx, name in enumerate(self.data.parnames): self._header.set("PTYPE" + str(idx + 1), name) column = self.data._coldefs[idx] scale, zero = self.data._get_scale_factors(column)[3:5] if scale: self._header.set("PSCAL" + str(idx + 1), column.bscale) if zero: self._header.set("PZERO" + str(idx + 1), column.bzero) # Update the position of the EXTEND keyword if it already exists if "EXTEND" in self._header: if len(self._axes): after = "NAXIS" + str(len(self._axes)) else: after = "NAXIS" self._header.set("EXTEND", after=after)
def _writedata_internal(self, fileobj): """ Basically copy/pasted from `_ImageBaseHDU._writedata_internal()`, but we have to get the data's byte order a different way... TODO: Might be nice to store some indication of the data's byte order as an attribute or function so that we don't have to do this. """ size = 0 if self.data is not None: self.data._scale_back() # Based on the system type, determine the byteorders that # would need to be swapped to get to big-endian output if sys.byteorder == "little": swap_types = ("<", "=") else: swap_types = ("<",) # deal with unsigned integer 16, 32 and 64 data if _is_pseudo_integer(self.data.dtype): # Convert the unsigned array to signed output = np.array( self.data - _pseudo_zero(self.data.dtype), dtype=f">i{self.data.dtype.itemsize}", ) should_swap = False else: output = self.data fname = self.data.dtype.names[0] byteorder = self.data.dtype.fields[fname][0].str[0] should_swap = byteorder in swap_types if should_swap: if output.flags.writeable: output.byteswap(True) try: fileobj.writearray(output) finally: output.byteswap(True) else: # For read-only arrays, there is no way around making # a byteswapped copy of the data. fileobj.writearray(output.byteswap(False)) else: fileobj.writearray(output) size += output.size * output.itemsize return size def _verify(self, option="warn"): errs = super()._verify(option=option) # Verify locations and values of mandatory keywords. self.req_cards( "NAXIS", 2, lambda v: (_is_int(v) and 1 <= v <= 999), 1, option, errs ) self.req_cards("NAXIS1", 3, lambda v: (_is_int(v) and v == 0), 0, option, errs) after = self._header["NAXIS"] + 3 pos = lambda x: x >= after self.req_cards("GCOUNT", pos, _is_int, 1, option, errs) self.req_cards("PCOUNT", pos, _is_int, 0, option, errs) self.req_cards("GROUPS", pos, lambda v: (v is True), True, option, errs) return errs def _calculate_datasum(self): """ Calculate the value for the ``DATASUM`` card in the HDU. """ if self._has_data: # We have the data to be used. # Check the byte order of the data. If it is little endian we # must swap it before calculating the datasum. # TODO: Maybe check this on a per-field basis instead of assuming # that all fields have the same byte order? byteorder = self.data.dtype.fields[self.data.dtype.names[0]][0].str[0] if byteorder != ">": if self.data.flags.writeable: byteswapped = True d = self.data.byteswap(True) d.dtype = d.dtype.newbyteorder(">") else: # If the data is not writeable, we just make a byteswapped # copy and don't bother changing it back after d = self.data.byteswap(False) d.dtype = d.dtype.newbyteorder(">") byteswapped = False else: byteswapped = False d = self.data byte_data = d.view(type=np.ndarray, dtype=np.ubyte) cs = self._compute_checksum(byte_data) # If the data was byteswapped in this method then return it to # its original little-endian order. if byteswapped: d.byteswap(True) d.dtype = d.dtype.newbyteorder("<") return cs else: # This is the case where the data has not been read from the file # yet. We can handle that in a generic manner so we do it in the # base class. The other possibility is that there is no data at # all. This can also be handled in a generic manner. return super()._calculate_datasum() def _summary(self): summary = super()._summary() name, ver, classname, length, shape, format, gcount = summary # Drop the first axis from the shape if shape: shape = shape[1:] if shape and all(shape): # Update the format format = self.columns[0].dtype.name # Update the GCOUNT report gcount = f"{self._gcount} Groups {self._pcount} Parameters" return (name, ver, classname, length, shape, format, gcount)
def _par_indices(names): """ Given a list of objects, returns a mapping of objects in that list to the index or indices at which that object was found in the list. """ unique = {} for idx, name in enumerate(names): # Case insensitive name = name.upper() if name in unique: unique[name].append(idx) else: unique[name] = [idx] return unique def _unique_parnames(names): """ Given a list of parnames, including possible duplicates, returns a new list of parnames with duplicates prepended by one or more underscores to make them unique. This is also case insensitive. """ upper_names = set() unique_names = [] for name in names: name_upper = name.upper() while name_upper in upper_names: name = "_" + name name_upper = "_" + name_upper unique_names.append(name) upper_names.add(name_upper) return unique_names