Browse Source

* remove unused codes in losses.py & metrics.py

* refine code style
* fix tests
* add a new tutorial
tags/v0.4.10
FengZiYjun 6 years ago
parent
commit
447746d9f5
6 changed files with 119 additions and 538 deletions
  1. +3
    -112
      fastNLP/core/losses.py
  2. +5
    -178
      fastNLP/core/metrics.py
  3. +12
    -0
      fastNLP/io/dataset_loader.py
  4. +17
    -243
      test/core/test_loss.py
  5. +1
    -5
      test/core/test_metrics.py
  6. +81
    -0
      tutorials/fastnlp_in_six_lines.ipynb

+ 3
- 112
fastNLP/core/losses.py View File

@@ -195,6 +195,7 @@ class CrossEntropyLoss(LossBase):
return F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx)


class L1Loss(LossBase):
def __init__(self, pred=None, target=None):
super(L1Loss, self).__init__()
@@ -212,6 +213,7 @@ class BCELoss(LossBase):
def get_loss(self, pred, target):
return F.binary_cross_entropy(input=pred, target=target)


class NLLLoss(LossBase):
def __init__(self, pred=None, target=None):
super(NLLLoss, self).__init__()
@@ -259,7 +261,7 @@ def _prepare_losser(losser):
elif isinstance(losser, LossBase):
return losser
else:
raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}")
raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}")


def squash(predict, truth, **kwargs):
@@ -354,114 +356,3 @@ def make_mask(lens, tar_len):
mask = torch.stack(mask, 1)
return mask


# map string to function. Just for more elegant using
method_dict = {
"squash": squash,
"unpad": unpad,
"unpad_mask": unpad_mask,
"mask": mask,
}

loss_function_name = {
"L1Loss".lower(): torch.nn.L1Loss,
"BCELoss".lower(): torch.nn.BCELoss,
"MSELoss".lower(): torch.nn.MSELoss,
"NLLLoss".lower(): torch.nn.NLLLoss,
"KLDivLoss".lower(): torch.nn.KLDivLoss,
"NLLLoss2dLoss".lower(): torch.nn.NLLLoss2d, # every name should end with "loss"
"SmoothL1Loss".lower(): torch.nn.SmoothL1Loss,
"SoftMarginLoss".lower(): torch.nn.SoftMarginLoss,
"PoissonNLLLoss".lower(): torch.nn.PoissonNLLLoss,
"MultiMarginLoss".lower(): torch.nn.MultiMarginLoss,
"CrossEntropyLoss".lower(): torch.nn.CrossEntropyLoss,
"BCEWithLogitsLoss".lower(): torch.nn.BCEWithLogitsLoss,
"MarginRankingLoss".lower(): torch.nn.MarginRankingLoss,
"TripletMarginLoss".lower(): torch.nn.TripletMarginLoss,
"HingeEmbeddingLoss".lower(): torch.nn.HingeEmbeddingLoss,
"CosineEmbeddingLoss".lower(): torch.nn.CosineEmbeddingLoss,
"MultiLabelMarginLoss".lower(): torch.nn.MultiLabelMarginLoss,
"MultiLabelSoftMarginLoss".lower(): torch.nn.MultiLabelSoftMarginLoss,
}


class LossFromTorch(object):
"""a LossFromTorch object is a callable object represents loss functions

This class only helps you with loss functions from PyTorch.
It has nothing to do with Trainer.
"""

def __init__(self, loss_name, pre_pro=[squash], **kwargs):
"""

:param loss_name: str or None , the name of loss function
:param pre_pro : list of function or str, methods to reform parameters before calculating loss
the strings will be auto translated to pre-defined functions
:param **kwargs: kwargs for torch loss function

pre_pro funcsions should have three arguments: predict, truth, **arg
predict and truth is the necessary parameters in loss function
kwargs is the extra parameters passed-in when calling loss function
pre_pro functions should return two objects, respectively predict and truth that after processed

"""

if loss_name is None:
# this is useful when Trainer.__init__ performs type check
self._loss = None
else:
if not isinstance(loss_name, str):
raise NotImplementedError
else:
self._loss = self._get_loss(loss_name, **kwargs)

self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro]

def add_pre_pro(self, func):
"""add a pre_pro function

:param func: a function or str, methods to reform parameters before calculating loss
the strings will be auto translated to pre-defined functions
"""
if not callable(func):
func = method_dict.get(func)
if func is None:
return
self.pre_pro.append(func)

@staticmethod
def _get_loss(loss_name, **kwargs):
"""Get loss function from torch

:param loss_name: str, the name of loss function
:param **kwargs: kwargs for torch loss function
:return: A callable loss function object
"""
loss_name = loss_name.strip().lower()
loss_name = "".join(loss_name.split("_"))

