From 447746d9f556d3052ca96400b1b538b545f04220 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 7 Dec 2018 13:22:04 +0800 Subject: [PATCH] * remove unused codes in losses.py & metrics.py * refine code style * fix tests * add a new tutorial --- fastNLP/core/losses.py | 115 +--------------- fastNLP/core/metrics.py | 183 +----------------------- fastNLP/io/dataset_loader.py | 12 ++ test/core/test_loss.py | 260 +++-------------------------------- test/core/test_metrics.py | 6 +- tutorials/fastnlp_in_six_lines.ipynb | 81 +++++++++++ 6 files changed, 119 insertions(+), 538 deletions(-) create mode 100644 tutorials/fastnlp_in_six_lines.ipynb diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index ed935c9d..757ce465 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -195,6 +195,7 @@ class CrossEntropyLoss(LossBase): return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx) + class L1Loss(LossBase): def __init__(self, pred=None, target=None): super(L1Loss, self).__init__() @@ -212,6 +213,7 @@ class BCELoss(LossBase): def get_loss(self, pred, target): return F.binary_cross_entropy(input=pred, target=target) + class NLLLoss(LossBase): def __init__(self, pred=None, target=None): super(NLLLoss, self).__init__() @@ -259,7 +261,7 @@ def _prepare_losser(losser): elif isinstance(losser, LossBase): return losser else: - raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}") + raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}") def squash(predict, truth, **kwargs): @@ -354,114 +356,3 @@ def make_mask(lens, tar_len): mask = torch.stack(mask, 1) return mask - -# map string to function. Just for more elegant using -method_dict = { - "squash": squash, - "unpad": unpad, - "unpad_mask": unpad_mask, - "mask": mask, -} - -loss_function_name = { - "L1Loss".lower(): torch.nn.L1Loss, - "BCELoss".lower(): torch.nn.BCELoss, - "MSELoss".lower(): torch.nn.MSELoss, - "NLLLoss".lower(): torch.nn.NLLLoss, - "KLDivLoss".lower(): torch.nn.KLDivLoss, - "NLLLoss2dLoss".lower(): torch.nn.NLLLoss2d, # every name should end with "loss" - "SmoothL1Loss".lower(): torch.nn.SmoothL1Loss, - "SoftMarginLoss".lower(): torch.nn.SoftMarginLoss, - "PoissonNLLLoss".lower(): torch.nn.PoissonNLLLoss, - "MultiMarginLoss".lower(): torch.nn.MultiMarginLoss, - "CrossEntropyLoss".lower(): torch.nn.CrossEntropyLoss, - "BCEWithLogitsLoss".lower(): torch.nn.BCEWithLogitsLoss, - "MarginRankingLoss".lower(): torch.nn.MarginRankingLoss, - "TripletMarginLoss".lower(): torch.nn.TripletMarginLoss, - "HingeEmbeddingLoss".lower(): torch.nn.HingeEmbeddingLoss, - "CosineEmbeddingLoss".lower(): torch.nn.CosineEmbeddingLoss, - "MultiLabelMarginLoss".lower(): torch.nn.MultiLabelMarginLoss, - "MultiLabelSoftMarginLoss".lower(): torch.nn.MultiLabelSoftMarginLoss, -} - - -class LossFromTorch(object): - """a LossFromTorch object is a callable object represents loss functions - - This class only helps you with loss functions from PyTorch. - It has nothing to do with Trainer. - """ - - def __init__(self, loss_name, pre_pro=[squash], **kwargs): - """ - - :param loss_name: str or None , the name of loss function - :param pre_pro : list of function or str, methods to reform parameters before calculating loss - the strings will be auto translated to pre-defined functions - :param **kwargs: kwargs for torch loss function - - pre_pro funcsions should have three arguments: predict, truth, **arg - predict and truth is the necessary parameters in loss function - kwargs is the extra parameters passed-in when calling loss function - pre_pro functions should return two objects, respectively predict and truth that after processed - - """ - - if loss_name is None: - # this is useful when Trainer.__init__ performs type check - self._loss = None - else: - if not isinstance(loss_name, str): - raise NotImplementedError - else: - self._loss = self._get_loss(loss_name, **kwargs) - - self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] - - def add_pre_pro(self, func): - """add a pre_pro function - - :param func: a function or str, methods to reform parameters before calculating loss - the strings will be auto translated to pre-defined functions - """ - if not callable(func): - func = method_dict.get(func) - if func is None: - return - self.pre_pro.append(func) - - @staticmethod - def _get_loss(loss_name, **kwargs): - """Get loss function from torch - - :param loss_name: str, the name of loss function - :param **kwargs: kwargs for torch loss function - :return: A callable loss function object - """ - loss_name = loss_name.strip().lower() - loss_name = "".join(loss_name.split("_")) - - if len(loss_name) < 4 or loss_name[-4:] != "loss": - loss_name += "loss" - return loss_function_name[loss_name](**kwargs) - - def get(self): - """This method exists just for make some existing codes run error-freely - """ - return self - - def __call__(self, predict, truth, **kwargs): - """Call a loss function - predict and truth will be processed by pre_pro methods in order of addition - - :param predict : Tensor, model output - :param truth : Tensor, truth from dataset - :param **kwargs : extra arguments, pass to pre_pro functions - for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens - """ - for f in self.pre_pro: - if f is None: - continue - predict, truth = f(predict, truth, **kwargs) - - return self._loss(predict, truth) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 929d6ee1..34a90d5a 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -1,5 +1,4 @@ import inspect -import warnings from collections import defaultdict import numpy as np @@ -197,19 +196,19 @@ class AccuracyMetric(MetricBase): """ fast_param = {} targets = list(target_dict.values()) - if len(targets)==1 and isinstance(targets[0], torch.Tensor): - if len(pred_dict)==1: + if len(targets) == 1 and isinstance(targets[0], torch.Tensor): + if len(pred_dict) == 1: pred = list(pred_dict.values())[0] fast_param['pred'] = pred - elif len(pred_dict)==2: + elif len(pred_dict) == 2: pred1 = list(pred_dict.values())[0] pred2 = list(pred_dict.values())[1] if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): return fast_param - if len(pred1.size())len(pred2.size()) and len(pred2.size())==1: + elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1: seq_lens = pred2 pred = pred1 else: @@ -308,178 +307,6 @@ def _prepare_metrics(metrics): return _metrics -""" - Attention: Codes below are not used in current FastNLP. - However, it is useful. - -""" - - -def _conver_numpy(x): - """convert input data to numpy array - - """ - if isinstance(x, np.ndarray): - return x - elif isinstance(x, torch.Tensor): - return x.numpy() - elif isinstance(x, list): - return np.array(x) - raise TypeError('cannot accept object: {}'.format(x)) - - -def _check_same_len(*arrays, axis=0): - """check if input array list has same length for one dimension - - """ - lens = set([x.shape[axis] for x in arrays if x is not None]) - return len(lens) == 1 - - -def _label_types(y): - """Determine the type - - "binary" - - "multiclass" - - "multiclass-multioutput" - - "multilabel" - - "unknown" - """ - # never squeeze the first dimension - y = y.squeeze() if y.shape[0] > 1 else y.resize(1, -1) - shape = y.shape - if len(shape) < 1: - raise ValueError('cannot accept data: {}'.format(y)) - if len(shape) == 1: - return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y - if len(shape) == 2: - return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y - return 'unknown', y - - -def _check_data(y_true, y_pred): - """Check if y_true and y_pred is same type of data e.g both binary or multiclass - - """ - y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) - if not _check_same_len(y_true, y_pred): - raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) - type_true, y_true = _label_types(y_true) - type_pred, y_pred = _label_types(y_pred) - - type_set = {'binary', 'multiclass'} - if type_true in type_set and type_pred in type_set: - return type_true if type_true == type_pred else 'multiclass', y_true, y_pred - - type_set = {'multiclass-multioutput', 'multilabel'} - if type_true in type_set and type_pred in type_set: - return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred - - raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) - - -def _weight_sum(y, normalize=True, sample_weight=None): - if normalize: - return np.average(y, weights=sample_weight) - if sample_weight is None: - return y.sum() - else: - return np.dot(y, sample_weight) - - -def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): - y_type, y_true, y_pred = _check_data(y_true, y_pred) - if y_type == 'multiclass-multioutput': - raise ValueError('cannot accept data type {0}'.format(y_type)) - if y_type == 'multilabel': - equel = (y_true == y_pred).sum(1) - count = equel == y_true.shape[1] - else: - count = y_true == y_pred - return _weight_sum(count, normalize=normalize, sample_weight=sample_weight) - - -def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): - y_type, y_true, y_pred = _check_data(y_true, y_pred) - if average == 'binary': - if y_type != 'binary': - raise ValueError("data type is {} but use average type {}".format(y_type, average)) - else: - pos = (y_true == pos_label) - tp = np.logical_and((y_true == y_pred), pos).sum() - pos_sum = pos.sum() - return tp / pos_sum if pos_sum > 0 else 0 - elif average == None: - y_labels = set(list(np.unique(y_true))) - if labels is None: - labels = list(y_labels) - else: - for i in labels: - if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): - warnings.warn('label {} is not contained in data'.format(i), UserWarning) - - if y_type in ['binary', 'multiclass']: - y_pred_right = y_true == y_pred - pos_list = [y_true == i for i in labels] - pos_sum_list = [pos_i.sum() for pos_i in pos_list] - return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ - for pos_i, sum_i in zip(pos_list, pos_sum_list)]) - elif y_type == 'multilabel': - y_pred_right = y_true == y_pred - pos = (y_true == pos_label) - tp = np.logical_and(y_pred_right, pos).sum(0) - pos_sum = pos.sum(0) - return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) - else: - raise ValueError('not support targets type {}'.format(y_type)) - raise ValueError('not support for average type {}'.format(average)) - - -def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): - y_type, y_true, y_pred = _check_data(y_true, y_pred) - if average == 'binary': - if y_type != 'binary': - raise ValueError("data type is {} but use average type {}".format(y_type, average)) - else: - pos = (y_true == pos_label) - tp = np.logical_and((y_true == y_pred), pos).sum() - pos_pred = (y_pred == pos_label).sum() - return tp / pos_pred if pos_pred > 0 else 0 - elif average == None: - y_labels = set(list(np.unique(y_true))) - if labels is None: - labels = list(y_labels) - else: - for i in labels: - if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): - warnings.warn('label {} is not contained in data'.format(i), UserWarning) - - if y_type in ['binary', 'multiclass']: - y_pred_right = y_true == y_pred - pos_list = [y_true == i for i in labels] - pos_sum_list = [(y_pred == i).sum() for i in labels] - return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ - for pos_i, sum_i in zip(pos_list, pos_sum_list)]) - elif y_type == 'multilabel': - y_pred_right = y_true == y_pred - pos = (y_true == pos_label) - tp = np.logical_and(y_pred_right, pos).sum(0) - pos_sum = (y_pred == pos_label).sum(0) - return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) - else: - raise ValueError('not support targets type {}'.format(y_type)) - raise ValueError('not support for average type {}'.format(average)) - - -def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): - precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) - recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) - if isinstance(precision, np.ndarray): - res = 2 * precision * recall / (precision + recall + 1e-10) - res[(precision + recall) <= 0] = 0 - return res - return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 - - def accuracy_topk(y_true, y_prob, k=1): """Compute accuracy of y_true matching top-k probable labels in y_prob. diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index fc2edb23..0d30c6e8 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -78,6 +78,18 @@ class DataSetLoader(BaseLoader): raise NotImplementedError +@DataSet.set_reader("read_naive") +class NativeDataSetLoader(DataSetLoader): + def __init__(self): + super(NativeDataSetLoader, self).__init__() + + def load(self, path): + ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t") + ds.set_input("raw_sentence") + ds.set_target("label") + return ds + + @DataSet.set_reader('read_raw') class RawDataSetLoader(DataSetLoader): def __init__(self): diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 52860b36..a6d542fa 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -1,253 +1,13 @@ -import math import unittest import torch -import torch as tc import torch.nn.functional as F import fastNLP.core.losses as loss +from fastNLP.core.losses import squash, unpad class TestLoss(unittest.TestCase): - - def test_case_1(self): - loss_func = loss.LossFunc(F.nll_loss) - nll_loss = loss.NLLLoss() - y = tc.Tensor( - [ - [.3, .4, .3], - [.5, .3, .2], - [.3, .6, .1], - ] - ) - - gy = tc.LongTensor( - [ - 0, - 1, - 2, - ] - ) - - y = tc.log(y) - los = loss_func({'input': y}, {'target': gy}) - losses = nll_loss({'input': y}, {'target': gy}) - - r = -math.log(.3) - math.log(.3) - math.log(.1) - r /= 3 - print("loss = %f" % (los)) - print("r = %f" % (r)) - print("nll_loss = %f" % (losses)) - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_2(self): - # 验证squash()的正确性 - - log = math.log - loss_func = loss.LossFromTorch("nll") - - y = tc.Tensor( - [ - [[.3, .4, .3], [.3, .4, .3], ], - [[.5, .3, .2], [.1, .2, .7], ], - [[.3, .6, .1], [.2, .1, .7], ], - ] - ) - - gy = tc.LongTensor( - [ - [0, 2], - [1, 2], - [2, 1], - ] - ) - - y = tc.log(y) - # los = loss_func({'input': y}, {'target': gy}) - los = loss_func(y, gy) - - r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) - r /= 6 - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_3(self): - # 验证pack_padded_sequence()的正确性 - log = math.log - loss_func = loss.NLLLoss() - y = tc.Tensor( - [ - [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], ], - [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], ], - [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], ], - ] - ) - - gy = tc.LongTensor( - [ - [0, 2, 1, ], - [1, 2, 0, ], - [2, 0, 0, ], - ] - ) - - lens = [3, 2, 1] - - # pdb.set_trace() - - y = tc.log(y) - - yy = tc.nn.utils.rnn.pack_padded_sequence(y, lens, batch_first=True).data - gyy = tc.nn.utils.rnn.pack_padded_sequence(gy, lens, batch_first=True).data - los = loss_func({'input': yy}, {'target': gyy}) - - r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) - r /= 6 - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_4(self): - # 验证unpad()的正确性 - log = math.log - y = tc.Tensor( - [ - [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], - [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], - [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], - ] - ) - - gy = tc.LongTensor( - [ - [0, 2, 1, 2, ], - [1, 2, 0, 0, ], - [2, 0, 0, 0, ], - ] - ) - - lens = [4, 2, 1] - y = tc.log(y) - - loss_func = loss.LossFromTorch("nll", pre_pro=["unpad"]) - los = loss_func(y, gy, lens=lens) - - r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) - r /= 7 - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_5(self): - # 验证mask()和make_mask()的正确性 - log = math.log - - y = tc.Tensor( - [ - [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], - [[.5, .4, .1], [.3, .2, .5], [.4, .5, .1, ], [.6, .1, .3, ], ], - [[.3, .6, .1], [.3, .2, .5], [.0, .0, .0, ], [.0, .0, .0, ], ], - ] - ) - - gy = tc.LongTensor( - [ - [1, 2, 0, 0, ], - [0, 2, 1, 2, ], - [2, 1, 0, 0, ], - ] - ) - - mask = tc.ByteTensor( - [ - [1, 1, 0, 0, ], - [1, 1, 1, 1, ], - [1, 1, 0, 0, ], - ] - ) - - y = tc.log(y) - - lens = [2, 4, 2] - - loss_func = loss.LossFromTorch("nll", pre_pro=["mask"]) - los = loss_func(y, gy, mask=mask) - - los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1])) - - r = -log(.3) - log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2) - r /= 8 - - self.assertEqual(int(los * 1000), int(r * 1000)) - self.assertEqual(int(los2 * 1000), int(r * 1000)) - - def test_case_6(self): - # 验证unpad_mask()的正确性 - log = math.log - y = tc.Tensor( - [ - [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], - [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], - [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], - ] - ) - - gy = tc.LongTensor( - [ - [0, 2, 1, 2, ], - [1, 2, 0, 0, ], - [2, 0, 0, 0, ], - ] - ) - - lens = [4, 2, 1] - - # pdb.set_trace() - - y = tc.log(y) - - loss_func = loss.LossFromTorch("nll", pre_pro=["unpad_mask"]) - los = loss_func(y, gy, lens=lens) - - r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) - r /= 7 - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_7(self): - # 验证一些其他东西 - log = math.log - y = tc.Tensor( - [ - [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], - [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], - [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], - ] - ) - - gy = tc.LongTensor( - [ - [0, 2, 1, 2, ], - [1, 2, 0, 0, ], - [2, 0, 0, 0, ], - ] - ) - - lens = [4, 2, 1] - y = tc.log(y) - - loss_func = loss.LossFromTorch("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0])) - loss_func.add_pre_pro("unpad_mask") - los = loss_func(y, gy, lens=lens) - - r = - log(.3) - log(.5) - log(.3) - r /= 3 - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_8(self): - pass - - -class TestLoss_v2(unittest.TestCase): def test_CrossEntropyLoss(self): ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth") a = torch.randn(3, 5, requires_grad=False) @@ -276,6 +36,7 @@ class TestLoss_v2(unittest.TestCase): ans = l1({"my_predict": a}, {"my_truth": b}) self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) + class TestLosserError(unittest.TestCase): def test_losser1(self): # (1) only input, targets passed @@ -292,11 +53,12 @@ class TestLosserError(unittest.TestCase): target_dict = {'target': torch.zeros(16, 3).long()} los = loss.CrossEntropyLoss() - # print(los(pred_dict=pred_dict, target_dict=target_dict)) + with self.assertRaises(RuntimeError): + print(los(pred_dict=pred_dict, target_dict=target_dict)) def test_losser3(self): # (2) with corrupted size - pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0} + pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} target_dict = {'target': torch.zeros(16).long()} los = loss.CrossEntropyLoss() @@ -311,3 +73,15 @@ class TestLosserError(unittest.TestCase): with self.assertRaises(Exception): ans = l1({"my_predict": a}, {"truth": b, "my": a}) + + +class TestLossUtils(unittest.TestCase): + def test_squash(self): + a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) + self.assertEqual(tuple(a.size()), (3, 5)) + self.assertEqual(tuple(b.size()), (15,)) + + def test_unpad(self): + a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) + self.assertEqual(tuple(a.size()), (5, 8, 3)) + self.assertEqual(tuple(b.size()), (5, 8)) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index d2e45379..c6267664 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -4,7 +4,7 @@ import numpy as np import torch from fastNLP.core.metrics import AccuracyMetric -from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score, pred_topk, accuracy_topk +from fastNLP.core.metrics import pred_topk, accuracy_topk class TestAccuracyMetric(unittest.TestCase): @@ -139,10 +139,6 @@ class TestUsefulFunctions(unittest.TestCase): # 测试metrics.py中一些看上去挺有用的函数 def test_case_1(self): # multi-class - _ = accuracy_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1))) - _ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) - _ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) - _ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) _ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) _ = pred_topk(np.random.randint(0, 3, size=(10, 1))) diff --git a/tutorials/fastnlp_in_six_lines.ipynb b/tutorials/fastnlp_in_six_lines.ipynb new file mode 100644 index 00000000..2d8f40d7 --- /dev/null +++ b/tutorials/fastnlp_in_six_lines.ipynb @@ -0,0 +1,81 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# 六行代码搞定FastNLP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core.dataset import DataSet\n", + "import fastNLP.io.dataset_loader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = DataSet.read_naive(\"../test/data_for_tests/tutorial_sample_dataset.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}