import logging
import math
import numpy as np
import numpy.polynomial.polynomial as poly
from astropy.stats import sigma_clipped_stats
from astropy.timeseries import LombScargle
from BayesicFitting import ConstantModel, Fitter, LevenbergMarquardtFitter, RobustShell, SineModel
from scipy.interpolate import pchip
from jwst.residual_fringe.fitter import spline_fitter
log = logging.getLogger(__name__)
# Number of knots for bkg model if no other info provided
# Hard coded parameter, has been selected based on testing but can be changed
NUM_KNOTS = 80
# Define some constants describing the two central fringe frequencies
# (primary fringe, and dichroic fringe) and a range around them to search for residual fringes,
# along with the maximum number of fringes to fit and the max amplitude allowed for those fringes.
FFREQ_1D = [2.9, 0.4]
DFFREQ_1D = [1.5, 0.15]
MAX_NFRINGES_1D = [10, 15]
MAXLINE_1D = 0.05
MAXAMP_1D = 0.4
__all__ = [
"slice_info",
"fill_wavenumbers",
"multi_sine",
"fit_envelope",
"find_lines",
"clip_spectral_features",
"check_res_fringes",
"interp_helper",
"fit_1d_background_complex",
"fit_1d_fringes_bayes_evidence",
"make_knots",
"fit_1d_background_complex_1d",
"fit_1d_fringes_bayes_evidence_1d",
"fit_residual_fringes_1d",
]
def slice_info(slice_map, channel):
"""
Identify pixels by slice.
Parameters
----------
slice_map : ndarray of int
2D image containing slice identification values by pixel.
Slice ID values are integers with the value 100 * channel number
+ slice number. Pixels not included in a slice have value 0.
channel : int
Channel number.
Returns
-------
slices_in_channel : ndarray of int
1D array of slice IDs included in the channel.
xrange_channel : ndarray of int
1D array with two elements: minimum and maximum x indices
for the channel.
slice_x_ranges : ndarray of int
N x 3 array for N slices, where the first column is the slice ID,
second column is the minimum x index for the slice,
and the third column is the maximum x index for the slice.
all_slice_masks : ndarray of int
N x nx x ny for N slices, matching the x and y shape of the
input slice_map. Values are 1 for pixels included in the slice,
0 otherwise.
"""
slice_inventory = np.unique(slice_map)
slices_in_channel = slice_inventory[
np.where((slice_inventory >= 100 * channel) & (slice_inventory < 100 * (channel + 1)))
]
log.info(f"Number of slices in channel {slices_in_channel.shape[0]} ")
slice_x_ranges = np.zeros((slices_in_channel.shape[0], 3), dtype=int)
all_slice_masks = np.zeros((slices_in_channel.shape[0], slice_map.shape[0], slice_map.shape[1]))
for n, s in enumerate(slices_in_channel):
# create a mask of the slice
pixels = np.where(slice_map == s)
slice_mask = np.zeros(slice_map.shape)
slice_mask[pixels] = 1
# add this to the all_slice_mask array
all_slice_masks[n] = slice_mask
# get the indices at the start and end of the slice
collapsed_slice = np.sum(slice_mask, axis=0)
indices = np.where(collapsed_slice[:-1] != collapsed_slice[1:])[0]
slice_x_ranges[n, 0], slice_x_ranges[n, 1], slice_x_ranges[n, 2] = (
int(s),
int(np.amin(indices)),
int(np.amax(indices) + 1),
)
log.debug(
f"For slice {slice_x_ranges[n, 0]} x ranges of slices "
f"region {slice_x_ranges[n, 1]}, {slice_x_ranges[n, 2]}"
)
log.debug(
"Min and max x pixel values of all slices "
f"in channel {np.amin(slice_x_ranges[:, 1])} {np.amax(slice_x_ranges[:, 2])}"
)
xrange_channel = np.zeros(2)
xrange_channel[0] = np.amin(slice_x_ranges[:, 1])
xrange_channel[1] = np.amax(slice_x_ranges[:, 2])
return slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks
def fill_wavenumbers(wnums):
"""
Fill in missing wavenumber values.
Given a wavenumber array with missing values (e.g., columns with
on-slice and off-slice pixels), fit the good points using a
polynomial, then use the coefficients to estimate wavenumbers
on the off-slice pixels.
Note that these new values are physically meaningless but having them
in the wavenum array stops the BayesicFitting package from crashing with
a LinAlgErr.
Parameters
----------
wnums : ndarray
The wavenumber array.
Returns
-------
wnums_filled : ndarray
The wavenumber array with off-slice pixels filled
"""
# set the off-slice pixels to nans and get their indices
wnums[wnums == 0] = np.nan
idx = np.isfinite(wnums)
# fit the on-slice wavenumbers
coefs = poly.polyfit(np.arange(wnums.shape[0])[idx], wnums[idx], 3)
wnums_filled = poly.polyval(np.arange(wnums.shape[0]), coefs)
# keep the original wnums for the on-slice pixels
wnums_filled[idx] = wnums[idx]
# clean
del idx, coefs
return wnums_filled
def multi_sine(n_sines):
"""
Create a multi-sine model.
Parameters
----------
n_sines : int
Number of sines to include.
Returns
-------
model : BayesicFitting.SineModel
The model composed of n sines.
"""
# make the first sine
mdl = SineModel()
# make a copy
model = mdl.copy()
# add the copy n - 1 times
for _ in range(1, n_sines):
mdl.addModel(model.copy())
# clean
del model
return mdl
def fit_envelope(wavenum, signal, check_extra_neighbors=False):
"""
Fit the upper and lower envelope of signal using a univariate spline.
Parameters
----------
wavenum : ndarray
Wavenumber values.
signal : ndarray
Signal values
check_extra_neighbors : bool, optional
If True, check two neighboring pixels instead of one, for
identifying peaks and troughs.
Returns
-------
lower_fit : ndarray
Fit to the lower envelope.
l_x : list
Input lower wavenum values.
l_y : list
Input lower signal values.
upper_fit : ndarray
Fit to the upper envelope.
u_x : list
Input upper wavenum values.
u_y : list
Input lower wavenum values.
"""
# Detect troughs and mark their location. Define endpoints
l_x = [wavenum[0]]
l_y = [signal[0]]
u_x = [wavenum[0]]
u_y = [signal[0]]
start = 2 if check_extra_neighbors else 1
for k in np.arange(start, len(signal) - start):
neighbor_check = [np.sign(signal[k] - signal[k - 1]), np.sign(signal[k] - signal[k + 1])]
if check_extra_neighbors:
neighbor_check.extend(
[np.sign(signal[k] - signal[k - 2]), np.sign(signal[k] - signal[k + 2])]
)
if all(n == -1 for n in neighbor_check):
# add to troughs: pixel is lower than its neighbors
l_x.append(wavenum[k])
l_y.append(signal[k])
elif all(n == 1 for n in neighbor_check):
# add to peaks: pixel is higher than its neighbors
u_x.append(wavenum[k])
u_y.append(signal[k])
# Append the last value of (s) to the interpolating values.
# This forces the model to use the same ending point
l_x.append(wavenum[-1])
l_y.append(signal[-1])
u_x.append(wavenum[-1])
u_y.append(signal[-1])
# fit a model
pcl = pchip(l_x, l_y)
pcu = pchip(u_x, u_y)
return pcl(wavenum), l_x, l_y, pcu(wavenum), u_x, u_y
def find_lines(signal, max_amp):
"""
Determine the location of large spectral features.
Parameters
----------
signal : ndarray
Signal data.
max_amp : ndarray
Maximum amplitude, by column. Features larger than
this value are flagged.
Returns
-------
weights : ndarray
1D array matching signal dimensions, containing 0 values
for large features and 1 values where no features were
detected.
"""
r_x = np.arange(signal.shape[0] - 1)
# setup the output arrays
signal_check = signal.copy()
weights_factors = np.ones(signal.shape[0])
# Detect peaks
u_y, u_x, l_y, l_x = [], [], [], []
for x in r_x:
# check for values near zero
if np.allclose(signal_check[x - 1 : x + 2], 0.0):
continue
# pixel is higher than immediate neighbors
if (np.sign(signal_check[x] - signal_check[x - 1]) == 1) and (
np.sign(signal_check[x] - signal_check[x + 1]) == 1
):
u_y.append(signal_check[x])
u_x.append(x)
# pixel is lower than immediate neighbors
if (np.sign(signal_check[x] - signal_check[x - 1]) == -1) and (
np.sign(signal_check[x] - signal_check[x + 1]) == -1
):
l_y.append(signal_check[x])
l_x.append(x)
for n, amp in enumerate(u_y):
max_amp_val = max_amp[u_x[n]]
if amp > max_amp_val:
# peak in x
xpeaks = [u_x[n] - 1, u_x[n], u_x[n] + 1]
# find nearest troughs
for xp in xpeaks:
log.debug(f"find_lines: checking ind {xp}")
try:
x1 = l_x[np.argsort(np.abs(l_x - xp))[0]]
try:
x2 = l_x[np.argsort(np.abs(l_x - xp))[1]]
if x1 < x2:
xlow = x1
xhigh = x2
if x1 > x2:
xhigh = x1
xlow = x2
except IndexError:
# raised if x1 is at the edge
xlow = x1
xhigh = x1
# set the weights to 0 and signal 1
log.debug("find_lines: setting weights between troughs to 0")
signal_check[xlow:xhigh] = 0
weights_factors[xlow:xhigh] = 0
except IndexError:
pass
log.debug(f"find_lines: Found {len(u_x)} peaks, {len(l_x)} troughs")
# catch any remaining signal significantly higher than the max amplitude
weights_factors[signal_check > 2 * max_amp] = 0
return weights_factors
def clip_spectral_features(signal, sigma=4.0):
"""
Clip out large spectral features.
Parameters
----------
signal : ndarray
Signal data. Input is expected to be the relative difference of the
spectrum from a smoothed spline fit to the lower envelope of the spectrum.
sigma : float, optional
Upper threshold for clipping features.
Returns
-------
weights : ndarray
1D array matching signal dimensions, containing 0 values
for large features and 1 values where no features were
detected.
"""
_, med, stddev = sigma_clipped_stats(signal, sigma_lower=np.inf, sigma_upper=sigma)
weights = (signal < (med + sigma * stddev)).astype(int)
return weights
def check_res_fringes(res_fringe_fit, max_amp):
"""
Check for regions with bad fringe fits.
Set the beat where this happens to 0 to avoid making the fringes worse.
Parameters
----------
res_fringe_fit : ndarray
The residual fringe fit.
max_amp : ndarray
The maximum amplitude array.
Returns
-------
res_fringe_fit: ndarray
The residual fringe fit with bad fit regions removed and replaced
with 0.
flags: ndarray
1D flag array indicating where the fit was altered, matching
the size of the first dimension of `res_fringe_fit`.
1 indicates a bad fit region; 0 indicates a good region, left
unchanged.
"""
flags = np.zeros(res_fringe_fit.shape[0])
# get fit envelope
npix = np.arange(res_fringe_fit.shape[0])
lenv_fit, _, _, uenv_fit, _, _ = fit_envelope(npix, res_fringe_fit)
# get the indices of the nodes (where uenv slope goes from
# negative to positive), add 0 and 1023
node_ind = [0]
for k in np.arange(1, len(uenv_fit) - 1):
if (np.sign(uenv_fit[k] - uenv_fit[k - 1]) == -1) and (
(np.sign(uenv_fit[k] - uenv_fit[k + 1])) == -1
):
node_ind.append(k)
node_ind.append(res_fringe_fit.shape[0] - 1)
node_ind = np.asarray(node_ind)
log.debug(f"check_res_fringes: found {len(node_ind)} nodes")
# find where res_fringes goes above max_amp
runaway_rfc = np.argwhere((np.abs(lenv_fit) + np.abs(uenv_fit)) > (max_amp))
# check which signal env the blow ups are located in and set to 1, and set a flag array
if len(runaway_rfc) > 0:
log.debug(f"check_res_fringes: {len(runaway_rfc)} data points exceed threshold")
log.debug("check_res_fringes: resetting fits to related beats")
for i in runaway_rfc:
# find where the index is compared to the nodes
node_loc = np.searchsorted(node_ind, i)
# set the res_fringes between the nodes to 0
lind = node_ind[node_loc - 1]
uind = node_ind[node_loc]
res_fringe_fit[lind[0] : uind[0]] = 0
flags[lind[0] : uind[0]] = 1 # set flag to 1 for reject fit region
return res_fringe_fit, flags
def interp_helper(mask):
"""
Create a convenience function for indexing low-weight values.
Low-weight is defined to be a value < 1e-5.
Parameters
----------
mask : ndarray
The 1D mask array (weights).
Returns
-------
index_array : ndarray of bool
Boolean index array for low weight pixels.
index_function : callable
A function, with signature indices = index_function(index_array),
to convert logical indices to equivalent direct index values.
"""
return mask < 1e-05, lambda z: z.nonzero()[0]
def fit_1d_background_complex(flux, weights, wavenum, ffreq=None, channel=1):
"""
Fit the background signal using a piecewise spline.
Note that this will also try to identify obvious emission lines
and flag them, so they aren't considered in the fitting.
Parameters
----------
flux : ndarray
1D array of fluxes.
weights : ndarray
1D array of weights.
wavenum : ndarray
1D array of wavenumbers.
ffreq : float, optional
The expected fringe frequency, used to determine number of knots.
If None, defaults to NUM_KNOTS constant
channel : int, optional
The channel to process. Used to determine if other arrays
need to be reversed given the direction of increasing
wavelength down the detector in MIRIFULONG.
Returns
-------
bg_fit : ndarray
The fitted background.
bgindx: ndarray
The location of the knots.
"""
# first get the weighted pixel fraction
weighted_pix_frac = (weights > 1e-05).sum() / flux.shape[0]
# define number of knots using fringe freq, want 1 knot per period
if ffreq is not None:
log.debug(f"fit_1d_background_complex: knot positions for {ffreq} cm-1")
nknots = int((np.amax(wavenum) - np.amin(wavenum)) / (ffreq))
else:
log.debug(f"fit_1d_background_complex: using num_knots={NUM_KNOTS}")
nknots = int((flux.shape[0] / 1024) * NUM_KNOTS)
log.debug(f"fit_1d_background_complex: number of knots = {nknots}")
# recale wavenums to around 1 for bayesicfitting
factor = np.amin(wavenum)
wavenum_scaled = wavenum.copy() / factor
# get number of fringe periods in array
nper = (np.amax(wavenum) - np.amin(wavenum)) // ffreq
log.debug(f"fit_1d_background_complex: column is {nper} fringe periods")
# now reduce by the weighted pixel fraction to see how many can be fitted
nper_cor = int(nper * weighted_pix_frac)
log.debug(f"fit_1d_background_complex: column has {nper_cor} weighted fringe periods")
# require at least 5 sine periods to fit
if nper < 5:
log.info(" not enough weighted data, no fit performed")
return flux.copy(), np.zeros(flux.shape[0]), None
bgindx = make_knots(flux.copy(), int(nknots), weights=weights.copy())
bgknots = wavenum_scaled[bgindx].astype(float)
# Reverse (and clip) the fit data as scipy/astropy need monotone increasing data for SW detector
if channel == 3 or channel == 4:
t = bgknots[1:-1]
x = wavenum_scaled
y = flux
w = weights
elif channel == 1 or channel == 2:
t = bgknots[::-1][1:-1]
x = wavenum_scaled[::-1]
y = flux[::-1]
w = weights[::-1]
else:
raise ValueError("channel not in 1-4")
# Fit the spline
if ffreq > 1.5:
bg_model = spline_fitter(x, y, w, t, 2, reject_outliers=True)
else:
# robust fitting causing problems for fringe 2 in channels 3 and 4,
# just use the fitter class
bg_model = spline_fitter(x, y, w, t, 1, reject_outliers=False)
# fit the background
bg_fit = bg_model(wavenum_scaled)
bg_fit *= np.where(weights.copy() > 1e-07, 1, 1e-08)
# linearly interpolate over the feature gaps if possible, stops issues later
try:
nz, z = interp_helper(weights)
bg_fit[nz] = np.interp(z(nz), z(~nz), bg_fit[~nz])
del nz, z
except ValueError:
pass
return bg_fit, bgindx
def fit_1d_fringes_bayes_evidence(
res_fringes, weights, wavenum, ffreq, dffreq, max_nfringes, pgram_res, col_snr2
):
"""
Fit the residual fringe signal.
Takes an input 1D array of residual fringes and fits using the
supplied mode in the BayesicFitting package.
Parameters
----------
res_fringes : ndarray
The 1D array with residual fringes.
weights : ndarray
The 1D array of weights
wavenum : ndarray
The 1D array of wavenum.
ffreq : float
The central scan frequency
dffreq : float
The one-sided interval of scan frequencies.
max_nfringes : int
The maximum number of fringes to check.
pgram_res : float
Resolution of the periodogram scan in cm-1.
col_snr2 : ndarray
Location of pixels with sufficient SNR to fit.
Returns
-------
res_fringe_fit : ndarray
The residual fringe fit data.
"""
# initialize output to none
res_fringe_fit = None
# get the number of weighted pixels
weighted_pix_num = (weights > 1e-05).sum()
# get scan res
res = np.around((2 * dffreq) / pgram_res).astype(int)
log.debug(f"fit_1d_fringes_bayes: scan res = {res}")
factor = np.amin(wavenum)
wavenum = wavenum.copy() / factor
ffreq = ffreq / factor
dffreq = dffreq / factor
# setup frequencies to scan
freq = np.linspace(ffreq - dffreq, ffreq + dffreq, res)
# handle out of slice pixels
res_fringes = np.nan_to_num(res_fringes)
res_fringes[res_fringes == np.inf] = 0
res_fringes[res_fringes == -np.inf] = 0
# initialise some parameters
res_fringes_proc = res_fringes.copy()
nfringes = 0
keep_dict = {}
fitted_frequencies = []
# get the initial evidence from ConstantModel
sdml = ConstantModel(values=1.0)
sftr = Fitter(wavenum, sdml)
_ = sftr.fit(res_fringes, weights=weights)
evidence1 = sftr.getEvidence(limits=[-3, 10], noiseLimits=[0.001, 10])
log.debug(f"fit_1d_fringes_bayes_evidence: Initial Evidence: {evidence1}")
for f in np.arange(max_nfringes):
log.debug(f"Starting fringe {f + 1}")
# get the scan arrays
weights *= col_snr2
res_fringe_scan = res_fringes_proc[np.where(weights > 1e-05)]
wavenum_scan = wavenum[np.where(weights > 1e-05)]
# use a Lomb-Scargle periodogram to get PSD and identify the strongest frequency
log.debug("fit_1d_fringes_bayes_evidence: get the periodogram")
pgram = LombScargle(wavenum_scan[::-1], res_fringe_scan[::-1]).power(1 / freq)
log.debug(
"fit_1d_fringes_bayes_evidence: get the most significant frequency in the periodogram"
)
peak = np.argmax(pgram)
freqs = 1.0 / freq[peak]
# fix the most significant frequency in the fixed dict that is passed to fitter
keep_ind = nfringes * 3
keep_dict[keep_ind] = freqs
log.debug(
f"fit_1d_fringes_bayes_evidence: creating multisine model of {nfringes + 1} freqs"
)
mdl = multi_sine(nfringes + 1)
# fit the multi-sine model and get evidence
if ffreq * factor > 1.5:
fitter = LevenbergMarquardtFitter(wavenum, mdl, verbose=0, keep=keep_dict)
ftr = RobustShell(fitter, domain=10)
try:
pars = ftr.fit(res_fringes, weights=weights)
# free the parameters and refit
mdl = multi_sine(nfringes + 1)
mdl.parameters = pars
fitter = LevenbergMarquardtFitter(wavenum, mdl, verbose=0)
ftr = RobustShell(fitter, domain=10)
pars = ftr.fit(res_fringes, weights=weights)
# try to get evidence (may fail for large component
# fits to noisy data, set to very negative value)
try:
evidence2 = fitter.getEvidence(limits=[-3, 10], noiseLimits=[0.001, 10])
except ValueError:
evidence2 = -1e9
except Exception:
evidence2 = -1e9
else:
fitter = LevenbergMarquardtFitter(wavenum, mdl, verbose=0, keep=keep_dict)
try:
pars = fitter.fit(res_fringes, weights=weights)
# free the parameters and refit
mdl = multi_sine(nfringes + 1)
mdl.parameters = pars
fitter = LevenbergMarquardtFitter(wavenum, mdl, verbose=0)
pars = fitter.fit(res_fringes, weights=weights)
# try to get evidence (may fail for large component
# fits to noisy data, set to very negative value)
try:
evidence2 = fitter.getEvidence(limits=[-3, 10], noiseLimits=[0.001, 10])
except ValueError:
evidence2 = -1e9
except Exception:
evidence2 = -1e9
log.debug(
f"fit_1d_fringes_bayes_evidence: nfringe={nfringes + 1} "
f"ev={evidence2} chi={fitter.chisq}"
)
bayes_factor = evidence2 - evidence1
log.debug(f"fit_1d_fringes_bayes_evidence: bayes factor={bayes_factor}")
if bayes_factor > 1:
# strong evidence threshold (log(bayes factor)>1, Kass and Raftery 1995)
evidence1 = evidence2
best_mdl = mdl.copy()
fitted_frequencies.append(freqs)
log.debug(
f"fit_1d_fringes_bayes_evidence: strong evidence for nfringes={nfringes + 1} "
)
else:
log.debug(f"fit_1d_fringes_bayes_evidence: no evidence for nfringes={nfringes + 1}")
break
# subtract the fringes for this frequency
res_fringe_fit = best_mdl(wavenum)
res_fringes_proc = res_fringes.copy() - res_fringe_fit
nfringes += 1
log.debug(f"fit_1d_fringes_bayes_evidence: optimal={nfringes} fringes")
# create outputs to return
fitted_frequencies = (1 / np.asarray(fitted_frequencies)) * factor
peak_freq = fitted_frequencies[0]
freq_min = np.amin(fitted_frequencies)
freq_max = np.amax(fitted_frequencies)
return res_fringe_fit, weighted_pix_num, nfringes, peak_freq, freq_min, freq_max
def make_knots(flux, nknots=20, weights=None):
"""
Define knot positions for piecewise models.
This function simply splits the array into sections. It does
NOT take into account the shape of the data.
Parameters
----------
flux : ndarray
The flux array or any array of the same dimension.
nknots : int, optional
The number of knots to create (excluding 0 and 1023).
weights : ndarray or None, optional
Optionally supply a weights array. This will be used to
add knots at the edge of bad pixels or features.
Returns
-------
knot_idx : ndarray
The indices of the knots.
"""
log.debug(f"make_knots: creating {nknots} knots on flux array")
# handle nans or infs that may exist
flux = np.nan_to_num(flux, posinf=1e-08, neginf=1e-08)
flux[flux < 0] = 1e-08
if weights is not None:
weights = np.nan_to_num(weights, posinf=1e-08, neginf=1e-08)
weights[weights < 0] = 1e-08
# create an array of indices
npoints = flux.shape[0]
# kstep is the number of points / number of knots
knot_step = npoints / nknots
# define an initial knot index array
init_knot_idx = np.zeros(nknots + 1)
for n in range(nknots):
init_knot_idx[n] = round(n * knot_step)
init_knot_idx[-1] = npoints - 1
# get difference between the indices
knot_split = np.ediff1d(init_knot_idx)
# the last diff will sometimes be different than the others
last_split = knot_split[-1]
# create new knot array with knots at 0 and 1023, then initial and final splits = last_split/2
knot_idx = np.ones(nknots + 2)
for n in range(nknots):
knot_idx[n + 1] = (n * knot_step) + (last_split / 2)
knot_idx[0] = 0
knot_idx[-1] = npoints - 1
# if the weights array is supplied, determine the edges of good data and set knots there
if weights is not None:
log.debug("make_knots: adding knots at edges of bad pixels in weights array")
# if there are bad pixels in the flux array with flux~0,
# add these to weights array if not already there
weights *= (flux > 1e-03).astype(int)
# use a two-point difference method
weights_diff = np.ediff1d(weights)
# set edges where diff should be almost equal to the largest of the two datapoints used
# iterate over the diffs and compare to the datapoints
edges_idx_list = []
for n, wd in enumerate(weights_diff):
# get the data points used for the diff
datapoints = np.array([weights[n], weights[n + 1]])
# get the value and index of the larges
largest = np.amax(datapoints)
largest_idx = np.argmax(datapoints)
# we don't need knots in the bad pixels so ignore these
if largest > 1e-03:
# check if the absolute values are almost equal
if math.isclose(largest, np.abs(wd), rel_tol=1e-01):
# if so, set the index and adjust depending on whether the
# first or second datapoint is the largest
idx = n + largest_idx
# check if this is right next to another index already defined
# causes problems in fitting, minimal difference
if (idx - 1 in knot_idx) | (idx + 1 in knot_idx):
pass
else:
# append to the index list
edges_idx_list.append(idx)
else:
pass
# convert the list to array, add to the knot_idx array, remove duplicates and sort
edges_idx = np.asarray(edges_idx_list)
knot_idx = np.sort(np.concatenate((knot_idx, edges_idx), axis=0), axis=0)
knot_idx = np.unique(knot_idx.astype(int))
return knot_idx.astype(int)
# The below functions were added to enable residual fringe correction
# in 1D extracted data.
def fit_1d_background_complex_1d(flux, weights, wavenum, ffreq=None):
"""
Fit the background signal using a piecewise spline of n knots.
Note that this will also try to identify obvious emission lines and
flag them so they aren't considered in the fitting.
Parameters
----------
flux : ndarray
The 1D array of fluxes.
weights : ndarray
The 1D array of weights.
wavenum : ndarray
The 1D array of wavenum.
ffreq : float or None, optional
The expected fringe frequency, used to determine number of knots.
If None, defaults to NUM_KNOTS constant.
Returns
-------
bg_fit : ndarray
The fitted background.
bgindx : ndarray
The location of the knots.
"""
# first get the weighted pixel fraction
weighted_pix_frac = (weights > 1e-05).sum() / flux.shape[0]
# define number of knots using fringe freq, want 1 knot per period
if ffreq is not None:
log.debug(f"fit_1d_background_complex: knot positions for {ffreq} cm-1")
nknots = int((np.amax(wavenum) - np.amin(wavenum)) / (ffreq))
else:
log.debug(f"fit_1d_background_complex: using num_knots={NUM_KNOTS}")
nknots = int((flux.shape[0] / 1024) * NUM_KNOTS)
log.debug(f"fit_1d_background_complex: number of knots = {nknots}")
# recale wavenums to around 1 for bayesicfitting
factor = np.amin(wavenum)
wavenum_scaled = wavenum.copy() / factor
# get number of fringe periods in array
nper = (np.amax(wavenum) - np.amin(wavenum)) // ffreq
log.debug(f"fit_1d_background_complex: column is {nper} fringe periods")
# now reduce by the weighted pixel fraction to see how many can be fitted
nper_cor = int(nper * weighted_pix_frac)
log.debug(f"fit_1d_background_complex: column has {nper_cor} weighted fringe periods")
# require at least 5 sine periods to fit
if nper < 5:
log.info(" not enough weighted data, no fit performed")
return flux.copy(), np.zeros(flux.shape[0]), None
bgindx = make_knots(flux.copy(), int(nknots), weights=weights.copy())
bgknots = wavenum_scaled[bgindx].astype(float)
# Reverse (and clip) the fit data as scipy/astropy need monotone increasing data for SW detector
t = bgknots[::-1][1:-1]
x = wavenum_scaled[::-1]
y = flux[::-1]
w = weights[::-1]
# Fit the spline
if ffreq > 1.5:
bg_model = spline_fitter(x, y, w, t, 2, reject_outliers=True)
else:
# robust fitting causing problems for fringe 2, change to just using fitter there
bg_model = spline_fitter(x, y, w, t, 1, reject_outliers=False)
# fit the background
bg_fit = bg_model(wavenum_scaled)
bg_fit *= np.where(weights.copy() > 1e-07, 1, 1e-08)
# linearly interpolate over the feature gaps if possible, stops issues later
try:
nz, z = interp_helper(weights)
bg_fit[nz] = np.interp(z(nz), z(~nz), bg_fit[~nz])
del nz, z
except ValueError:
pass
return bg_fit, bgindx
def fit_1d_fringes_bayes_evidence_1d(
res_fringes, weights, wavenum, ffreq, dffreq, max_nfringes, pgram_res
):
"""
Fit the residual fringe signal in 1D.
Takes an input 1D array of residual fringes and fits them using
the supplied mode in the BayesicFitting package.
Parameters
----------
res_fringes : ndarray
The 1D array with residual fringes.
weights : ndarray
The 1D array of weights.
wavenum : ndarray
The 1D array of wavenum.
ffreq : float
The central scan frequency.
dffreq : float
The one-sided interval of scan frequencies.
max_nfringes : int
The maximum number of fringes to check.
pgram_res : float
Resolution of the periodogram scan in cm-1.
Returns
-------
res_fringe_fit : ndarray
The residual fringe fit data.
"""
# initialize output to none
res_fringe_fit = None
# get the number of weighted pixels
weighted_pix_num = (weights > 1e-05).sum()
# get scan res
res = np.around((2 * dffreq) / pgram_res).astype(int)
factor = np.amin(wavenum)
wavenum = wavenum.copy() / factor
ffreq = ffreq / factor
dffreq = dffreq / factor
# setup frequencies to scan
freq = np.linspace(ffreq - dffreq, ffreq + dffreq, res)
# handle out of slice pixels
res_fringes = np.nan_to_num(res_fringes)
# initialise some parameters
res_fringes_proc = res_fringes.copy()
nfringes = 0
keep_dict = {}
fitted_frequencies = []
# get the initial evidence from ConstantModel
sdml = ConstantModel(values=1.0)
sftr = Fitter(wavenum, sdml)
_ = sftr.fit(res_fringes, weights=weights)
evidence1 = sftr.getEvidence(limits=[-2, 1000], noiseLimits=[0.001, 1])
for nfringe in range(max_nfringes):
log.debug(f"Fitting fringe {nfringe} of {max_nfringes} max")
# get the scan arrays
res_fringe_scan = res_fringes_proc[np.where(weights > 1e-05)]
wavenum_scan = wavenum[np.where(weights > 1e-05)]
# use a Lomb-Scargle periodogram to get PSD and identify the strongest frequency
pgram = LombScargle(wavenum_scan[::-1], res_fringe_scan[::-1]).power(1 / freq)
peak = np.argmax(pgram)
freqs = 1.0 / freq[peak]
# fix the most significant frequency in the fixed dict that is passed to fitter
keep_ind = nfringes * 3
keep_dict[keep_ind] = freqs
mdl = multi_sine(nfringes + 1)
# fit the multi-sine model and get evidence
fitter = LevenbergMarquardtFitter(wavenum, mdl, verbose=0, keep=keep_dict)
ftr = RobustShell(fitter, domain=10)
try:
pars = ftr.fit(res_fringes, weights=weights)
# free the parameters and refit
mdl = multi_sine(nfringes + 1)
mdl.parameters = pars
fitter = LevenbergMarquardtFitter(wavenum, mdl, verbose=0)
ftr = RobustShell(fitter, domain=10)
ftr.fit(res_fringes, weights=weights)
# try to get evidence (may fail with ValueError
# for large component fits to noisy data)
evidence2 = fitter.getEvidence(limits=[-2, 1000], noiseLimits=[0.001, 1])
except (ValueError, RuntimeError, np.linalg.LinAlgError) as err:
# set evidence to large negative value in case of failure
log.debug("Fringe fit failed: %s", str(err))
evidence2 = -1e9
bayes_factor = evidence2 - evidence1
if bayes_factor > 1: # strong evidence thresh (log(bayes factor)>1, Kass and Raftery 1995)
evidence1 = evidence2
best_mdl = mdl.copy()
fitted_frequencies.append(freqs)
else:
break
# subtract the fringes for this frequency
res_fringe_fit = best_mdl(wavenum)
res_fringes_proc = res_fringes.copy() - res_fringe_fit
nfringes += 1
# Check for any successful fit
if len(fitted_frequencies) == 0:
raise ValueError(f"Failed to fit any fringes for frequency {ffreq}")
# create outputs to return
fitted_frequencies = (1 / np.asarray(fitted_frequencies)) * factor
peak_freq = fitted_frequencies[0]
freq_min = np.amin(fitted_frequencies)
freq_max = np.amax(fitted_frequencies)
return res_fringe_fit, weighted_pix_num, nfringes, peak_freq, freq_min, freq_max
[docs]
def fit_residual_fringes_1d(
flux,
wavelength,
channel=1,
dichroic_only=False,
max_amp=None,
clip_features=True,
clip_sigma=5.0,
max_line=None,
ignore_regions=None,
):
"""
Fit residual fringes in 1D.
Parameters
----------
flux : ndarray
The 1D array of fluxes.
wavelength : ndarray
The 1D array of wavelengths.
channel : int, optional
The MRS spectral channel.
dichroic_only : bool, optional
Fit only dichroic fringes.
max_amp : float, optional
The maximum relative amplitude value for fringe correction. If not provided,
is set to ``MAXAMP_1D``.
clip_features : bool, optional
If True, spectral features are masked via sigma clipping. If False, they
are detected and masked via comparison to the ``max_line`` value.
clip_sigma : float, optional
If ``clip_features`` is True, then this value is used as the sigma threshold
for clipping spectral features.
max_line : float, optional
The maximum relative amplitude value to detect an emission line. If not provided,
is set to ``MAXLINE_1D``. Used only if ``clip_features`` is False.
ignore_regions : list of list of float, optional
If provided, data in the wavelengths specified is ignored in the fringe
fits. The expected format is a list of [min_region, max_region] values, in
input wavelength units.
Returns
-------
output : ndarray
Modified version of input flux array.
"""
# Restrict to just the non-zero positive fluxes
indx = np.where(flux > 0)
useflux = flux[indx]
usewave = wavelength[indx]
wavenum = 10000.0 / usewave
weights = useflux / np.nanmedian(useflux)
weights[weights == np.inf] = 0
weights[np.isnan(weights)] = 0
# Zero out any weights longward of 27.6 microns as the calibration is too uncertain
# and can bias the fringe finding
weights[usewave > 27.6] = 0
# Zero out weights for any user-specified regions
if ignore_regions is not None:
for region in ignore_regions:
weights[(usewave > region[0]) & (usewave < region[1])] = 0
if np.all(weights == 0):
log.warning("No good data. Skipping correction.")
return flux
# get the maxamp of the fringes
if max_amp is None:
max_amp_array = np.full(useflux.shape, MAXAMP_1D)
else:
max_amp_array = np.full(useflux.shape, max_amp)
if max_line is None:
max_line_array = np.full(useflux.shape, MAXLINE_1D)
else:
max_line_array = np.full(useflux.shape, max_line)
# find spectral features by comparing to a low-order fit to the middle of the spectrum
lenv, _, _, uenv, _, _ = fit_envelope(
np.arange(useflux.shape[0]), useflux, check_extra_neighbors=True
)
model = (lenv + uenv) / 2
mod = np.abs((useflux - model) / model)
# given signal in mod find location of lines
if clip_features:
weight_factors = clip_spectral_features(mod, sigma=clip_sigma)
else:
weight_factors = find_lines(mod, max_line_array)
weights_feat = weights * weight_factors
if dichroic_only is True:
if channel not in [3, 4]:
raise ValueError(
"Dichroic fringe should only be removed from channels 3 and 4, stopping!"
)
ffreq_vals = [FFREQ_1D[1]]
dffreq_vals = [DFFREQ_1D[1]]
max_nfringes_vals = [MAX_NFRINGES_1D[1]]
else:
# check the channel and remove second fringe for channels 1 and 2
if channel == 1 or channel == 2:
ffreq_vals = [FFREQ_1D[0]]
dffreq_vals = [DFFREQ_1D[0]]
max_nfringes_vals = [MAX_NFRINGES_1D[0]]
else:
ffreq_vals = FFREQ_1D
dffreq_vals = DFFREQ_1D
max_nfringes_vals = MAX_NFRINGES_1D
# BayesicFitting doesn't like 0s at data or weight array edges so set to small value
useflux[useflux <= 0] = 1e-08
weights_feat[weights_feat <= 0] = 1e-08
# check for off-slice pixels and send to be filled with interpolated/extrapolated wnums
# to stop BayesicFitting crashing, will not be fitted anyway
found_bad = np.logical_or(np.isnan(wavenum), np.isinf(wavenum))
num_bad = len(np.where(found_bad)[0])
if num_bad > 0:
wavenum[found_bad] = 0
wavenum = fill_wavenumbers(wavenum)
# do the processing
proc_data = useflux.copy()
for n, ffreq in enumerate(ffreq_vals):
log.debug(f"Fitting frequency {n} ({ffreq})")
bg_fit, bgindx = fit_1d_background_complex_1d(proc_data, weights_feat, wavenum, ffreq=ffreq)
# get the residual fringes as fraction of signal
res_fringes = np.divide(
proc_data, bg_fit, out=np.full_like(proc_data, 1e-8), where=bg_fit != 0
)
np.subtract(res_fringes, 1, out=res_fringes, where=res_fringes != 0)
res_fringes *= np.where(weights > 1e-07, 1, 1e-08)
# fit the residual fringes
try:
res_fringe_fit, wpix_num, opt_nfringes, peak_freq, freq_min, freq_max = (
fit_1d_fringes_bayes_evidence_1d(
res_fringes,
weights_feat,
wavenum,
ffreq,
dffreq_vals[n],
max_nfringes_vals[n],
0.001,
)
)
except ValueError as err:
log.warning(str(err))
continue
# check for fit blowing up, reset rfc fit to 0, raise a flag
res_fringe_fit, res_fringe_fit_flag = check_res_fringes(res_fringe_fit, max_amp_array)
# correct for residual fringes
rfc_factors = 1 / (res_fringe_fit * (weights > 1e-05).astype(int) + 1)
proc_data *= rfc_factors
# handle nans or infs that may exist
proc_data[proc_data == np.inf] = 0
proc_data = np.nan_to_num(proc_data)
# Embed output back in a full-size array
output = flux.copy()
output[indx] = proc_data
return output