if len(loss_name) < 4 or loss_name[-4:] != "loss":
loss_name += "loss"
return loss_function_name[loss_name](**kwargs)

def get(self):
"""This method exists just for make some existing codes run error-freely
"""
return self

def __call__(self, predict, truth, **kwargs):
"""Call a loss function
predict and truth will be processed by pre_pro methods in order of addition

:param predict : Tensor, model output
:param truth : Tensor, truth from dataset
:param **kwargs : extra arguments, pass to pre_pro functions
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens
"""
for f in self.pre_pro:
if f is None:
continue
predict, truth = f(predict, truth, **kwargs)

return self._loss(predict, truth)

+ 5
- 178
fastNLP/core/metrics.py View File

@@ -1,5 +1,4 @@
import inspect
import warnings
from collections import defaultdict

import numpy as np
@@ -197,19 +196,19 @@ class AccuracyMetric(MetricBase):
"""
fast_param = {}
targets = list(target_dict.values())
if len(targets)==1 and isinstance(targets[0], torch.Tensor):
if len(pred_dict)==1:
if len(targets) == 1 and isinstance(targets[0], torch.Tensor):
if len(pred_dict) == 1:
pred = list(pred_dict.values())[0]
fast_param['pred'] = pred
elif len(pred_dict)==2:
elif len(pred_dict) == 2:
pred1 = list(pred_dict.values())[0]
pred2 = list(pred_dict.values())[1]
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)):
return fast_param
if len(pred1.size())<len(pred2.size()) and len(pred1.size())==1:
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1:
seq_lens = pred1
pred = pred2
elif len(pred1.size())>len(pred2.size()) and len(pred2.size())==1:
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1:
seq_lens = pred2
pred = pred1
else:
@@ -308,178 +307,6 @@ def _prepare_metrics(metrics):
return _metrics


"""
Attention: Codes below are not used in current FastNLP.
However, it is useful.

"""


def _conver_numpy(x):
"""convert input data to numpy array

"""
if isinstance(x, np.ndarray):
return x
elif isinstance(x, torch.Tensor):
return x.numpy()
elif isinstance(x, list):
return np.array(x)
raise TypeError('cannot accept object: {}'.format(x))


def _check_same_len(*arrays, axis=0):
"""check if input array list has same length for one dimension

"""
lens = set([x.shape[axis] for x in arrays if x is not None])
return len(lens) == 1


def _label_types(y):
"""Determine the type
- "binary"
- "multiclass"
- "multiclass-multioutput"
- "multilabel"
- "unknown"
"""
# never squeeze the first dimension
y = y.squeeze() if y.shape[0] > 1 else y.resize(1, -1)
shape = y.shape
if len(shape) < 1:
raise ValueError('cannot accept data: {}'.format(y))
if len(shape) == 1:
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y
if len(shape) == 2:
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y
return 'unknown', y


def _check_data(y_true, y_pred):
"""Check if y_true and y_pred is same type of data e.g both binary or multiclass

"""
y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred)
if not _check_same_len(y_true, y_pred):
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred))
type_true, y_true = _label_types(y_true)
type_pred, y_pred = _label_types(y_pred)

type_set = {'binary', 'multiclass'}
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass', y_true, y_pred

type_set = {'multiclass-multioutput', 'multilabel'}
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred

raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred))


def _weight_sum(y, normalize=True, sample_weight=None):
if normalize:
return np.average(y, weights=sample_weight)
if sample_weight is None:
return y.sum()
else:
return np.dot(y, sample_weight)


def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if y_type == 'multiclass-multioutput':
raise ValueError('cannot accept data type {0}'.format(y_type))
if y_type == 'multilabel':
equel = (y_true == y_pred).sum(1)
count = equel == y_true.shape[1]
else:
count = y_true == y_pred
return _weight_sum(count, normalize=normalize, sample_weight=sample_weight)


def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if average == 'binary':
if y_type != 'binary':
raise ValueError("data type is {} but use average type {}".format(y_type, average))
else:
pos = (y_true == pos_label)
tp = np.logical_and((y_true == y_pred), pos).sum()
pos_sum = pos.sum()
return tp / pos_sum if pos_sum > 0 else 0
elif average == None:
y_labels = set(list(np.unique(y_true)))
if labels is None:
labels = list(y_labels)
else:
for i in labels:
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]):
warnings.warn('label {} is not contained in data'.format(i), UserWarning)

if y_type in ['binary', 'multiclass']:
y_pred_right = y_true == y_pred
pos_list = [y_true == i for i in labels]
pos_sum_list = [pos_i.sum() for pos_i in pos_list]
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
elif y_type == 'multilabel':
y_pred_right = y_true == y_pred
pos = (y_true == pos_label)
tp = np.logical_and(y_pred_right, pos).sum(0)
pos_sum = pos.sum(0)
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels])
else:
raise ValueError('not support targets type {}'.format(y_type))
raise ValueError('not support for average type {}'.format(average))


def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if average == 'binary':
if y_type != 'binary':
raise ValueError("data type is {} but use average type {}".format(y_type, average))
else:
pos = (y_true == pos_label)
tp = np.logical_and((y_true == y_pred), pos).sum()
pos_pred = (y_pred == pos_label).sum()
return tp / pos_pred if pos_pred > 0 else 0
elif average == None:
y_labels = set(list(np.unique(y_true)))
if labels is None:
labels = list(y_labels)
else:
for i in labels:
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]):
warnings.warn('label {} is not contained in data'.format(i), UserWarning)

if y_type in ['binary', 'multiclass']:
y_pred_right = y_true == y_pred
pos_list = [y_true == i for i in labels]
pos_sum_list = [(y_pred == i).sum() for i in labels]
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
elif y_type == 'multilabel':
y_pred_right = y_true == y_pred
pos = (y_true == pos_label)
tp = np.logical_and(y_pred_right, pos).sum(0)
pos_sum = (y_pred == pos_label).sum(0)
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels])
else:
raise ValueError('not support targets type {}'.format(y_type))
raise ValueError('not support for average type {}'.format(average))


def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
if isinstance(precision, np.ndarray):
res = 2 * precision * recall / (precision + recall + 1e-10)
res[(precision + recall) <= 0] = 0
return res
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0


def accuracy_topk(y_true, y_prob, k=1):
"""Compute accuracy of y_true matching top-k probable
labels in y_prob.


