""" Copyright 2020 Tianshu AI Platform. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ============================================================= """ import abc import torch import torch.nn as nn import torch.nn.functional as F import sys import typing from typing import Callable, Dict, List, Any from collections import Mapping, Sequence from . import loss from kamal.core import metrics, exceptions from kamal.core.attach import AttachTo class Task(object): def __init__(self, name): self.name = name @abc.abstractmethod def get_loss( self, outputs, targets ) -> Dict: pass @abc.abstractmethod def predict(self, outputs) -> Any: pass class GeneralTask(Task): def __init__(self, name: str, loss_fn: Callable, scaling:float=1.0, pred_fn: Callable=lambda x: x, attach_to=None): super(GeneralTask, self).__init__(name) self._attach = AttachTo(attach_to) self.loss_fn = loss_fn self.pred_fn = pred_fn self.scaling = scaling def get_loss(self, outputs, targets): outputs, targets = self._attach(outputs, targets) return { self.name: self.loss_fn( outputs, targets ) * self.scaling } def predict(self, outputs): outputs = self._attach(outputs) return self.pred_fn(outputs) def __repr__(self): rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) return rep class TaskCompose(list): def __init__(self, tasks: list): for task in tasks: if isinstance(task, Task): self.append(task) def get_loss(self, outputs, targets): loss_dict = {} for task in self: loss_dict.update( task.get_loss( outputs, targets ) ) return loss_dict def predict(self, outputs): results = [] for task in self: results.append( task.predict( outputs ) ) return results def __repr__(self): rep="TaskCompose: \n" for task in self: rep+="\t%s\n"%task class StandardTask: @staticmethod def classification(name='ce', scaling=1.0, attach_to=None): return GeneralTask( name=name, loss_fn=nn.CrossEntropyLoss(), scaling=scaling, pred_fn=lambda x: x.max(1)[1], attach_to=attach_to ) @staticmethod def binary_classification(name='bce', scaling=1.0, attach_to=None): return GeneralTask(name=name, loss_fn=F.binary_cross_entropy_with_logits, scaling=scaling, pred_fn=lambda x: (x>0.5), attach_to=attach_to ) @staticmethod def regression(name='mse', scaling=1.0, attach_to=None): return GeneralTask(name=name, loss_fn=nn.MSELoss(), scaling=scaling, pred_fn=lambda x: x, attach_to=attach_to ) @staticmethod def segmentation(name='ce', scaling=1.0, attach_to=None): return GeneralTask(name=name, loss_fn=nn.CrossEntropyLoss(ignore_index=255), scaling=scaling, pred_fn=lambda x: x.max(1)[1], attach_to=attach_to ) @staticmethod def monocular_depth(name='l1', scaling=1.0, attach_to=None): return GeneralTask(name=name, loss_fn=nn.L1Loss(), scaling=scaling, pred_fn=lambda x: x, attach_to=attach_to) @staticmethod def detection(): raise NotImplementedError @staticmethod def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): return GeneralTask(name=name, loss_fn=loss.KLDiv(T=T), scaling=scaling, pred_fn=lambda x: x.max(1)[1], attach_to=attach_to) class StandardMetrics(object): @staticmethod def classification(attach_to=None): return metrics.MetricCompose( metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} ) @staticmethod def regression(attach_to=None): return metrics.MetricCompose( metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} ) @staticmethod def segmentation(num_classes, ignore_idx=255, attach_to=None): confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) return metrics.MetricCompose( metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), 'confusion_matrix': confusion_matrix , 'miou': metrics.mIoU(confusion_matrix)} ) @staticmethod def monocular_depth(attach_to=None): return metrics.MetricCompose( metric_dict={ 'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), 'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), 'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), 'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), 'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), 'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) } ) @staticmethod def loss_metric(loss_fn): return metrics.MetricCompose( metric_dict={ 'loss': metrics.AverageMetric( loss_fn ) } )