Source code for deepdisc.preprocessing.detection

import scarlet
import scarlet.wavelet
from astropy.stats import median_absolute_deviation as astromad
import numpy as np
import sep
import matplotlib.pyplot as plt
import time
from matplotlib.patches import Ellipse
from astropy.stats import sigma_clipped_stats


[docs] def mad_wavelet_own(image): """image: Median absolute deviation of the first wavelet scale. (WARNING: sorry to disapoint, this is not a wavelet for mad scientists) Have to use astropy mad as scipy mad does not like ignoring NaN and computing over multiple axes Parameters ---------- image: array An image or cube of images Returns ------- mad: array median absolute deviation for each image in the cube """ # Scale =1/1.4826 to replicate older scipy MAD behavior scale = 1 / 1.4826 sigma = astromad( scarlet.Starlet.from_image(image, scales=2).coefficients[0, ...], axis=(-2, -1), ignore_nan=True ) return sigma / scale
[docs] def make_catalog( datas, lvl=4, wave=True, segmentation_map=False, maskthresh=10.0, object_limit=100000, ): """ Creates a detection catalog by combining low and high resolution data Parameters ---------- datas: array array of Data objects lvl: int detection lvl wave: Bool set to True to use wavelet decomposition of images before combination subtract_background : Bool if you want to subtract the background and retrieve an estimate, change to True. But default is False because HSC images are already background subtracted. segmentation_map : Bool Whether to run sep segmentation map maskthresh : float Mask threshold for sep segmentation object_limit : int Limit on number of objects to detect in image Code adapted from https://pmelchior.github.io/scarlet/tutorials/wavelet_model.html Returns ------- catalog: sextractor catalog catalog of detected sources (use 'catalog.dtype.names' for info) bg_rms: array background level for each data set (set to None if subtract_background is False) """ if type(datas) is np.ndarray: hr_images = datas / np.sum(datas, axis=(1, 2))[:, None, None] # Detection image as the sum over all images # detect_image = np.sum(hr_images, axis=0) detect_image = np.sum(datas, axis=0) else: data_lr, data_hr = datas # Create observations for each image # Interpolate low resolution to high resolution interp = interpolate(data_lr, data_hr) # Normalization of the interpolate low res images interp = interp / np.sum(interp, axis=(1, 2))[:, None, None] # Normalisation of the high res data hr_images = data_hr.images / np.sum(data_hr.images, axis=(1, 2))[:, None, None] # Detection image as the sum over all images detect_image = np.sum(interp, axis=0) + np.sum(hr_images, axis=0) detect_image *= np.sum(data_hr.images) if np.size(detect_image.shape) == 4: if wave: # Wavelet detection in the first three levels wave_detect = scarlet.Starlet(detect_image.mean(axis=0), lvl=5).coefficients wave_detect[:, -1, :, :] = 0 detect = scarlet.Starlet(coefficients=wave_detect).image else: # Direct detection detect = detect_image.mean(axis=0) else: if wave: wave_detect = scarlet.Starlet.from_image(detect_image).coefficients detect = wave_detect[0] + wave_detect[1] + wave_detect[2] else: detect = detect_image bkg = sep.Background(detect) # Set the limit on the number of sub-objects when deblending. sep.set_sub_object_limit(object_limit) # Extract detection catalog with segmentation maps! # Can use this to retrieve ellipse params catalog = sep.extract( detect, lvl, err=bkg.globalrms, segmentation_map=segmentation_map, maskthresh=maskthresh, ) # Estimate background # Have to include because will no longer take ndarray if type(datas) is np.ndarray and len(datas > 2): bkg_rms = np.zeros(len(datas)) for i, data in enumerate(datas): bkg_rms[i] = mad_wavelet_own(data) else: bkg_rms = [] for data in datas: bkg_rms.append(scarlet.wavelet.mad_wavelet(data.images)) return catalog, bkg_rms
[docs] def fit_scarlet_blend( starlet_sources, observation, catalog, max_iters=15, e_rel=1e-4, plot_likelihood=True, savefigs=False, figpath="", ): """ Creates a detection catalog by combining low and high resolution data Parameters ---------- datas: array array of Data objects Will end early if likelihood and constraints converge Returns ------- """ # Create and fit Blend model. Go for 200 iterations, # but will end early if likelihood and constraints converge print(f"Fitting Blend model.") try: starlet_blend = scarlet.Blend(starlet_sources, observation) it, logL = starlet_blend.fit(max_iters, e_rel=e_rel) print(f"Scarlet ran for {it} iterations to logL = {logL}") # Catch any exceptions like no detections except AssertionError as e1: print(f"Length of detection catalog is {len(catalog)}.") raise return starlet_blend, logL
[docs] def run_scarlet( datas, filters, catalog=None, stretch=0.1, Q=5, sigma_model=1, sigma_obs=5, psf=None, subtract_background=False, max_chi2=5000, max_iters=15, morph_thresh=0.1, lvl=5, lvl_segmask=2, maskthresh=0.025, segmentation_map=True, wave_cat=False, plot_wavelet=False, plot_likelihood=True, plot_scene=False, plot_sources=False, add_ellipses=True, add_labels=False, add_boxes=False, percentiles=(1, 99), savefigs=False, figpath="", weights=None, return_models=True ): """Run P. Melchior's scarlet (https://github.com/pmelchior/scarlet) implementation for source separation. This function will create diagnostic plots, a source detection catalog, and fit a model for all sources in the observation scene (image). Parameters ---------- datas: ndarray multichannel array of data, shape (Nfilters x N x N) filters: list str list of filters, e.g. ('u','g','r) catalog: pandas df external source catalog used to initialize source positions and get additional source information stretch: float plotting parameter Q: float plotting parameter sigma_model: float scarlet model psf sigma_obs: FWHM for observational psf. Use psf argument instead if you have an image of the psf psf: ndarray psf image array max_iters: int max number of iterations for scarlet to fit morph_thresh: float parameter for source initialization lvl: int detection level for wavelet catalog detection lvl_segmask: int segmentation mask detection level subtract_background : boolean Whether or not to estimate and subtract the background (often background is already subtracted) Detault is False plot_wavelet : boolean Plot starlet wavelet transform and inverse transform at different scales. NOTE: Not really useful at large image sizes (> ~few hundred pixels length/height) Default is False plot_detections : boolean Plot detection catalog results. Default is False plot_likelihood : boolean Plot likelihood as function of iterations from Blend fit function. Default is True plot_scene : boolean Plot full scene with the model, rendered model, observation, and residual. Default is False. plot_sources : boolean Plot the model, rendered model, observation, and spectrum across channels for each object. WARNING: Not advised to do this with a large image with many sources! Default is False plot_first_isolated_comp : boolean Plot the subtracted and isolated first (or any) starlet component. Recommended for finding a bright component. Default is False. Return ------- observation: scarlet function Scarlet observation objects starlet_sources: list List of ScarletSource objects model_frame: scarlet function Image frame of source model catalog_deblended: list Deblended source detection catalog source_catalog: pandas df External catalog of source detections segmentation_masks: list List of segmentation mask of each object in image """ # norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q) norm = scarlet.display.AsinhPercentileNorm(datas, percentiles=percentiles) # Define model frame and observations: model_psf = scarlet.GaussianPSF(sigma=sigma_model) # , boxsize=100) model_frame = scarlet.Frame(datas.shape, psf=model_psf, channels=filters) spread=None if catalog is None: wavecat=True print('Generate source catalog using wavelets') t0 = time.time() catalog, bg_rms_hsc = make_catalog(datas, lvl, wave=True) print(time.time() - t0) # If image is already background subtracted, weights are set to 1 if subtract_background: weights = np.ones_like(datas) / (bg_rms_hsc**2)[:, None, None] else: weights = np.ones_like(datas) print("Sep found ", len(catalog), "objects") Rs = np.sqrt(catalog["a"] ** 2 + catalog["b"] ** 2) spread = Rs / sigma_obs chi2s = np.zeros(len(catalog)) centers = [(src["y"], src["x"]) for src in catalog] # y needs to be first else: wavecat=False #catalog_, bg_rms_hsc = make_catalog(datas, lvl, wave=True) #Rs = np.sqrt(catalog_["a"] ** 2 + catalog_["b"] ** 2) #spread = Rs / sigma_obs #print(spread) print("Source catalog has ", len(catalog), "objects") chi2s = np.zeros(len(catalog)) centers = [ (catalog.iloc[i]["y"], catalog.iloc[i]["x"]) for i in range(len(catalog)) ] print(centers) # Plot wavelet transform at different scales if plot_wavelet == True: _plot_wavelet(datas) if psf is None: observation_psf = scarlet.GaussianPSF(sigma=sigma_obs) else: observation_psf = scarlet.ImagePSF(psf) observation = scarlet.Observation(datas, psf=observation_psf, weights=weights, channels=filters).match( model_frame ) # Initialize starlet sources to be fit. Assume extended sources for all because # we are not looking at all detections in each image # TODO: Plot chi2 vs. binned size and mag. Implement conidition if chi2 > xxx then # add another component until larger sources are modeled well print("Initializing starlet sources to be fit.") t0 = time.time() starlet_sources, skipped = scarlet.initialization.init_all_sources( model_frame, centers, observation, max_components=5, thresh=morph_thresh, fallback=False, silent=True, set_spectra=False, ) if len(starlet_sources)==0: if spread is not None: print("Modeling as extended sources") for k, src in enumerate(catalog): # Is the source compact relative to the PSF? if spread[k] < 1: compact = True else: compact = False # Try modeling each source as a single ExtendedSource first new_source = scarlet.ExtendedSource(model_frame, (src['y'], src['x']), observation, K=1, thresh=morph_thresh, compact=compact) starlet_sources.append(new_source) else: for k in range(len(catalog)): # Try modeling each source as a single ExtendedSource first new_source = scarlet.ExtendedSource(model_frame, (centers[k][0], centers[k][1]), observation, K=1, thresh=morph_thresh, compact=False) starlet_sources.append(new_source) # Fit scarlet blend starlet_blend, logL = fit_scarlet_blend( starlet_sources, observation, catalog, max_iters=max_iters, plot_likelihood=plot_likelihood, savefigs=savefigs, figpath=figpath, ) print(time.time() - t0) # Extract the deblended catalog and update the chi2 residuals print("Extracting deblended catalog.") catalog_deblended = [] segmentation_masks = [] for k, src in enumerate(starlet_sources): model = src.get_model(frame=model_frame) model = observation.render(model) # Compute in bbox only model = src.bbox.extract_from(model) bkgmod = sep.Background(np.sum(datas, axis=0)) std = sigma_clipped_stats(datas)[2] # Run sep if return_models: try: cat, _ = make_catalog( model, lvl_segmask, wave=False, segmentation_map=False, maskthresh=maskthresh, ) except: print(f"Exception with source {k}") cat = [] # if segmentation_map == True: # cat, mask = cat # If more than 1 source is detected for some reason (e.g. artifacts) if len(cat) > 1: # keep the brightest idx = np.argmax([c["cflux"] for c in cat]) cat = cat[idx] # if segmentation_map == True: # mask = mask[idx] # If failed to detect model source if len(cat) == 0: # Fill with nan cat = [np.full(model.shape, np.nan, dtype=model.dtype)] catalog_deblended.append(cat) # Append to full catalog if segmentation_map == True: # For some reason sep doesn't like these images, so do the segmask ourselves for now model_det = np.sum(model, axis=0) mask = np.zeros_like(model_det) #mask[model_det > lvl_segmask * bkgmod.globalrms] = 1 mask[model_det > lvl_segmask * std] = 1 segmentation_masks.append(mask) # plt.imshow(mask) # plt.show() # Plot scene: rendered model, observations, and residuals if plot_scene == True: _plot_scene( starlet_sources, observation, norm, catalog=catalog, wave_cat=wave_cat, show_model=False, show_rendered=True, show_observed=True, show_residual=True, add_labels=add_labels, add_boxes=add_boxes, add_ellipses=add_ellipses, savefigs=savefigs, figpath=figpath, ) # Plot each for each source if plot_sources == True: scarlet.display.show_sources( starlet_sources, observation, norm=norm, show_rendered=True, show_observed=True, add_boxes=add_boxes, ) if savefigs: plt.savefig(figpath + "sources.png") plt.show() if return_models: return ( observation, starlet_sources, model_frame, catalog, segmentation_masks, ) else: return ( observation, starlet_sources, model_frame, catalog, #catalog_deblended, segmentation_masks, )
[docs] def _plot_scene( starlet_sources, observation, norm, catalog, wave_cat=True, show_model=True, show_rendered=True, show_observed=True, show_residual=True, add_labels=True, add_boxes=True, add_ellipses=True, savefigs=False, figpath="", ): """ Helper function to plot scene with scarlet Parameters ---------- starlet_sources: List List of ScarletSource objects observation: Scarlet observation objects norm: Scarlet normalization for plotting catalog: list Source detection catalog show_model: bool Whether to show model show_rendered: bool Whether to show rendered model show_observed: bool Whether to show observed show_residual: bool Whether to show residual add_labels: bool Whether to add labels add_boxes: bool Whether to add bounding boxes to each panel add_ellipses: bool Whether to add ellipses to each panel Returns ------- fig : matplotlib Figure Figure object """ fig = scarlet.display.show_scene( starlet_sources, observation=observation, norm=norm, show_model=show_model, show_rendered=show_rendered, show_observed=show_observed, show_residual=show_residual, add_labels=add_labels, add_boxes=add_boxes, ) for ax in fig.axes: # Plot sep ellipse around all sources from the detection catalog if add_ellipses == True: if wave_cat: for k, src in enumerate(catalog): # See https://sextractor.readthedocs.io/en/latest/Position.html e = Ellipse(xy=(src['x'], src['y']), width=6*src['a'], height=6*src['b'], angle=np.rad2deg(src['theta'])) e.set_facecolor('none') e.set_edgecolor('white') ax.add_artist(e) else: for i in range(len(catalog)): obj = catalog.iloc[i] # See https://sextractor.readthedocs.io/en/latest/Position.html e = Ellipse(xy=(obj["x"], obj["y"]), width=10, height=10, angle=0) e.set_facecolor("none") e.set_edgecolor("white") ax.add_artist(e) ax.axis("off") fig.subplots_adjust(wspace=0.01) if savefigs: plt.savefig(figpath + "scarlet_out.png") plt.show() return fig