Source code for deepdisc.astrodet.scarlet_nocatalog

import os
import sys
import time

import astropy.io.fits as fits
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scarlet
import scarlet.wavelet
import sep
from astropy.io import ascii
from astropy.stats import median_absolute_deviation as astromad
from astropy.visualization.lupton_rgb import AsinhMapping, LinearMapping
from astropy.wcs import WCS
from matplotlib.patches import Ellipse
from scipy.stats import median_abs_deviation as mad


[docs] def write_scarlet_results( datas, observation, starlet_sources, model_frame, catalog_deblended, segmentation_masks, dirpath, filters, s, ): """ Saves images in each channel, with headers for each source in image, such that the number of headers = number of sources detected in image. Parameters ---------- datas: array array of Data objects observation: scarlet function Scarlet observation objects starlet_sources: list List of ScarletSource objects model_frame: scarlet function Image frame of source model catalog_deblended: list Deblended source detection catalog segmentation_masks: list List of segmentation mask of each object in image dirpath : str Path to HSC image file directory filters : list A list of filters for your images. Default is ['g', 'r', 'i']. s : str File basename string Returns ------- filename : dict dictionary of all paths to the saved scarlet files for the particular dataset. Saved image and model files for each filter, and one total segmentation mask file for all filters. """ def _make_hdr(starlet_source, cat): """ Helper function to make FITS header and insert metadata. Parameters ---------- starlet_source: starlet_source starlet_source object for source k cat: dict catalog data for source k Returns ------- model_hdr : Astropy fits.Header FITS header for source k with catalog metadata """ # For each header, assign descriptive data about each source # (x0, y0, w, h) in absolute floating pixel coordinates bbox_h = starlet_source.bbox.shape[1] bbox_w = starlet_source.bbox.shape[2] bbox_y = starlet_source.bbox.origin[1] + int(np.floor(bbox_w / 2)) # y-coord of the source's center bbox_x = starlet_source.bbox.origin[2] + int(np.floor(bbox_w / 2)) # x-coord of the source's center # Ellipse parameters (a, b, theta) from deblend catalog e_a, e_b, e_theta = cat["a"], cat["b"], cat["theta"] ell_parm = np.concatenate((cat["a"], cat["b"], cat["theta"])) # Add info to header model_hdr = fits.Header() model_hdr["bbox"] = ",".join(map(str, [bbox_x, bbox_y, bbox_w, bbox_h])) model_hdr["area"] = bbox_w * bbox_h model_hdr["ell_parm"] = ",".join(map(str, list(ell_parm))) model_hdr["cat_id"] = 1 # Category ID #TODO: set categor_id based on if the source is extended or not return model_hdr # Create dict for all saved filenames segmask_hdul = [] model_hdul = [] filenames = {} # Filter loop for i, f in enumerate(filters): # datas is HSC data array with dimensions [filters, N, N] f = f.upper() # Primary HDU is full image img_hdu = fits.PrimaryHDU(data=datas[i]) # Create header entry for each scarlet source for k, (src, cat) in enumerate(zip(starlet_sources, catalog_deblended)): # Get each model, make into image model = starlet_sources[k].get_model(frame=model_frame) model = observation.render(model) model = src.bbox.extract_from(model) model_hdr = _make_hdr(starlet_sources[k], cat) model_hdu = fits.ImageHDU(data=model[i], header=model_hdr) model_primary = fits.PrimaryHDU() model_hdul.append(model_hdu) # Write final fits file to specified location # Save full image and then headers per source w/ descriptive info save_img_hdul = fits.HDUList([img_hdu]) save_model_hdul = fits.HDUList([model_primary, *model_hdul]) # GM-11/29/22 # There's some bug here that causes the R band model.fits file to copy all the catalog sources a second time # and the I band model.fits file to copy all the sources a third time # I.e., All bands should have the same N sources, but R band has 2N and I band has 3N # I'm not tracking this down as of now because the G band is fine and # the only downstream thing the model files are used for is determining the source class from the catalog cross-matching # Save list of filenames in dict for each band filenames[f"img_{f}"] = os.path.join(dirpath, f"{f}-{s}_scarlet_img.fits") save_img_hdul.writeto(filenames[f"img_{f}"], overwrite=True) filenames[f"model_{f}"] = os.path.join(dirpath, f"{f}-{s}_scarlet_model.fits") save_model_hdul.writeto(filenames[f"model_{f}"], overwrite=True) # If we have segmentation mask data, save them as a separate fits file # Just using the G band for the segmentation mask if segmentation_masks is not None: for i, f in enumerate(filters[0]): # Create header entry for each scarlet source for k, (src, cat) in enumerate(zip(starlet_sources, catalog_deblended)): segmask_hdr = _make_hdr(starlet_sources[k], cat) # Save each model source k in the image segmask_hdu = fits.ImageHDU(data=segmentation_masks[k], header=segmask_hdr) segmask_primary = fits.PrimaryHDU() segmask_hdul.append(segmask_hdu) save_segmask_hdul = fits.HDUList([segmask_primary, *segmask_hdul]) # Save list of filenames in dict for each band filenames["segmask"] = os.path.join(dirpath, f"{f}-{s}_scarlet_segmask.fits") save_segmask_hdul.writeto(filenames["segmask"], overwrite=True) return filenames
[docs] def plot_stretch_Q(datas, stretches=[0.01, 0.1, 0.5, 1], Qs=[1, 10, 5, 100]): """ Plots different normalizations of your image using the stretch, Q parameters. Parameters ---------- stretches : array List of stretch params you want to permutate through to find optimal image normalization. Default is [0.01, 0.1, 0.5, 1] Qs : array List of Q params you want to permutate through to find optimal image normalization. Default is [1, 10, 5, 100] Code adapted from: https://pmelchior.github.io/scarlet/tutorials/display.html Returns ------- fig : Figure object """ fig, ax = plt.subplots(len(stretches), len(Qs), figsize=(9, 9)) for i, stretch in enumerate(stretches): for j, Q in enumerate(Qs): asinh = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q) # Scale the RGB channels for the image img_rgb = scarlet.display.img_to_rgb(datas, norm=asinh) ax[i][j].imshow(img_rgb) ax[i][j].set_title("Stretch {}, Q {}".format(stretch, Q)) ax[i][j].axis("off") return fig
[docs] def make_catalog( datas, lvl=4, wave=True, segmentation_map=False, maskthresh=10.0, object_limit=100000, ): """ Creates a detection catalog by combining low and high resolution data Parameters ---------- datas: array array of Data objects lvl: int detection lvl wave: Bool set to True to use wavelet decomposition of images before combination subtract_background : Bool if you want to subtract the background and retrieve an estimate, change to True. But default is False because HSC images are already background subtracted. segmentation_map : Bool Whether to run sep segmentation map maskthresh : float Mask threshold for sep segmentation object_limit : int Limit on number of objects to detect in image Code adapted from https://pmelchior.github.io/scarlet/tutorials/wavelet_model.html Returns ------- catalog: sextractor catalog catalog of detected sources (use 'catalog.dtype.names' for info) bg_rms: array background level for each data set (set to None if subtract_background is False) """ if type(datas) is np.ndarray: hr_images = datas / np.sum(datas, axis=(1, 2))[:, None, None] # Detection image as the sum over all images # detect_image = np.sum(hr_images, axis=0) detect_image = np.sum(datas, axis=0) else: data_lr, data_hr = datas # Create observations for each image # Interpolate low resolution to high resolution interp = interpolate(data_lr, data_hr) # Normalization of the interpolate low res images interp = interp / np.sum(interp, axis=(1, 2))[:, None, None] # Normalisation of the high res data hr_images = data_hr.images / np.sum(data_hr.images, axis=(1, 2))[:, None, None] # Detection image as the sum over all images detect_image = np.sum(interp, axis=0) + np.sum(hr_images, axis=0) detect_image *= np.sum(data_hr.images) if np.size(detect_image.shape) == 4: if wave: # Wavelet detection in the first three levels wave_detect = scarlet.Starlet(detect_image.mean(axis=0), lvl=5).coefficients wave_detect[:, -1, :, :] = 0 detect = scarlet.Starlet(coefficients=wave_detect).image else: # Direct detection detect = detect_image.mean(axis=0) else: if wave: wave_detect = scarlet.Starlet.from_image(detect_image).coefficients detect = wave_detect[0] + wave_detect[1] + wave_detect[2] else: detect = detect_image bkg = sep.Background(detect) # Set the limit on the number of sub-objects when deblending. sep.set_sub_object_limit(object_limit) # Extract detection catalog with segmentation maps! # Can use this to retrieve ellipse params catalog = sep.extract( detect, lvl, err=bkg.globalrms, segmentation_map=segmentation_map, maskthresh=maskthresh, ) # Estimate background # Have to include because will no longer take ndarray if type(datas) is np.ndarray and len(datas > 2): bkg_rms = np.zeros(len(datas)) for i, data in enumerate(datas): bkg_rms[i] = mad_wavelet_own(data) else: bkg_rms = [] for data in datas: bkg_rms.append(scarlet.wavelet.mad_wavelet(data.images)) return catalog, bkg_rms
[docs] def mad_wavelet_own(image): """image: Median absolute deviation of the first wavelet scale. (WARNING: sorry to disapoint, this is not a wavelet for mad scientists) Have to use astropy mad as scipy mad does not like ignoring NaN and computing over multiple axes Parameters ---------- image: array An image or cube of images Returns ------- mad: array median absolute deviation for each image in the cube """ # Scarlet seems to no longer take ndarrays, so have to go through each channel # Scale =1/1.4826 to replicate older scipy MAD behavior scale = 1 / 1.4826 sigma = astromad( scarlet.Starlet.from_image(image, scales=2).coefficients[0, ...], axis=(-2, -1), ignore_nan=True, ) return sigma / scale
[docs] def fit_scarlet_blend( starlet_sources, observation, catalog, max_iters=15, e_rel=1e-4, plot_likelihood=True, savefigs=False, figpath="", ): """ Creates a detection catalog by combining low and high resolution data Parameters ---------- datas: array array of Data objects Will end early if likelihood and constraints converge Returns ------- """ # Create and fit Blend model. Go for 200 iterations, # but will end early if likelihood and constraints converge print(f"Fitting Blend model.") try: starlet_blend = scarlet.Blend(starlet_sources, observation) it, logL = starlet_blend.fit(max_iters, e_rel=e_rel) print(f"Scarlet ran for {it} iterations to logL = {logL}") # Catch any exceptions like no detections except AssertionError as e1: print(f"Length of detection catalog is {len(catalog)}.") raise if plot_likelihood == True: scarlet.display.show_likelihood(starlet_blend) plt.ylabel("log-Likelihood", fontsize=15) plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) plt.xlabel("Iteration", fontsize=15) plt.xticks(fontsize=13) plt.yticks(fontsize=13) plt.subplots_adjust(left=0.2) if savefigs: plt.savefig(figpath + "scarlet_likelihood.png") plt.show() return starlet_blend, logL
[docs] def _plot_wavelet(datas): """ Helper function to plot wavelet transformation diagnostic figures with scarlet Parameters ---------- datas: array array of Data objects Returns ------- """ # Declare a starlet object (and performs the transform) Sw = scarlet.Starlet.from_image(datas) # , lvl=5, direct=True) # This is the starlet transform as an array w = Sw.coefficients # The inverse starlet transform of w (new object otherwise, the tranform is not used) iw = Sw.image # TODO: Clean this code up using plt.subplots() # The wavelet transform of the first slice of images in pictures lvl = w.shape[1] plt.figure(figsize=(lvl * 5 + 5, 5)) plt.suptitle("Wavelet coefficients") for i in range(lvl): plt.subplot(1, lvl, i + 1) plt.title("scale" + str(i + 1)) plt.imshow(w[0, i], cmap="inferno") plt.colorbar() plt.show() # Making sure we recover the original image plt.figure(figsize=(30, 10)) plt.subplot(131) plt.title("Original image", fontsize=20) plt.imshow(datas[0], cmap="inferno") plt.colorbar() plt.subplot(132) plt.title("Starlet-reconstructed image", fontsize=20) plt.imshow(iw[0], cmap="inferno") plt.colorbar() plt.subplot(133) plt.title("Absolute difference", fontsize=20) plt.imshow((np.abs(iw[0] - datas[0])), cmap="inferno") plt.colorbar() plt.show() return
[docs] def _plot_scene( starlet_sources, observation, norm, catalog, show_model=True, show_rendered=True, show_observed=True, show_residual=True, add_labels=True, add_boxes=True, add_ellipses=True, savefigs=False, figpath="", ): """ Helper function to plot scene with scarlet Parameters ---------- starlet_sources: List List of ScarletSource objects observation: Scarlet observation objects norm: Scarlet normalization for plotting catalog: list Source detection catalog show_model: bool Whether to show model show_rendered: bool Whether to show rendered model show_observed: bool Whether to show observed show_residual: bool Whether to show residual add_labels: bool Whether to add labels add_boxes: bool Whether to add bounding boxes to each panel add_ellipses: bool Whether to add ellipses to each panel Returns ------- fig : matplotlib Figure Figure object """ fig = scarlet.display.show_scene( starlet_sources, observation=observation, norm=norm, show_model=show_model, show_rendered=show_rendered, show_observed=show_observed, show_residual=show_residual, add_labels=add_labels, add_boxes=add_boxes, ) for ax in fig.axes: # Plot sep ellipse around all sources from the detection catalog if add_ellipses == True: for k, src in enumerate(catalog): # See https://sextractor.readthedocs.io/en/latest/Position.html e = Ellipse( xy=(src["x"], src["y"]), width=6 * src["a"], height=6 * src["b"], angle=np.rad2deg(src["theta"]), ) e.set_facecolor("none") e.set_edgecolor("white") ax.add_artist(e) ax.axis("off") fig.subplots_adjust(wspace=0.01) if savefigs: plt.savefig(figpath + "scarlet_out.png") plt.show() return fig
[docs] def run_scarlet( datas, filters, stretch=0.1, Q=5, sigma_model=1, sigma_obs=5, psf=None, subtract_background=False, max_chi2=5000, max_iters=15, morph_thresh=0.1, starlet_thresh=0.1, lvl=5, lvl_segmask=2, maskthresh=0.025, segmentation_map=True, plot_wavelet=False, plot_likelihood=True, plot_scene=False, plot_sources=False, add_ellipses=True, add_labels=False, add_boxes=False, percentiles=(1, 99), savefigs=False, figpath="", weights=None, ): """Run P. Melchior's scarlet (https://github.com/pmelchior/scarlet) implementation for source separation. This function will create diagnostic plots, a source detection catalog, and fit a model for all sources in the observation scene (image). Parameters ---------- subtract_background : boolean Whether or not to estimate and subtract the background (often background is already subtracted) Detault is False plot_wavelet_transform : boolean Plot starlet wavelet transform and inverse transform at different scales. NOTE: Not really useful at large image sizes (> ~few hundred pixels length/height) Default is False plot_detections : boolean Plot detection catalog results. Default is False plot_likelihood : boolean Plot likelihood as function of iterations from Blend fit function. Default is True plot_full_scene : boolean Plot full scene with the model, rendered model, observation, and residual. Default is False. plot_all_sources : boolean Plot the model, rendered model, observation, and spectrum across channels for each object. WARNING: dumb to do this with a large image with many sources! Default is False plot_first_isolated_comp : boolean Plot the subtracted and isolated first (or any) starlet component. Recommended for finding a bright component. Default is False. Return ------- FITS file with... TODO: fill this out once I get the exact fits file output generated to Colin's liking """ # norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q) norm = scarlet.display.AsinhPercentileNorm(datas, percentiles=percentiles) # Generate source catalog using wavelets t0 = time.time() catalog, bg_rms_hsc = make_catalog(datas, lvl, wave=True) print(time.time() - t0) # If image is already background subtracted, weights are set to 1 if subtract_background: weights = np.ones_like(datas) / (bg_rms_hsc**2)[:, None, None] else: weights = np.ones_like(datas) print("Source catalog found ", len(catalog), "objects") # Plot wavelet transform at different scales if plot_wavelet == True: _plot_wavelet(datas) # Define model frame and observations: model_psf = scarlet.GaussianPSF(sigma=sigma_model) # , boxsize=100) model_frame = scarlet.Frame(datas.shape, psf=model_psf, channels=filters) # observation_psf = scarlet.GaussianPSF(sigma=sigma_obs) observation_psf = scarlet.ImagePSF(psf) observation = scarlet.Observation(datas, psf=observation_psf, weights=weights, channels=filters).match( model_frame ) # Initialize starlet sources to be fit. Assume extended sources for all because # we are not looking at all detections in each image # TODO: Plot chi2 vs. binned size and mag. Implement conidition if chi2 > xxx then # add another component until larger sources are modeled well print("Initializing starlet sources to be fit.") # Compute radii and spread of sources Rs = np.sqrt(catalog["a"] ** 2 + catalog["b"] ** 2) spread = Rs / sigma_obs # Array of chi^2 residuals computed after fit on each model chi2s = np.zeros(len(catalog)) t0 = time.time() # Loop through detections in catalog starlet_sources = [] for k, src in enumerate(catalog): # Is the source compact relative to the PSF? if spread[k] < 1: compact = True else: compact = False # Try modeling each source as a single ExtendedSource first new_source = scarlet.ExtendedSource( model_frame, (src["y"], src["x"]), observation, K=1, thresh=morph_thresh, compact=compact, ) starlet_sources.append(new_source) # Fit scarlet blend starlet_blend, logL = fit_scarlet_blend( starlet_sources, observation, catalog, max_iters=max_iters, plot_likelihood=plot_likelihood, savefigs=savefigs, figpath=figpath, ) print(time.time() - t0) """ print("Computing residuals.") # Compute reduced chi^2 for each rendered sources for k, src in enumerate(starlet_sources): model = src.get_model(frame=model_frame) model = observation.render(model) res = datas - model # Compute in bbox only res = src.bbox.extract_from(res) chi2s[k] = np.sum(res**2) # Replace models with poor fits with StarletSource models if chi2s[k] > max_chi2: starlet_sources[k] = scarlet.StarletSource(model_frame, (catalog["y"][k], catalog["x"][k]), observation, thresh=morph_thresh, starlet_thresh=starlet_thresh ) # If any chi2 residuals are flagged, re-fit the blend with a more complex model if np.any(chi2s > max_chi2): print("Re-fitting with Starlet models for poorly-fit sources.") starlet_blend, logL = fit_scarlet_blend(starlet_sources, observation, catalog, max_iters=max_iters, plot_likelihood=plot_likelihood,savefigs=savefigs,figpath=figpath) """ # Extract the deblended catalog and update the chi2 residuals print("Extracting deblended catalog.") catalog_deblended = [] segmentation_masks = [] for k, src in enumerate(starlet_sources): model = src.get_model(frame=model_frame) model = observation.render(model) # Compute in bbox only model = src.bbox.extract_from(model) bkgmod = sep.Background(np.sum(model, axis=0)) # Run sep try: cat, _ = make_catalog( model, lvl_segmask, wave=False, segmentation_map=False, maskthresh=maskthresh, ) except: print(f"Exception with source {k}") cat = [] # if segmentation_map == True: # cat, mask = cat # If more than 1 source is detected for some reason (e.g. artifacts) if len(cat) > 1: # keep the brightest idx = np.argmax([c["cflux"] for c in cat]) cat = cat[idx] # if segmentation_map == True: # mask = mask[idx] # If failed to detect model source if len(cat) == 0: # Fill with nan cat = [np.full(catalog[0].shape, np.nan, dtype=catalog.dtype)] # Append to full catalog if segmentation_map == True: # For some reason sep doesn't like these images, so do the segmask ourselves for now # model_det = np.array(model[0,:,:]) model_det = np.sum(model, axis=0) mask = np.zeros_like(model_det) # mask[model_det>lvl_segmask*bg_rms_hsc[0]] = 1 mask[model_det > lvl_segmask * bkgmod.globalrms] = 1 segmentation_masks.append(mask) # plt.imshow(mask) # plt.show() catalog_deblended.append(cat) """ try: catalog,mask = sep.extract(model, lvl_segmask, err=bg_rms_hsc, segmentation_map=True) except: print('Exception with source {i} in file ') #exceptions.append(i) mask = np.zeros_like(model) mask[model>lvl_segmask*bg_rms_hsc] = 1 catalog=[] # If more than 1 source is detected for some reason (e.g. artifacts) if len(catalog) > 1: print('More than 1 source for object ', i, ' in file ') # keep the brightest idx = np.argmax([c['cflux'] for c in catalog]) catalog = catalog[idx] mask[mask!=idx]=0 segmentation_masks.append(mask) catalog_deblended.append(catalog) """ # Combine catalog named array catalog_deblended = np.vstack(catalog_deblended) # Plot scene: rendered model, observations, and residuals if plot_scene == True: _plot_scene( starlet_sources, observation, norm, catalog, show_model=False, show_rendered=True, show_observed=True, show_residual=True, add_labels=add_labels, add_boxes=add_boxes, add_ellipses=add_ellipses, savefigs=savefigs, figpath=figpath, ) # Plot each for each source if plot_sources == True: scarlet.display.show_sources( starlet_sources, observation, norm=norm, show_rendered=True, show_observed=True, add_boxes=add_boxes, ) if savefigs: plt.savefig(figpath + "sources.png") plt.show() return ( observation, starlet_sources, model_frame, catalog, catalog_deblended, segmentation_masks, )
[docs] def overlapped_slices(bbox1, bbox2): """Slices of bbox1 and bbox2 that overlap Parameters ---Z------- bbox1: `~scarlet.bbox.Box` bbox2: `~scarlet.bbox.Box` Returns ------- slices: tuple of slices The slice of an array bounded by `bbox1` and the slice of an array bounded by `bbox` in the overlapping region. """ overlap = bbox1 & bbox2 _bbox1 = overlap - bbox1.origin _bbox2 = overlap - bbox2.origin slices = ( _bbox1.slices, _bbox2.slices, ) return slices
[docs] def get_processed_hsc_DR3_data( filename, filters=["g", "r", "i"], dirpath="/home/g4merz/deblend/data/processed_HSC_DR3/lvl5/", stringcap=14, ): """ Get HSC data given tract/patch info or SkyCoord Parameters ---------- dirpath : str Path to HSC image file directory filters : list A list of filters for your images. Default is ['g', 'r', 'i']. tract : int An integer used for specifying the tract. Default is 10054 patch : [int, int] Patch #,#. Default is [0,0] coord : SkyCoord Astropy SkyCoord, when specified, overrides tract/patch info and attempts to lookup HSC filename from ra, dec. Default is None cutout_size: [int, int] Size of cutout to use (set to None for no cutting). Default is [128, 128] The image filepath is in the form: {dirpath}/deepCoadd/HSC-{filter}/{tract}/{patch[0]},{patch[1]}/calexp-HSC-{filter}-{tract}-{patch[0]},{patch[1]}.fits Returns ------- data : ndarray HSC data array with dimensions [filters, N, N] """ s = filename.split(f"G-")[1].split(".fits")[0] tract, patch, sp = s.split("-") tract = int(tract) patch = tuple(map(int, patch.split(","))) sp = int(sp[1:-stringcap]) filters = [f.upper() for f in filters] # if coord is not None: # print("Overriding tract/patch info and looking for HSC file at requested coordinates.") # tract, patch = get_tract_patch_from_coord(coord) datas = [] # models = [] # print(dirpath,'dirpath') for f in filters: impath = os.path.join(dirpath, f"{f}-{tract}-{patch[0]},{patch[1]}-c{sp}_scarlet_img.fits") modpath = os.path.join(dirpath, f"{f}-{tract}-{patch[0]},{patch[1]}-c{sp}_scarlet_model.fits") # print(impath, 'impath') # print(f'Loading "{filepath}".') try: with fits.open(impath) as obs_hdul: data = obs_hdul[0].data wcs = WCS(obs_hdul[0].header) # fits.close(impath) # mod_hdul = fits.open(modpath) # model = mod_hdul[1].data datas.append(data) # models.append(model) except: print("Missing filter ", f) return None return np.array(datas)
[docs] def return_model_objects( fiG, luptonize=False, stringcap=14, dirpath="/home/shared/hsc/HSC/HSC_DR3/data/train/", ): s = fiG.split(f"G-")[1].split(".fits")[0] tract, patch, sp = s.split("-") tract = int(tract) patch = tuple(map(int, patch.split(","))) sp = int(sp[1:-stringcap]) # print('Running on ', fiG) # fiR=f'/home/shared/hsc/HSC/HSC_DR3/data/train/R-{tract}-{patch[0]},{patch[1]}-c{sp}_scarlet_model.fits' # fiI=f'/home/shared/hsc/HSC/HSC_DR3/data/train/I-{tract}-{patch[0]},{patch[1]}-c{sp}_scarlet_model.fits' fiR = dirpath + f"R-{tract}-{patch[0]},{patch[1]}-c{sp}_scarlet_model.fits" fiI = dirpath + f"I-{tract}-{patch[0]},{patch[1]}-c{sp}_scarlet_model.fits" d = get_processed_hsc_DR3_data(fiG, dirpath=dirpath, stringcap=stringcap) fb = scarlet.bbox.Box(d.shape[1:]) model = np.zeros(d.shape) with fits.open(fiG) as sourcesG: ls = len(sourcesG) objectsG = [] objectsR = [] objectsI = [] for i, file in enumerate([fiG, fiR, fiI]): bandmodel = np.zeros(d.shape[1:]) l = 1 + i * (ls - 1) u = l + ls - 1 with fits.open(file) as sources: sources = fits.open(file) for src in sources[l:u]: srcmodel = src.data if i == 0: objectsG.append(srcmodel) elif i == 1: objectsR.append(srcmodel) elif i == 2: objectsI.append(srcmodel) bb = src.header["BBOX"] bb = bb.split(",") bmin1 = int(bb[0]) bmin2 = int(bb[1]) bs1 = int(bb[2]) bs2 = int(bb[3]) shape = [bs2, bs1] origin = [ bmin2 - int(np.floor(bs1 / 2)), bmin1 - int(np.floor(bs2 / 2)), ] # print(shape,origin) mb = scarlet.bbox.Box(shape, origin) frame_slices, model_slices = overlapped_slices(fb, mb) result = np.zeros(fb.shape) result[frame_slices] = srcmodel[model_slices] bandmodel += result model[i] = bandmodel # ps=PSNR(d,model,luptonize) return objectsG, objectsR, objectsI, model
[docs] def return_spliced_sources(sourceG, sourceR, sourceI): wmin = np.array([sourceG.shape[0], sourceR.shape[0], sourceI.shape[0]]).min() hmin = np.array([sourceG.shape[1], sourceR.shape[1], sourceI.shape[1]]).min() sources = [] for source in [sourceG, sourceR, sourceI]: if source.shape[0] > wmin: source = source[(source.shape[0] - wmin) // 2 : -(source.shape[0] - wmin) // 2, :] if source.shape[1] > hmin: source = source[:, (source.shape[1] - hmin) // 2 : -(source.shape[1] - hmin) // 2] sources.append(source) return sources