+ 12
- 0
fastNLP/io/dataset_loader.py View File

@@ -78,6 +78,18 @@ class DataSetLoader(BaseLoader):
raise NotImplementedError


@DataSet.set_reader("read_naive")
class NativeDataSetLoader(DataSetLoader):
def __init__(self):
super(NativeDataSetLoader, self).__init__()

def load(self, path):
ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t")
ds.set_input("raw_sentence")
ds.set_target("label")
return ds


@DataSet.set_reader('read_raw')
class RawDataSetLoader(DataSetLoader):
def __init__(self):


+ 17
- 243
test/core/test_loss.py View File

@@ -1,253 +1,13 @@
import math
import unittest

import torch
import torch as tc
import torch.nn.functional as F

import fastNLP.core.losses as loss
from fastNLP.core.losses import squash, unpad


class TestLoss(unittest.TestCase):

def test_case_1(self):
loss_func = loss.LossFunc(F.nll_loss)
nll_loss = loss.NLLLoss()
y = tc.Tensor(
[
[.3, .4, .3],
[.5, .3, .2],
[.3, .6, .1],
]
)

gy = tc.LongTensor(
[
0,
1,
2,
]
)

y = tc.log(y)
los = loss_func({'input': y}, {'target': gy})
losses = nll_loss({'input': y}, {'target': gy})

r = -math.log(.3) - math.log(.3) - math.log(.1)
r /= 3
print("loss = %f" % (los))
print("r = %f" % (r))
print("nll_loss = %f" % (losses))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_2(self):
# 验证squash()的正确性

log = math.log
loss_func = loss.LossFromTorch("nll")

y = tc.Tensor(
[
[[.3, .4, .3], [.3, .4, .3], ],
[[.5, .3, .2], [.1, .2, .7], ],
[[.3, .6, .1], [.2, .1, .7], ],
]
)

gy = tc.LongTensor(
[
[0, 2],
[1, 2],
[2, 1],
]
)

y = tc.log(y)
# los = loss_func({'input': y}, {'target': gy})
los = loss_func(y, gy)

r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
r /= 6

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_3(self):
# 验证pack_padded_sequence()的正确性
log = math.log
loss_func = loss.NLLLoss()
y = tc.Tensor(
[
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], ],
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], ],
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], ],
]
)

gy = tc.LongTensor(
[
[0, 2, 1, ],
[1, 2, 0, ],
[2, 0, 0, ],
]
)

lens = [3, 2, 1]

# pdb.set_trace()

y = tc.log(y)

yy = tc.nn.utils.rnn.pack_padded_sequence(y, lens, batch_first=True).data
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy, lens, batch_first=True).data
los = loss_func({'input': yy}, {'target': gyy})

r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
r /= 6

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_4(self):
# 验证unpad()的正确性
log = math.log
y = tc.Tensor(
[
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ],
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ],
]
)

gy = tc.LongTensor(
[
[0, 2, 1, 2, ],
[1, 2, 0, 0, ],
[2, 0, 0, 0, ],
]
)

