Source code for deepdisc.inference.match_objects

import numpy as np
import torch
from detectron2 import structures
from detectron2.structures import BoxMode
from astropy.wcs import WCS

import deepdisc.astrodet.astrodet as toolkit
from deepdisc.inference.predictors import get_predictions, get_predictions_new


[docs] def get_matched_object_inds(dataset_dict, outputs, IOUthresh = 0.5): """Returns indices for matched pairs of ground truth and detected objects in an image Parameters ---------- dataset_dict : dictionary The dictionary metadata for a single image IOUthresh : float The IOU threshold used to match detections and ground truth Returns ------- matched_gts: list(int) The indices of matched objects in the ground truth list matched_dts: list(int) The indices of matched objects in the detections list outputs: list(Intances) The list of detected object Instances """ gt_boxes = np.array([a["bbox"] for a in dataset_dict["annotations"]]) # Convert to the mode model expects # Make sure the input bboxes are in XYWH mode so they can be converted here gt_boxes = BoxMode.convert(gt_boxes, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) gt_boxes = structures.Boxes(torch.Tensor(gt_boxes)) pred_boxes = outputs["instances"].pred_boxes pred_boxes = pred_boxes.to("cpu") IOUs = structures.pairwise_iou(pred_boxes, gt_boxes).numpy() # matched_gts holds the indices of the ground truth annotations that correspond to the matched detections # matched_dts holds the indices of the detections that corresponds to the ground truth annotations matched_gts = [] matched_dts = [] for i, dt in enumerate(IOUs): IOU = dt[dt.argmax()] if IOU >= IOUthresh: matched_gts.append(dt.argmax()) matched_dts.append(i) return matched_gts, matched_dts
[docs] def get_object_coords(dataset_dict, outputs): """Returns indices for matched pairs of ground truth and detected objects in an image Parameters ---------- dataset_dict : dictionary The dictionary metadata containing the wcs for a single image Returns ------- matched_gts: list(int) The indices of matched objects in the ground truth list matched_dts: list(int) The indices of matched objects in the detections list outputs: list(Intances) The list of detected object Instances """ wcs = WCS(dataset_dict['wcs']) pred_boxes = outputs["instances"].pred_boxes pred_boxes = pred_boxes.to("cpu") centers = outputs['instances'].pred_boxes.get_centers().cpu().numpy() coords = wcs.pixel_to_world(centers[:,0],centers[:,1]) ras = coords.ra.degree decs = coords.dec.degree return ras, decs
[docs] def get_matched_object_classes(dataset_dicts, imreader, key_mapper, predictor): """Returns object classes for matched pairs of ground truth and detected objects in an image Parameters ---------- dataset_dicts : list[dict] The dictionary metadata for a test images imreader: ImageReader object An object derived from ImageReader base class to read the images. key_mapper: function The key_mapper should take a dataset_dict as input and return the key used by imreader predictor: AstroPredictor The predictor object used to make predictions on the test set Returns ------- true_classes: list(int) The classes of matched objects in the ground truth list pred_classes: list(int) The classes of matched objects in the detections list """ true_classes = [] pred_classes = [] for d in dataset_dicts: outputs = get_predictions(d, imreader, key_mapper, predictor) matched_gts, matched_dts = get_matched_object_inds(d, outputs) for gti, dti in zip(matched_gts, matched_dts): true_class = d["annotations"][int(gti)]["category_id"] pred_class = outputs["instances"].pred_classes.cpu().detach().numpy()[int(dti)] true_classes.append(true_class) pred_classes.append(pred_class) return true_classes, pred_classes
[docs] def get_matched_object_classes_new(dataset_dicts, predictor): """Returns object classes for matched pairs of ground truth and detected objects test images assuming the dataset_dicts have the image HxWxC in the 'image_shaped' field Parameters ---------- dataset_dicts : list[dict] The dictionary metadata for a test images predictor: AstroPredictor The predictor object used to make predictions on the test set Returns ------- true_classes: list(int) The classes of matched objects in the ground truth list pred_classes: list(int) The classes of matched objects in the detections list """ true_classes = [] pred_classes = [] for d in dataset_dicts: outputs = get_predictions_new(d, predictor) matched_gts, matched_dts = get_matched_object_inds(d, outputs) for gti, dti in zip(matched_gts, matched_dts): true_class = d["annotations"][int(gti)]["category_id"] pred_class = outputs["instances"].pred_classes.cpu().detach().numpy()[int(dti)] true_classes.append(true_class) pred_classes.append(pred_class) return true_classes, pred_classes
[docs] def run_batched_match_class(dataloader, predictor): """ Test function not yet implemented for batch prediction """ true_classes = [] pred_classes = [] with torch.no_grad(): for i, dataset_dicts in enumerate(dataloader): batched_outputs = predictor.model(dataset_dicts) for outputs,d in zip(batched_outputs, dataset_dicts): matched_gts, matched_dts = get_matched_object_inds(d, outputs) for gti, dti in zip(matched_gts, matched_dts): true_class = d["annotations"][int(gti)]["category_id"] pred_class = outputs["instances"].pred_classes.cpu().detach().numpy()[int(dti)] true_classes.append(true_class) pred_classes.append(pred_class) return true_classes, pred_classes
[docs] def run_batched_get_object_coords(dataloader, predictor, oclass=True, gmm=False): """Returns object classes for matched pairs of ground truth and detected objects test images assuming the dataset_dicts have the image HxWxC in the 'image_shaped' field Parameters ---------- dataloader : Dataloader Dataloader that reads in images and formats input for the model predictor: AstroPredictor The predictor object used to make predictions on the test set Returns ------- zpreds: list(float) The predicted reshifts of detected objects all_ras: list(float) The RAs of detected objects all_decs: list(float) The DECs of detected objects oclasses: list(int) The predicted classes of detected objects scores: list(float) The confidence scores of detected objects """ zpreds = [] all_decs = [] all_ras = [] scores = [] if gmm: gmms =[] if oclass: oclasses=[] with torch.no_grad(): for i, dataset_dicts in enumerate(dataloader): batched_outputs = predictor.model(dataset_dicts) for outputs,d in zip(batched_outputs, dataset_dicts): ras,decs = get_object_coords(d, outputs) #all_ras.append(*ras) #all_decs.append(*decs) list(map(all_ras.append, ras)) list(map(all_decs.append, decs)) for dti in range(len(outputs['instances'])): #ztrue = d["annotations"][int(gti)]["redshift"] pdf = np.exp(outputs["instances"].pred_redshift_pdf[int(dti)].cpu().numpy()) zpreds.append(pdf) gmms.append(outputs['instances'].pred_gmm.cpu()[int(dti)].cpu().numpy()) oclasses.append(outputs['instances'].pred_classes.cpu()[int(dti)].numpy()) scores.append(outputs['instances'].scores.cpu()[int(dti)].numpy()) if gmm: return zpreds, all_ras, all_decs, oclasses, gmms, scores else: return zpreds, all_ras, all_decs, oclasses, scores