* refine code style * fix tests * add a new tutorialtags/v0.4.10
@@ -195,6 +195,7 @@ class CrossEntropyLoss(LossBase): | |||
return F.cross_entropy(input=pred, target=target, | |||
ignore_index=self.padding_idx) | |||
class L1Loss(LossBase): | |||
def __init__(self, pred=None, target=None): | |||
super(L1Loss, self).__init__() | |||
@@ -212,6 +213,7 @@ class BCELoss(LossBase): | |||
def get_loss(self, pred, target): | |||
return F.binary_cross_entropy(input=pred, target=target) | |||
class NLLLoss(LossBase): | |||
def __init__(self, pred=None, target=None): | |||
super(NLLLoss, self).__init__() | |||
@@ -259,7 +261,7 @@ def _prepare_losser(losser): | |||
elif isinstance(losser, LossBase): | |||
return losser | |||
else: | |||
raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}") | |||
raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}") | |||
def squash(predict, truth, **kwargs): | |||
@@ -354,114 +356,3 @@ def make_mask(lens, tar_len): | |||
mask = torch.stack(mask, 1) | |||
return mask | |||
# map string to function. Just for more elegant using | |||
method_dict = { | |||
"squash": squash, | |||
"unpad": unpad, | |||
"unpad_mask": unpad_mask, | |||
"mask": mask, | |||
} | |||
loss_function_name = { | |||
"L1Loss".lower(): torch.nn.L1Loss, | |||
"BCELoss".lower(): torch.nn.BCELoss, | |||
"MSELoss".lower(): torch.nn.MSELoss, | |||
"NLLLoss".lower(): torch.nn.NLLLoss, | |||
"KLDivLoss".lower(): torch.nn.KLDivLoss, | |||
"NLLLoss2dLoss".lower(): torch.nn.NLLLoss2d, # every name should end with "loss" | |||
"SmoothL1Loss".lower(): torch.nn.SmoothL1Loss, | |||
"SoftMarginLoss".lower(): torch.nn.SoftMarginLoss, | |||
"PoissonNLLLoss".lower(): torch.nn.PoissonNLLLoss, | |||
"MultiMarginLoss".lower(): torch.nn.MultiMarginLoss, | |||
"CrossEntropyLoss".lower(): torch.nn.CrossEntropyLoss, | |||
"BCEWithLogitsLoss".lower(): torch.nn.BCEWithLogitsLoss, | |||
"MarginRankingLoss".lower(): torch.nn.MarginRankingLoss, | |||
"TripletMarginLoss".lower(): torch.nn.TripletMarginLoss, | |||
"HingeEmbeddingLoss".lower(): torch.nn.HingeEmbeddingLoss, | |||
"CosineEmbeddingLoss".lower(): torch.nn.CosineEmbeddingLoss, | |||
"MultiLabelMarginLoss".lower(): torch.nn.MultiLabelMarginLoss, | |||
"MultiLabelSoftMarginLoss".lower(): torch.nn.MultiLabelSoftMarginLoss, | |||
} | |||
class LossFromTorch(object): | |||
"""a LossFromTorch object is a callable object represents loss functions | |||
This class only helps you with loss functions from PyTorch. | |||
It has nothing to do with Trainer. | |||
""" | |||
def __init__(self, loss_name, pre_pro=[squash], **kwargs): | |||
""" | |||
:param loss_name: str or None , the name of loss function | |||
:param pre_pro : list of function or str, methods to reform parameters before calculating loss | |||
the strings will be auto translated to pre-defined functions | |||
:param **kwargs: kwargs for torch loss function | |||
pre_pro funcsions should have three arguments: predict, truth, **arg | |||
predict and truth is the necessary parameters in loss function | |||
kwargs is the extra parameters passed-in when calling loss function | |||
pre_pro functions should return two objects, respectively predict and truth that after processed | |||
""" | |||
if loss_name is None: | |||
# this is useful when Trainer.__init__ performs type check | |||
self._loss = None | |||
else: | |||
if not isinstance(loss_name, str): | |||
raise NotImplementedError | |||
else: | |||
self._loss = self._get_loss(loss_name, **kwargs) | |||
self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] | |||
def add_pre_pro(self, func): | |||
"""add a pre_pro function | |||
:param func: a function or str, methods to reform parameters before calculating loss | |||
the strings will be auto translated to pre-defined functions | |||
""" | |||
if not callable(func): | |||
func = method_dict.get(func) | |||
if func is None: | |||
return | |||
self.pre_pro.append(func) | |||
@staticmethod | |||
def _get_loss(loss_name, **kwargs): | |||
"""Get loss function from torch | |||
:param loss_name: str, the name of loss function | |||
:param **kwargs: kwargs for torch loss function | |||
:return: A callable loss function object | |||
""" | |||
loss_name = loss_name.strip().lower() | |||
loss_name = "".join(loss_name.split("_")) | |||
if len(loss_name) < 4 or loss_name[-4:] != "loss": | |||
loss_name += "loss" | |||
return loss_function_name[loss_name](**kwargs) | |||
def get(self): | |||
"""This method exists just for make some existing codes run error-freely | |||
""" | |||
return self | |||
def __call__(self, predict, truth, **kwargs): | |||
"""Call a loss function | |||
predict and truth will be processed by pre_pro methods in order of addition | |||
:param predict : Tensor, model output | |||
:param truth : Tensor, truth from dataset | |||
:param **kwargs : extra arguments, pass to pre_pro functions | |||
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens | |||
""" | |||
for f in self.pre_pro: | |||
if f is None: | |||
continue | |||
predict, truth = f(predict, truth, **kwargs) | |||
return self._loss(predict, truth) |
@@ -1,5 +1,4 @@ | |||
import inspect | |||
import warnings | |||
from collections import defaultdict | |||
import numpy as np | |||
@@ -197,19 +196,19 @@ class AccuracyMetric(MetricBase): | |||
""" | |||
fast_param = {} | |||
targets = list(target_dict.values()) | |||
if len(targets)==1 and isinstance(targets[0], torch.Tensor): | |||
if len(pred_dict)==1: | |||
if len(targets) == 1 and isinstance(targets[0], torch.Tensor): | |||
if len(pred_dict) == 1: | |||
pred = list(pred_dict.values())[0] | |||
fast_param['pred'] = pred | |||
elif len(pred_dict)==2: | |||
elif len(pred_dict) == 2: | |||
pred1 = list(pred_dict.values())[0] | |||
pred2 = list(pred_dict.values())[1] | |||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||
return fast_param | |||
if len(pred1.size())<len(pred2.size()) and len(pred1.size())==1: | |||
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1: | |||
seq_lens = pred1 | |||
pred = pred2 | |||
elif len(pred1.size())>len(pred2.size()) and len(pred2.size())==1: | |||
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1: | |||
seq_lens = pred2 | |||
pred = pred1 | |||
else: | |||
@@ -308,178 +307,6 @@ def _prepare_metrics(metrics): | |||
return _metrics | |||
""" | |||
Attention: Codes below are not used in current FastNLP. | |||
However, it is useful. | |||
""" | |||
def _conver_numpy(x): | |||
"""convert input data to numpy array | |||
""" | |||
if isinstance(x, np.ndarray): | |||
return x | |||
elif isinstance(x, torch.Tensor): | |||
return x.numpy() | |||
elif isinstance(x, list): | |||
return np.array(x) | |||
raise TypeError('cannot accept object: {}'.format(x)) | |||
def _check_same_len(*arrays, axis=0): | |||
"""check if input array list has same length for one dimension | |||
""" | |||
lens = set([x.shape[axis] for x in arrays if x is not None]) | |||
return len(lens) == 1 | |||
def _label_types(y): | |||
"""Determine the type | |||
- "binary" | |||
- "multiclass" | |||
- "multiclass-multioutput" | |||
- "multilabel" | |||
- "unknown" | |||
""" | |||
# never squeeze the first dimension | |||
y = y.squeeze() if y.shape[0] > 1 else y.resize(1, -1) | |||
shape = y.shape | |||
if len(shape) < 1: | |||
raise ValueError('cannot accept data: {}'.format(y)) | |||
if len(shape) == 1: | |||
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y | |||
if len(shape) == 2: | |||
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y | |||
return 'unknown', y | |||
def _check_data(y_true, y_pred): | |||
"""Check if y_true and y_pred is same type of data e.g both binary or multiclass | |||
""" | |||
y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) | |||
if not _check_same_len(y_true, y_pred): | |||
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) | |||
type_true, y_true = _label_types(y_true) | |||
type_pred, y_pred = _label_types(y_pred) | |||
type_set = {'binary', 'multiclass'} | |||
if type_true in type_set and type_pred in type_set: | |||
return type_true if type_true == type_pred else 'multiclass', y_true, y_pred | |||
type_set = {'multiclass-multioutput', 'multilabel'} | |||
if type_true in type_set and type_pred in type_set: | |||
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred | |||
raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) | |||
def _weight_sum(y, normalize=True, sample_weight=None): | |||
if normalize: | |||
return np.average(y, weights=sample_weight) | |||
if sample_weight is None: | |||
return y.sum() | |||
else: | |||
return np.dot(y, sample_weight) | |||
def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): | |||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||
if y_type == 'multiclass-multioutput': | |||
raise ValueError('cannot accept data type {0}'.format(y_type)) | |||
if y_type == 'multilabel': | |||
equel = (y_true == y_pred).sum(1) | |||
count = equel == y_true.shape[1] | |||
else: | |||
count = y_true == y_pred | |||
return _weight_sum(count, normalize=normalize, sample_weight=sample_weight) | |||
def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||
if average == 'binary': | |||
if y_type != 'binary': | |||
raise ValueError("data type is {} but use average type {}".format(y_type, average)) | |||
else: | |||
pos = (y_true == pos_label) | |||
tp = np.logical_and((y_true == y_pred), pos).sum() | |||
pos_sum = pos.sum() | |||
return tp / pos_sum if pos_sum > 0 else 0 | |||
elif average == None: | |||
y_labels = set(list(np.unique(y_true))) | |||
if labels is None: | |||
labels = list(y_labels) | |||
else: | |||
for i in labels: | |||
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): | |||
warnings.warn('label {} is not contained in data'.format(i), UserWarning) | |||
if y_type in ['binary', 'multiclass']: | |||
y_pred_right = y_true == y_pred | |||
pos_list = [y_true == i for i in labels] | |||
pos_sum_list = [pos_i.sum() for pos_i in pos_list] | |||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||
elif y_type == 'multilabel': | |||
y_pred_right = y_true == y_pred | |||
pos = (y_true == pos_label) | |||
tp = np.logical_and(y_pred_right, pos).sum(0) | |||
pos_sum = pos.sum(0) | |||
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) | |||
else: | |||
raise ValueError('not support targets type {}'.format(y_type)) | |||
raise ValueError('not support for average type {}'.format(average)) | |||
def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||
if average == 'binary': | |||
if y_type != 'binary': | |||
raise ValueError("data type is {} but use average type {}".format(y_type, average)) | |||
else: | |||
pos = (y_true == pos_label) | |||
tp = np.logical_and((y_true == y_pred), pos).sum() | |||
pos_pred = (y_pred == pos_label).sum() | |||
return tp / pos_pred if pos_pred > 0 else 0 | |||
elif average == None: | |||
y_labels = set(list(np.unique(y_true))) | |||
if labels is None: | |||
labels = list(y_labels) | |||
else: | |||
for i in labels: | |||
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): | |||
warnings.warn('label {} is not contained in data'.format(i), UserWarning) | |||
if y_type in ['binary', 'multiclass']: | |||
y_pred_right = y_true == y_pred | |||
pos_list = [y_true == i for i in labels] | |||
pos_sum_list = [(y_pred == i).sum() for i in labels] | |||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||
elif y_type == 'multilabel': | |||
y_pred_right = y_true == y_pred | |||
pos = (y_true == pos_label) | |||
tp = np.logical_and(y_pred_right, pos).sum(0) | |||
pos_sum = (y_pred == pos_label).sum(0) | |||
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) | |||
else: | |||
raise ValueError('not support targets type {}'.format(y_type)) | |||
raise ValueError('not support for average type {}'.format(average)) | |||
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||
if isinstance(precision, np.ndarray): | |||
res = 2 * precision * recall / (precision + recall + 1e-10) | |||
res[(precision + recall) <= 0] = 0 | |||
return res | |||
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | |||
def accuracy_topk(y_true, y_prob, k=1): | |||
"""Compute accuracy of y_true matching top-k probable | |||
labels in y_prob. | |||
@@ -78,6 +78,18 @@ class DataSetLoader(BaseLoader): | |||
raise NotImplementedError | |||
@DataSet.set_reader("read_naive") | |||
class NativeDataSetLoader(DataSetLoader): | |||
def __init__(self): | |||
super(NativeDataSetLoader, self).__init__() | |||
def load(self, path): | |||
ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t") | |||
ds.set_input("raw_sentence") | |||
ds.set_target("label") | |||
return ds | |||
@DataSet.set_reader('read_raw') | |||
class RawDataSetLoader(DataSetLoader): | |||
def __init__(self): | |||
@@ -1,253 +1,13 @@ | |||
import math | |||
import unittest | |||
import torch | |||
import torch as tc | |||
import torch.nn.functional as F | |||
import fastNLP.core.losses as loss | |||
from fastNLP.core.losses import squash, unpad | |||
class TestLoss(unittest.TestCase): | |||
def test_case_1(self): | |||
loss_func = loss.LossFunc(F.nll_loss) | |||
nll_loss = loss.NLLLoss() | |||
y = tc.Tensor( | |||
[ | |||
[.3, .4, .3], | |||
[.5, .3, .2], | |||
[.3, .6, .1], | |||
] | |||
) | |||
gy = tc.LongTensor( | |||
[ | |||
0, | |||
1, | |||
2, | |||
] | |||
) | |||
y = tc.log(y) | |||
los = loss_func({'input': y}, {'target': gy}) | |||
losses = nll_loss({'input': y}, {'target': gy}) | |||
r = -math.log(.3) - math.log(.3) - math.log(.1) | |||
r /= 3 | |||
print("loss = %f" % (los)) | |||
print("r = %f" % (r)) | |||
print("nll_loss = %f" % (losses)) | |||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||
def test_case_2(self): | |||
# 验证squash()的正确性 | |||
log = math.log | |||
loss_func = loss.LossFromTorch("nll") | |||
y = tc.Tensor( | |||
[ | |||
[[.3, .4, .3], [.3, .4, .3], ], | |||
[[.5, .3, .2], [.1, .2, .7], ], | |||
[[.3, .6, .1], [.2, .1, .7], ], | |||
] | |||
) | |||
gy = tc.LongTensor( | |||
[ | |||
[0, 2], | |||
[1, 2], | |||
[2, 1], | |||
] | |||
) | |||
y = tc.log(y) | |||
# los = loss_func({'input': y}, {'target': gy}) | |||
los = loss_func(y, gy) | |||
r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) | |||
r /= 6 | |||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||
def test_case_3(self): | |||
# 验证pack_padded_sequence()的正确性 | |||
log = math.log | |||
loss_func = loss.NLLLoss() | |||
y = tc.Tensor( | |||
[ | |||
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], ], | |||
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], ], | |||
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], ], | |||
] | |||
) | |||
gy = tc.LongTensor( | |||
[ | |||
[0, 2, 1, ], | |||
[1, 2, 0, ], | |||
[2, 0, 0, ], | |||
] | |||
) | |||
lens = [3, 2, 1] | |||
# pdb.set_trace() | |||
y = tc.log(y) | |||
yy = tc.nn.utils.rnn.pack_padded_sequence(y, lens, batch_first=True).data | |||
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy, lens, batch_first=True).data | |||
los = loss_func({'input': yy}, {'target': gyy}) | |||
r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | |||
r /= 6 | |||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||
def test_case_4(self): | |||
# 验证unpad()的正确性 | |||
log = math.log | |||
y = tc.Tensor( | |||
[ | |||
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], | |||
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
] | |||
) | |||
gy = tc.LongTensor( | |||
[ | |||
[0, 2, 1, 2, ], | |||
[1, 2, 0, 0, ], | |||
[2, 0, 0, 0, ], | |||
] | |||
) | |||
lens = [4, 2, 1] | |||
y = tc.log(y) | |||
loss_func = loss.LossFromTorch("nll", pre_pro=["unpad"]) | |||
los = loss_func(y, gy, lens=lens) | |||
r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | |||
r /= 7 | |||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||
def test_case_5(self): | |||
# 验证mask()和make_mask()的正确性 | |||
log = math.log | |||
y = tc.Tensor( | |||
[ | |||
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
[[.5, .4, .1], [.3, .2, .5], [.4, .5, .1, ], [.6, .1, .3, ], ], | |||
[[.3, .6, .1], [.3, .2, .5], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
] | |||
) | |||
gy = tc.LongTensor( | |||
[ | |||
[1, 2, 0, 0, ], | |||
[0, 2, 1, 2, ], | |||
[2, 1, 0, 0, ], | |||
] | |||
) | |||
mask = tc.ByteTensor( | |||
[ | |||
[1, 1, 0, 0, ], | |||
[1, 1, 1, 1, ], | |||
[1, 1, 0, 0, ], | |||
] | |||
) | |||
y = tc.log(y) | |||
lens = [2, 4, 2] | |||
loss_func = loss.LossFromTorch("nll", pre_pro=["mask"]) | |||
los = loss_func(y, gy, mask=mask) | |||
los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1])) | |||
r = -log(.3) - log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2) | |||
r /= 8 | |||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||
self.assertEqual(int(los2 * 1000), int(r * 1000)) | |||
def test_case_6(self): | |||
# 验证unpad_mask()的正确性 | |||
log = math.log | |||
y = tc.Tensor( | |||
[ | |||
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], | |||
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
] | |||
) | |||
gy = tc.LongTensor( | |||
[ | |||
[0, 2, 1, 2, ], | |||
[1, 2, 0, 0, ], | |||
[2, 0, 0, 0, ], | |||
] | |||
) | |||
lens = [4, 2, 1] | |||
# pdb.set_trace() | |||
y = tc.log(y) | |||
loss_func = loss.LossFromTorch("nll", pre_pro=["unpad_mask"]) | |||
los = loss_func(y, gy, lens=lens) | |||
r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | |||
r /= 7 | |||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||
def test_case_7(self): | |||
# 验证一些其他东西 | |||
log = math.log | |||
y = tc.Tensor( | |||
[ | |||
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], | |||
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], | |||
] | |||
) | |||
gy = tc.LongTensor( | |||
[ | |||
[0, 2, 1, 2, ], | |||
[1, 2, 0, 0, ], | |||
[2, 0, 0, 0, ], | |||
] | |||
) | |||
lens = [4, 2, 1] | |||
y = tc.log(y) | |||
loss_func = loss.LossFromTorch("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0])) | |||
loss_func.add_pre_pro("unpad_mask") | |||
los = loss_func(y, gy, lens=lens) | |||
r = - log(.3) - log(.5) - log(.3) | |||
r /= 3 | |||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||
def test_case_8(self): | |||
pass | |||
class TestLoss_v2(unittest.TestCase): | |||
def test_CrossEntropyLoss(self): | |||
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth") | |||
a = torch.randn(3, 5, requires_grad=False) | |||
@@ -276,6 +36,7 @@ class TestLoss_v2(unittest.TestCase): | |||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) | |||
class TestLosserError(unittest.TestCase): | |||
def test_losser1(self): | |||
# (1) only input, targets passed | |||
@@ -292,11 +53,12 @@ class TestLosserError(unittest.TestCase): | |||
target_dict = {'target': torch.zeros(16, 3).long()} | |||
los = loss.CrossEntropyLoss() | |||
# print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
with self.assertRaises(RuntimeError): | |||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
def test_losser3(self): | |||
# (2) with corrupted size | |||
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0} | |||
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} | |||
target_dict = {'target': torch.zeros(16).long()} | |||
los = loss.CrossEntropyLoss() | |||
@@ -311,3 +73,15 @@ class TestLosserError(unittest.TestCase): | |||
with self.assertRaises(Exception): | |||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) | |||
class TestLossUtils(unittest.TestCase): | |||
def test_squash(self): | |||
a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) | |||
self.assertEqual(tuple(a.size()), (3, 5)) | |||
self.assertEqual(tuple(b.size()), (15,)) | |||
def test_unpad(self): | |||
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) | |||
self.assertEqual(tuple(a.size()), (5, 8, 3)) | |||
self.assertEqual(tuple(b.size()), (5, 8)) |
@@ -4,7 +4,7 @@ import numpy as np | |||
import torch | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score, pred_topk, accuracy_topk | |||
from fastNLP.core.metrics import pred_topk, accuracy_topk | |||
class TestAccuracyMetric(unittest.TestCase): | |||
@@ -139,10 +139,6 @@ class TestUsefulFunctions(unittest.TestCase): | |||
# 测试metrics.py中一些看上去挺有用的函数 | |||
def test_case_1(self): | |||
# multi-class | |||
_ = accuracy_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1))) | |||
_ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) | |||
_ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) | |||
_ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) | |||
_ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) | |||
_ = pred_topk(np.random.randint(0, 3, size=(10, 1))) | |||
@@ -0,0 +1,81 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"collapsed": true | |||
}, | |||
"source": [ | |||
"# 六行代码搞定FastNLP" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"from fastNLP.core.dataset import DataSet\n", | |||
"import fastNLP.io.dataset_loader" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"ds = DataSet.read_naive(\"../test/data_for_tests/tutorial_sample_dataset.csv\")" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python 2", | |||
"language": "python", | |||
"name": "python2" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 2 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython2", | |||
"version": "2.7.6" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 0 | |||
} |