lens = [4, 2, 1]
y = tc.log(y)

loss_func = loss.LossFromTorch("nll", pre_pro=["unpad"])
los = loss_func(y, gy, lens=lens)

r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
r /= 7

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_5(self):
# 验证mask()和make_mask()的正确性
log = math.log

y = tc.Tensor(
[
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
[[.5, .4, .1], [.3, .2, .5], [.4, .5, .1, ], [.6, .1, .3, ], ],
[[.3, .6, .1], [.3, .2, .5], [.0, .0, .0, ], [.0, .0, .0, ], ],
]
)

gy = tc.LongTensor(
[
[1, 2, 0, 0, ],
[0, 2, 1, 2, ],
[2, 1, 0, 0, ],
]
)

mask = tc.ByteTensor(
[
[1, 1, 0, 0, ],
[1, 1, 1, 1, ],
[1, 1, 0, 0, ],
]
)

y = tc.log(y)

lens = [2, 4, 2]

loss_func = loss.LossFromTorch("nll", pre_pro=["mask"])
los = loss_func(y, gy, mask=mask)

los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1]))

r = -log(.3) - log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2)
r /= 8

self.assertEqual(int(los * 1000), int(r * 1000))
self.assertEqual(int(los2 * 1000), int(r * 1000))

def test_case_6(self):
# 验证unpad_mask()的正确性
log = math.log
y = tc.Tensor(
[
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ],
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ],
]
)

gy = tc.LongTensor(
[
[0, 2, 1, 2, ],
[1, 2, 0, 0, ],
[2, 0, 0, 0, ],
]
)

lens = [4, 2, 1]

# pdb.set_trace()

y = tc.log(y)

loss_func = loss.LossFromTorch("nll", pre_pro=["unpad_mask"])
los = loss_func(y, gy, lens=lens)

r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
r /= 7

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_7(self):
# 验证一些其他东西
log = math.log
y = tc.Tensor(
[
[[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ],
[[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
[[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ],
]
)

gy = tc.LongTensor(
[
[0, 2, 1, 2, ],
[1, 2, 0, 0, ],
[2, 0, 0, 0, ],
]
)

lens = [4, 2, 1]
y = tc.log(y)

loss_func = loss.LossFromTorch("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0]))
loss_func.add_pre_pro("unpad_mask")
los = loss_func(y, gy, lens=lens)

r = - log(.3) - log(.5) - log(.3)
r /= 3
self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_8(self):
pass


class TestLoss_v2(unittest.TestCase):
def test_CrossEntropyLoss(self):
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False)
@@ -276,6 +36,7 @@ class TestLoss_v2(unittest.TestCase):
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))


class TestLosserError(unittest.TestCase):
def test_losser1(self):
# (1) only input, targets passed
@@ -292,11 +53,12 @@ class TestLosserError(unittest.TestCase):
target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss()

# print(los(pred_dict=pred_dict, target_dict=target_dict))
with self.assertRaises(RuntimeError):
print(los(pred_dict=pred_dict, target_dict=target_dict))

def test_losser3(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0}
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0}
target_dict = {'target': torch.zeros(16).long()}
los = loss.CrossEntropyLoss()

@@ -311,3 +73,15 @@ class TestLosserError(unittest.TestCase):

with self.assertRaises(Exception):
ans = l1({"my_predict": a}, {"truth": b, "my": a})


class TestLossUtils(unittest.TestCase):
def test_squash(self):
a, b = squash(torch.randn(3, 5), torch.randn(3, 5))
self.assertEqual(tuple(a.size()), (3, 5))
self.assertEqual(tuple(b.size()), (15,))

def test_unpad(self):
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8))
self.assertEqual(tuple(a.size()), (5, 8, 3))
self.assertEqual(tuple(b.size()), (5, 8))

+ 1
- 5
test/core/test_metrics.py View File

@@ -4,7 +4,7 @@ import numpy as np
import torch

from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score, pred_topk, accuracy_topk
from fastNLP.core.metrics import pred_topk, accuracy_topk


class TestAccuracyMetric(unittest.TestCase):
@@ -139,10 +139,6 @@ class TestUsefulFunctions(unittest.TestCase):
# 测试metrics.py中一些看上去挺有用的函数
def test_case_1(self):
# multi-class
_ = accuracy_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)))
_ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
_ = pred_topk(np.random.randint(0, 3, size=(10, 1)))



+ 81
- 0
tutorials/fastnlp_in_six_lines.ipynb View File

@@ -0,0 +1,81 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# 六行代码搞定FastNLP"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.core.dataset import DataSet\n",
"import fastNLP.io.dataset_loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds = DataSet.read_naive(\"../test/data_for_tests/tutorial_sample_dataset.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

Loading…
Cancel
Save