import matplotlib.pyplot as plt import numpy as np from astropy.convolution import Gaussian2DKernel, convolve from astropy.io import fits from astropy.utils.data import get_pkg_data_filename from scipy.ndimage import convolve as scipy_convolve # Load the data from data.astropy.org filename = get_pkg_data_filename('galactic_center/gc_msx_e.fits') hdu = fits.open(filename)[0] # Scale the file to have reasonable numbers # (this is mostly so that colorbars do not have too many digits) # Also, we crop it so you can see individual pixels img = hdu.data[50:90, 60:100] * 1e5 # This example is intended to demonstrate how astropy.convolve and # scipy.convolve handle missing data, so we start by setting the # brightest pixels to NaN to simulate a "saturated" data set img[img > 20] = np.nan # We also create a copy of the data and set those NaNs to zero. We will # use this for the scipy convolution img_zerod = img.copy() img_zerod[np.isnan(img)] = 0 # We smooth with a Gaussian kernel with x_stddev=1 (and y_stddev=1) # It is a 9x9 array kernel = Gaussian2DKernel(x_stddev=1) # Convolution: scipy's direct convolution mode spreads out NaNs (see # panel 2 below) scipy_conv = scipy_convolve(img, kernel) # scipy's direct convolution mode run on the 'zero'd' image will not # have NaNs, but will have some very low value zones where the NaNs were # (see panel 3 below) scipy_conv_zerod = scipy_convolve(img_zerod, kernel) # astropy's convolution replaces the NaN pixels with a kernel-weighted # interpolation from their neighbors astropy_conv = convolve(img, kernel) # Now we do a bunch of plots. In the first two plots, the originally masked # values are marked with red X's fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(8, 8)) plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.3, hspace=0.3) ax = ax.flatten() for axis in ax: axis.axis('off') im = ax[0].imshow(img, vmin=-2.0, vmax=20.0, origin='lower', interpolation='nearest', cmap='viridis') y, x = np.where(np.isnan(img)) ax[0].plot(x, y, 'rx', markersize=4) ax[0].set_title("Input Data") ax[0].set_xticklabels([]) ax[0].set_yticklabels([]) im = ax[1].imshow(scipy_conv, vmin=-2.0, vmax=20.0, origin='lower', interpolation='nearest', cmap='viridis') ax[1].plot(x, y, 'rx', markersize=4) ax[1].set_title("Scipy convolved") ax[1].set_xticklabels([]) ax[1].set_yticklabels([]) im = ax[2].imshow(scipy_conv_zerod, vmin=-2.0, vmax=20.0, origin='lower', interpolation='nearest', cmap='viridis') ax[2].set_title("Scipy convolved (NaN to zero)") ax[2].set_xticklabels([]) ax[2].set_yticklabels([]) im = ax[3].imshow(astropy_conv, vmin=-2.0, vmax=20.0, origin='lower', interpolation='nearest', cmap='viridis') ax[3].set_title("Astropy convolved") ax[3].set_xticklabels([]) ax[3].set_yticklabels([]) plt.tight_layout() # we make a second plot of the amplitudes vs offset position to more # clearly illustrate the value differences plt.figure(2, figsize=(8, 6)).clf() plt.plot(img[:, 25], label='Input data', drawstyle='steps-mid', linewidth=2, alpha=0.5) plt.plot(scipy_conv[:, 25], label='SciPy convolved', drawstyle='steps-mid', linewidth=2, alpha=0.5, marker='s') plt.plot(scipy_conv_zerod[:, 25], label='SciPy convolved (NaN to zero)', drawstyle='steps-mid', linewidth=2, alpha=0.5, marker='s') plt.plot(astropy_conv[:, 25], label='Astropy convolved', drawstyle='steps-mid', linewidth=2, alpha=0.5) plt.xlabel("Pixel") plt.ylabel("Amplitude") plt.legend(loc='best') plt.show()