"""
Apply linearity correction to a science image
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import numpy as np
from roman_datamodels import datamodels as rdd
from roman_datamodels.dqflags import group, pixel
from stcal.linearity.linearity import linearity_correction
from romancal.datamodels.fileio import open_dataset
from romancal.stpipe import RomanStep
if TYPE_CHECKING:
from typing import ClassVar
__all__ = ["LinearityStep"]
log = logging.getLogger(__name__)
def make_inl_correction(inl_model, ncols):
"""
Create a callable for integral nonlinearity correction.
Parameters
----------
inl_model : datamodel
The integral nonlinearity reference file model.
ncols : int
Number of columns in the data, used to determine which channels
to extract.
Returns
-------
callable
A function that takes a 3D array (nreads, nrows, ncols) and returns
a correction array of the same shape to be added to the data.
"""
channel_width = 128
lookup_values = inl_model.value.copy().astype("f4")
channel_corrections = {}
for start_col in range(0, ncols, channel_width):
channel_num = start_col // channel_width + 1
attr_name = f"science_channel_{channel_num:02d}"
channel_corrections[channel_num] = getattr(
inl_model.inl_table, attr_name
).correction.copy()
def inl_correction(data):
"""Apply INL correction to data array."""
result = np.zeros_like(data)
for start_col in range(0, data.shape[-1], channel_width):
channel_num = start_col // channel_width + 1
correction = channel_corrections[channel_num]
channel_data = data[..., start_col : start_col + channel_width]
result[..., start_col : start_col + channel_width] = np.interp(
channel_data, lookup_values, correction
)
return result
return inl_correction
[docs]
class LinearityStep(RomanStep):
"""
LinearityStep: This step performs a correction for non-linear
detector response, using the "classic" polynomial method.
"""
class_alias = "linearity"
reference_file_types: ClassVar = [
"linearity",
"inverselinearity",
"integralnonlinearity",
]
[docs]
def process(self, dataset):
input_model = open_dataset(dataset, update_version=self.update_version)
# Get reference file names
self.lin_name = self.get_reference_file(input_model, "linearity")
self.ilin_name = self.get_reference_file(input_model, "inverselinearity")
self.inl_name = self.get_reference_file(input_model, "integralnonlinearity")
log.info("Using LINEARITY reference file: %s", self.lin_name)
log.info("Using INVERSELINEARITY reference file: %s", self.ilin_name)
log.info("Using INTEGRALNONLINEARITY reference file: %s", self.inl_name)
# Check for valid reference files
if self.lin_name == "N/A" or self.ilin_name == "N/A":
log.warning("No LINEARITY or INVERSELINEARITY reference file found")
log.warning("Linearity step will be skipped")
input_model.meta.cal_step["linearity"] = "SKIPPED"
return input_model
# INL correction is optional
inl_correction = None
if self.inl_name != "N/A":
with rdd.open(self.inl_name) as inl_model:
inl_correction = make_inl_correction(
inl_model, input_model.data.shape[-1]
)
with (
rdd.LinearityRefModel(self.lin_name, memmap=False) as lin_model,
rdd.InverselinearityRefModel(self.ilin_name, memmap=False) as ilin_model,
):
lin_coeffs = lin_model.coeffs
lin_dq = lin_model.dq
ilin_coeffs = ilin_model.coeffs
read_pattern = input_model.meta.exposure.read_pattern
gdq = input_model.groupdq[np.newaxis, :]
pdq = input_model.pixeldq
input_model.data = input_model.data[np.newaxis, :]
# Call linearity correction function in stcal
# The third return value is the processed zero frame which
# Roman does not use.
new_data, new_pdq, _ = linearity_correction(
input_model.data,
gdq,
pdq,
lin_coeffs,
lin_dq,
pixel,
ilin_coeffs=ilin_coeffs,
additional_correction=inl_correction,
read_pattern=read_pattern,
)
input_model.data = new_data[0, :, :, :]
input_model.pixeldq = new_pdq
# FIXME: force all values in array to be at least vaguely sane.
# This should not happen for good linearity corrections and linearity
# correction flagging, but current reference files have issues that
# cause more problems downstream.
# Full well is 65k DN. After linearity correction we can't be more than
# a factor of several away from this.
# Any points larger than 1e6 should be flagged.
m = np.abs(input_model.data) > 1e6
input_model.data[m] = np.clip(input_model.data[m], -1e6, 1e6)
input_model.groupdq[m] |= group.DO_NOT_USE
nbad = np.sum(m)
log.warning(f"Flagged {nbad} spurious values outside remotely plausible range.")
# Update the step status
input_model.meta.cal_step["linearity"] = "COMPLETE"
if self.save_results:
try:
self.suffix = "linearity"
except AttributeError:
self["suffix"] = "linearity"
return input_model