""" Copyright 2020 Tianshu AI Platform. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ============================================================= """ from .base import Callback from typing import Callable, Union, Sequence import weakref import random from kamal.utils import move_to_device, set_mode, split_batch, colormap from kamal.core.attach import AttachTo import torch import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('agg') import math import numbers class VisualizeOutputs(Callback): def __init__(self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5, normalizer: Callable=None, prepare_fn: Callable=None, decode_fn: Callable=None, # decode targets and preds tag: str='viz'): super(VisualizeOutputs, self).__init__() self._dataset = dataset self._model = weakref.ref(model) if isinstance(idx_list_or_num_vis, int): self.idx_list = self._get_vis_idx_list(self._dataset, idx_list_or_num_vis) elif isinstance(idx_list_or_num_vis, Sequence): self.idx_list = idx_list_or_num_vis self._normalizer = normalizer self._decode_fn = decode_fn if prepare_fn is None: prepare_fn = VisualizeOutputs.get_prepare_fn() self._prepare_fn = prepare_fn self._tag = tag def _get_vis_idx_list(self, dataset, num_vis): return random.sample(list(range(len(dataset))), num_vis) @torch.no_grad() def __call__(self, trainer): if trainer.tb_writer is None: trainer.logger.warning("summary writer was not found in trainer") return device = trainer.device model = self._model() with torch.no_grad(), set_mode(model, training=False): for img_id, idx in enumerate(self.idx_list): batch = move_to_device(self._dataset[idx], device) batch = [ d.unsqueeze(0) for d in batch ] inputs, targets, preds = self._prepare_fn(model, batch) if self._normalizer is not None: inputs = self._normalizer(inputs) inputs = inputs.detach().cpu().numpy() preds = preds.detach().cpu().numpy() targets = targets.detach().cpu().numpy() if self._decode_fn: # to RGB 0~1 NCHW preds = self._decode_fn(preds) targets = self._decode_fn(targets) inputs = inputs[0] preds = preds[0] targets = targets[0] trainer.tb_writer.add_images("%s-%d"%(self._tag, img_id), np.stack( [inputs, targets, preds], axis=0), global_step=trainer.state.iter) @staticmethod def get_prepare_fn(attach_to=None, pred_fn=lambda x: x): attach_to = AttachTo(attach_to) def wrapper(model, batch): inputs, targets = split_batch(batch) outputs = model(inputs) outputs, targets = attach_to(outputs, targets) return inputs, targets, pred_fn(outputs) return wrapper @staticmethod def get_seg_decode_fn(cmap=colormap(), index_transform=lambda x: x+1): # 255->0, 0->1, def wrapper(preds): if len(preds.shape)>3: preds = preds.squeeze(1) out = cmap[ index_transform(preds.astype('uint8')) ] out = out.transpose(0, 3, 1, 2) / 255 return out return wrapper @staticmethod def get_depth_decode_fn(max_depth, log_scale=True, cmap=plt.get_cmap('jet')): def wrapper(preds): if log_scale: _max_depth = np.log( max_depth ) preds = np.log( preds ) else: _max_depth = max_depth if len(preds.shape)>3: preds = preds.squeeze(1) out = (cmap(preds.clip(0, _max_depth)/_max_depth)).transpose(0, 3, 1, 2)[:, :3] return out return wrapper class VisualizeSegmentation(VisualizeOutputs): def __init__( self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5, cmap = colormap(), attach_to=None, normalizer: Callable=None, prepare_fn: Callable=None, decode_fn: Callable=None, tag: str='seg' ): if prepare_fn is None: prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x.max(1)[1]) if decode_fn is None: decode_fn = VisualizeOutputs.get_seg_decode_fn(cmap=cmap, index_transform=lambda x: x+1) super(VisualizeSegmentation, self).__init__( model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, tag=tag ) class VisualizeDepth(VisualizeOutputs): def __init__( self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5, max_depth = 10, log_scale = True, attach_to = None, normalizer: Callable=None, prepare_fn: Callable=None, decode_fn: Callable=None, tag: str='depth' ): if prepare_fn is None: prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x) if decode_fn is None: decode_fn = VisualizeOutputs.get_depth_decode_fn(max_depth=max_depth, log_scale=log_scale) super(VisualizeDepth, self).__init__( model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, tag=tag )