# Licensed under a 3-clause BSD style license - see PYFITS.rst
import sys
import numpy as np
from astropy.io.fits.column import FITS2NUMPY, ColDefs, Column
from astropy.io.fits.fitsrec import FITS_rec, FITS_record
from astropy.io.fits.util import _is_int, _is_pseudo_integer, _pseudo_zero
from astropy.utils import lazyproperty
from .base import DELAYED, DTYPE2BITPIX
from .image import PrimaryHDU
from .table import _TableLikeHDU
[docs]class Group(FITS_record):
"""
One group of the random group data.
"""
def __init__(self, input, row=0, start=None, end=None, step=None, base=None):
super().__init__(input, row, start, end, step, base)
@property
def parnames(self):
return self.array.parnames
@property
def data(self):
# The last column in the coldefs is the data portion of the group
return self.field(self.array._coldefs.names[-1])
@lazyproperty
def _unique(self):
return _par_indices(self.parnames)
[docs] def par(self, parname):
"""
Get the group parameter value.
"""
if _is_int(parname):
result = self.array[self.row][parname]
else:
indx = self._unique[parname.upper()]
if len(indx) == 1:
result = self.array[self.row][indx[0]]
# if more than one group parameter have the same name
else:
result = self.array[self.row][indx[0]].astype("f8")
for i in indx[1:]:
result += self.array[self.row][i]
return result
[docs] def setpar(self, parname, value):
"""
Set the group parameter value.
"""
# TODO: It would be nice if, instead of requiring a multi-part value to
# be an array, there were an *option* to automatically split the value
# into multiple columns if it doesn't already fit in the array data
# type.
if _is_int(parname):
self.array[self.row][parname] = value
else:
indx = self._unique[parname.upper()]
if len(indx) == 1:
self.array[self.row][indx[0]] = value
# if more than one group parameter have the same name, the
# value must be a list (or tuple) containing arrays
else:
if isinstance(value, (list, tuple)) and len(indx) == len(value):
for i in range(len(indx)):
self.array[self.row][indx[i]] = value[i]
else:
raise ValueError(
"Parameter value must be a sequence with "
"{} arrays/numbers.".format(len(indx))
)
[docs]class GroupData(FITS_rec):
"""
Random groups data object.
Allows structured access to FITS Group data in a manner analogous
to tables.
"""
_record_type = Group
def __new__(
cls,
input=None,
bitpix=None,
pardata=None,
parnames=[],
bscale=None,
bzero=None,
parbscales=None,
parbzeros=None,
):
"""
Parameters
----------
input : array or FITS_rec instance
input data, either the group data itself (a
`numpy.ndarray`) or a record array (`FITS_rec`) which will
contain both group parameter info and the data. The rest
of the arguments are used only for the first case.
bitpix : int
data type as expressed in FITS ``BITPIX`` value (8, 16, 32,
64, -32, or -64)
pardata : sequence of array
parameter data, as a list of (numeric) arrays.
parnames : sequence of str
list of parameter names.
bscale : int
``BSCALE`` of the data
bzero : int
``BZERO`` of the data
parbscales : sequence of int
list of bscales for the parameters
parbzeros : sequence of int
list of bzeros for the parameters
"""
if not isinstance(input, FITS_rec):
if pardata is None:
npars = 0
else:
npars = len(pardata)
if parbscales is None:
parbscales = [None] * npars
if parbzeros is None:
parbzeros = [None] * npars
if parnames is None:
parnames = [f"PAR{idx + 1}" for idx in range(npars)]
if len(parnames) != npars:
raise ValueError(
"The number of parameter data arrays does "
"not match the number of parameters."
)
unique_parnames = _unique_parnames(parnames + ["DATA"])
if bitpix is None:
bitpix = DTYPE2BITPIX[input.dtype.name]
fits_fmt = GroupsHDU._bitpix2tform[bitpix] # -32 -> 'E'
format = FITS2NUMPY[fits_fmt] # 'E' -> 'f4'
data_fmt = f"{str(input.shape[1:])}{format}"
formats = ",".join(([format] * npars) + [data_fmt])
gcount = input.shape[0]
cols = [
Column(
name=unique_parnames[idx],
format=fits_fmt,
bscale=parbscales[idx],
bzero=parbzeros[idx],
)
for idx in range(npars)
]
cols.append(
Column(
name=unique_parnames[-1],
format=fits_fmt,
bscale=bscale,
bzero=bzero,
)
)
coldefs = ColDefs(cols)
self = FITS_rec.__new__(
cls,
np.rec.array(None, formats=formats, names=coldefs.names, shape=gcount),
)
# By default the data field will just be 'DATA', but it may be
# uniquified if 'DATA' is already used by one of the group names
self._data_field = unique_parnames[-1]
self._coldefs = coldefs
self.parnames = parnames
for idx, name in enumerate(unique_parnames[:-1]):
column = coldefs[idx]
# Note: _get_scale_factors is used here and in other cases
# below to determine whether the column has non-default
# scale/zero factors.
# TODO: Find a better way to do this than using this interface
scale, zero = self._get_scale_factors(column)[3:5]
if scale or zero:
self._cache_field(name, pardata[idx])
else:
np.rec.recarray.field(self, idx)[:] = pardata[idx]
column = coldefs[self._data_field]
scale, zero = self._get_scale_factors(column)[3:5]
if scale or zero:
self._cache_field(self._data_field, input)
else:
np.rec.recarray.field(self, npars)[:] = input
else:
self = FITS_rec.__new__(cls, input)
self.parnames = None
return self
def __array_finalize__(self, obj):
super().__array_finalize__(obj)
if isinstance(obj, GroupData):
self.parnames = obj.parnames
elif isinstance(obj, FITS_rec):
self.parnames = obj._coldefs.names
def __getitem__(self, key):
out = super().__getitem__(key)
if isinstance(out, GroupData):
out.parnames = self.parnames
return out
@property
def data(self):
"""
The raw group data represented as a multi-dimensional `numpy.ndarray`
array.
"""
# The last column in the coldefs is the data portion of the group
return self.field(self._coldefs.names[-1])
@lazyproperty
def _unique(self):
return _par_indices(self.parnames)
[docs] def par(self, parname):
"""
Get the group parameter values.
"""
if _is_int(parname):
result = self.field(parname)
else:
indx = self._unique[parname.upper()]
if len(indx) == 1:
result = self.field(indx[0])
# if more than one group parameter have the same name
else:
result = self.field(indx[0]).astype("f8")
for i in indx[1:]:
result += self.field(i)
return result
[docs]class GroupsHDU(PrimaryHDU, _TableLikeHDU):
"""
FITS Random Groups HDU class.
See the :ref:`astropy:random-groups` section in the Astropy documentation
for more details on working with this type of HDU.
"""
_bitpix2tform = {8: "B", 16: "I", 32: "J", 64: "K", -32: "E", -64: "D"}
_data_type = GroupData
_data_field = "DATA"
"""
The name of the table record array field that will contain the group data
for each group; 'DATA' by default, but may be preceded by any number of
underscores if 'DATA' is already a parameter name
"""
def __init__(self, data=None, header=None):
super().__init__(data=data, header=header)
if data is not DELAYED:
self.update_header()
# Update the axes; GROUPS HDUs should always have at least one axis
if len(self._axes) <= 0:
self._axes = [0]
self._header["NAXIS"] = 1
self._header.set("NAXIS1", 0, after="NAXIS")
@lazyproperty
def data(self):
"""
The data of a random group FITS file will be like a binary table's
data.
"""
if self._axes == [0]:
return
data = self._get_tbdata()
data._coldefs = self.columns
data.parnames = self.parnames
del self.columns
return data
@lazyproperty
def parnames(self):
"""The names of the group parameters as described by the header."""
pcount = self._header["PCOUNT"]
# The FITS standard doesn't really say what to do if a parname is
# missing, so for now just assume that won't happen
return [self._header["PTYPE" + str(idx + 1)] for idx in range(pcount)]
@lazyproperty
def columns(self):
if self._has_data and hasattr(self.data, "_coldefs"):
return self.data._coldefs
format = self._bitpix2tform[self._header["BITPIX"]]
pcount = self._header["PCOUNT"]
parnames = []
bscales = []
bzeros = []
for idx in range(pcount):
bscales.append(self._header.get("PSCAL" + str(idx + 1), None))
bzeros.append(self._header.get("PZERO" + str(idx + 1), None))
parnames.append(self._header["PTYPE" + str(idx + 1)])
formats = [format] * len(parnames)
dim = [None] * len(parnames)
# Now create columns from collected parameters, but first add the DATA
# column too, to contain the group data.
parnames.append("DATA")
bscales.append(self._header.get("BSCALE"))
bzeros.append(self._header.get("BZEROS"))
data_shape = self.shape[:-1]
formats.append(str(int(np.prod(data_shape))) + format)
dim.append(data_shape)
parnames = _unique_parnames(parnames)
self._data_field = parnames[-1]
cols = [
Column(name=name, format=fmt, bscale=bscale, bzero=bzero, dim=dim)
for name, fmt, bscale, bzero, dim in zip(
parnames, formats, bscales, bzeros, dim
)
]
coldefs = ColDefs(cols)
return coldefs
@property
def _nrows(self):
if not self._data_loaded:
# The number of 'groups' equates to the number of rows in the table
# representation of the data
return self._header.get("GCOUNT", 0)
else:
return len(self.data)
@lazyproperty
def _theap(self):
# Only really a lazyproperty for symmetry with _TableBaseHDU
return 0
@property
def is_image(self):
return False
@property
def size(self):
"""
Returns the size (in bytes) of the HDU's data part.
"""
size = 0
naxis = self._header.get("NAXIS", 0)
# for random group image, NAXIS1 should be 0, so we skip NAXIS1.
if naxis > 1:
size = 1
for idx in range(1, naxis):
size = size * self._header["NAXIS" + str(idx + 1)]
bitpix = self._header["BITPIX"]
gcount = self._header.get("GCOUNT", 1)
pcount = self._header.get("PCOUNT", 0)
size = abs(bitpix) * gcount * (pcount + size) // 8
return size
def _writedata_internal(self, fileobj):
"""
Basically copy/pasted from `_ImageBaseHDU._writedata_internal()`, but
we have to get the data's byte order a different way...
TODO: Might be nice to store some indication of the data's byte order
as an attribute or function so that we don't have to do this.
"""
size = 0
if self.data is not None:
self.data._scale_back()
# Based on the system type, determine the byteorders that
# would need to be swapped to get to big-endian output
if sys.byteorder == "little":
swap_types = ("<", "=")
else:
swap_types = ("<",)
# deal with unsigned integer 16, 32 and 64 data
if _is_pseudo_integer(self.data.dtype):
# Convert the unsigned array to signed
output = np.array(
self.data - _pseudo_zero(self.data.dtype),
dtype=f">i{self.data.dtype.itemsize}",
)
should_swap = False
else:
output = self.data
fname = self.data.dtype.names[0]
byteorder = self.data.dtype.fields[fname][0].str[0]
should_swap = byteorder in swap_types
if should_swap:
if output.flags.writeable:
output.byteswap(True)
try:
fileobj.writearray(output)
finally:
output.byteswap(True)
else:
# For read-only arrays, there is no way around making
# a byteswapped copy of the data.
fileobj.writearray(output.byteswap(False))
else:
fileobj.writearray(output)
size += output.size * output.itemsize
return size
def _verify(self, option="warn"):
errs = super()._verify(option=option)
# Verify locations and values of mandatory keywords.
self.req_cards(
"NAXIS", 2, lambda v: (_is_int(v) and 1 <= v <= 999), 1, option, errs
)
self.req_cards("NAXIS1", 3, lambda v: (_is_int(v) and v == 0), 0, option, errs)
after = self._header["NAXIS"] + 3
pos = lambda x: x >= after
self.req_cards("GCOUNT", pos, _is_int, 1, option, errs)
self.req_cards("PCOUNT", pos, _is_int, 0, option, errs)
self.req_cards("GROUPS", pos, lambda v: (v is True), True, option, errs)
return errs
def _calculate_datasum(self):
"""
Calculate the value for the ``DATASUM`` card in the HDU.
"""
if self._has_data:
# We have the data to be used.
# Check the byte order of the data. If it is little endian we
# must swap it before calculating the datasum.
# TODO: Maybe check this on a per-field basis instead of assuming
# that all fields have the same byte order?
byteorder = self.data.dtype.fields[self.data.dtype.names[0]][0].str[0]
if byteorder != ">":
if self.data.flags.writeable:
byteswapped = True
d = self.data.byteswap(True)
d.dtype = d.dtype.newbyteorder(">")
else:
# If the data is not writeable, we just make a byteswapped
# copy and don't bother changing it back after
d = self.data.byteswap(False)
d.dtype = d.dtype.newbyteorder(">")
byteswapped = False
else:
byteswapped = False
d = self.data
byte_data = d.view(type=np.ndarray, dtype=np.ubyte)
cs = self._compute_checksum(byte_data)
# If the data was byteswapped in this method then return it to
# its original little-endian order.
if byteswapped:
d.byteswap(True)
d.dtype = d.dtype.newbyteorder("<")
return cs
else:
# This is the case where the data has not been read from the file
# yet. We can handle that in a generic manner so we do it in the
# base class. The other possibility is that there is no data at
# all. This can also be handled in a generic manner.
return super()._calculate_datasum()
def _summary(self):
summary = super()._summary()
name, ver, classname, length, shape, format, gcount = summary
# Drop the first axis from the shape
if shape:
shape = shape[1:]
if shape and all(shape):
# Update the format
format = self.columns[0].dtype.name
# Update the GCOUNT report
gcount = f"{self._gcount} Groups {self._pcount} Parameters"
return (name, ver, classname, length, shape, format, gcount)
def _par_indices(names):
"""
Given a list of objects, returns a mapping of objects in that list to the
index or indices at which that object was found in the list.
"""
unique = {}
for idx, name in enumerate(names):
# Case insensitive
name = name.upper()
if name in unique:
unique[name].append(idx)
else:
unique[name] = [idx]
return unique
def _unique_parnames(names):
"""
Given a list of parnames, including possible duplicates, returns a new list
of parnames with duplicates prepended by one or more underscores to make
them unique. This is also case insensitive.
"""
upper_names = set()
unique_names = []
for name in names:
name_upper = name.upper()
while name_upper in upper_names:
name = "_" + name
name_upper = "_" + name_upper
unique_names.append(name)
upper_names.add(name_upper)
return unique_names