Browse Source

* fix processor.py

* add code comments
* merge *_saver.py & *_loader.py in io/
* (ancient codes) rename Loss into LossFromTorch
tags/v0.4.10
FengZiYjun 6 years ago
parent
commit
27e9453d19
22 changed files with 349 additions and 386 deletions
  1. +2
    -6
      fastNLP/api/model_zoo.py
  2. +23
    -11
      fastNLP/api/processor.py
  3. +1
    -1
      fastNLP/core/__init__.py
  4. +47
    -19
      fastNLP/core/dataset.py
  5. +33
    -22
      fastNLP/core/losses.py
  6. +7
    -0
      fastNLP/core/metrics.py
  7. +12
    -0
      fastNLP/core/optimizer.py
  8. +0
    -9
      fastNLP/core/trainer.py
  9. +0
    -16
      fastNLP/io/base_loader.py
  10. +148
    -2
      fastNLP/io/config_io.py
  11. +0
    -149
      fastNLP/io/config_loader.py
  12. +21
    -105
      fastNLP/io/dataset_loader.py
  13. +28
    -0
      fastNLP/io/model_io.py
  14. +0
    -28
      fastNLP/io/model_loader.py
  15. +1
    -1
      reproduction/Biaffine_parser/infer.py
  16. +2
    -3
      reproduction/Biaffine_parser/run.py
  17. +2
    -2
      reproduction/LSTM+self_attention_sentiment_analysis/main.py
  18. +2
    -3
      reproduction/chinese_word_segment/run.py
  19. +2
    -2
      setup.py
  20. +12
    -0
      test/api/test_processor.py
  21. +5
    -5
      test/core/test_loss.py
  22. +1
    -2
      test/io/test_config_saver.py

+ 2
- 6
fastNLP/api/model_zoo.py View File

@@ -1,5 +1,3 @@
import torch

import hashlib import hashlib
import os import os
import re import re
@@ -7,6 +5,8 @@ import shutil
import sys import sys
import tempfile import tempfile


import torch

try: try:
from requests.utils import urlparse from requests.utils import urlparse
from requests import get as urlopen from requests import get as urlopen
@@ -132,7 +132,3 @@ if tqdm is None:


sys.stderr.write('\n') sys.stderr.write('\n')



if __name__ == '__main__':
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.')
print(type(pipeline))

+ 23
- 11
fastNLP/api/processor.py View File

@@ -1,14 +1,15 @@
import torch
from collections import defaultdict
import re import re
from collections import defaultdict

import torch


from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.vocabulary import Vocabulary




class Processor:
class Processor(object):
def __init__(self, field_name, new_added_field_name): def __init__(self, field_name, new_added_field_name):
self.field_name = field_name self.field_name = field_name
if new_added_field_name is None: if new_added_field_name is None:
@@ -17,7 +18,7 @@ class Processor:
self.new_added_field_name = new_added_field_name self.new_added_field_name = new_added_field_name


def process(self, *args, **kwargs): def process(self, *args, **kwargs):
pass
raise NotImplementedError


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs) return self.process(*args, **kwargs)
@@ -132,13 +133,14 @@ class Num2TagProcessor(Processor):




class IndexerProcessor(Processor): class IndexerProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):


assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))


super(IndexerProcessor, self).__init__(field_name, new_added_field_name) super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab self.vocab = vocab
self.delete_old_field = delete_old_field self.delete_old_field = delete_old_field
self.is_input = is_input


def set_vocab(self, vocab): def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
@@ -146,13 +148,14 @@ class IndexerProcessor(Processor):
self.vocab = vocab self.vocab = vocab


def process(self, dataset): def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
for ins in dataset: for ins in dataset:
tokens = ins[self.field_name] tokens = ins[self.field_name]
index = [self.vocab.to_index(token) for token in tokens] index = [self.vocab.to_index(token) for token in tokens]
ins[self.new_added_field_name] = index ins[self.new_added_field_name] = index


dataset._set_need_tensor(**{self.new_added_field_name: True})
if self.is_input:
dataset.set_input(self.new_added_field_name)


if self.delete_old_field: if self.delete_old_field:
dataset.delete_field(self.field_name) dataset.delete_field(self.field_name)
@@ -161,6 +164,9 @@ class IndexerProcessor(Processor):




class VocabProcessor(Processor): class VocabProcessor(Processor):
"""Build vocabulary with a field in the data set.

"""
def __init__(self, field_name): def __init__(self, field_name):
super(VocabProcessor, self).__init__(field_name, None) super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary() self.vocab = Vocabulary()
@@ -178,17 +184,20 @@ class VocabProcessor(Processor):




class SeqLenProcessor(Processor): class SeqLenProcessor(Processor):
def __init__(self, field_name, new_added_field_name='seq_lens'):
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
self.is_input = is_input


def process(self, dataset): def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset: for ins in dataset:
length = len(ins[self.field_name]) length = len(ins[self.field_name])
ins[self.new_added_field_name] = length ins[self.new_added_field_name] = length
dataset._set_need_tensor(**{self.new_added_field_name: True})
if self.is_input:
dataset.set_input(self.new_added_field_name)
return dataset return dataset



class ModelProcessor(Processor): class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
""" """
@@ -238,6 +247,7 @@ class ModelProcessor(Processor):
device = torch.device(device) device = torch.device(device)
self.model.to(device) self.model.to(device)



class Index2WordProcessor(Processor): class Index2WordProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name): def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
@@ -251,6 +261,7 @@ class Index2WordProcessor(Processor):




class SetTensorProcessor(Processor): class SetTensorProcessor(Processor):
# TODO: remove it. It is strange.
def __init__(self, field_dict, default=False): def __init__(self, field_dict, default=False):
super(SetTensorProcessor, self).__init__(None, None) super(SetTensorProcessor, self).__init__(None, None)
self.field_dict = field_dict self.field_dict = field_dict
@@ -264,6 +275,7 @@ class SetTensorProcessor(Processor):




class SetIsTargetProcessor(Processor): class SetIsTargetProcessor(Processor):
# TODO; remove it.
def __init__(self, field_dict, default=False): def __init__(self, field_dict, default=False):
super(SetIsTargetProcessor, self).__init__(None, None) super(SetIsTargetProcessor, self).__init__(None, None)
self.field_dict = field_dict self.field_dict = field_dict


+ 1
- 1
fastNLP/core/__init__.py View File

@@ -2,7 +2,7 @@ from .batch import Batch
from .dataset import DataSet from .dataset import DataSet
from .fieldarray import FieldArray from .fieldarray import FieldArray
from .instance import Instance from .instance import Instance
from .losses import Loss
from .losses import LossFromTorch
from .optimizer import Optimizer from .optimizer import Optimizer
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester from .tester import Tester


+ 47
- 19
fastNLP/core/dataset.py View File

@@ -9,32 +9,20 @@ from fastNLP.core.utils import get_func_signature
_READERS = {} _READERS = {}




def construct_dataset(sentences):
"""Construct a data set from a list of sentences.

:param sentences: list of list of str
:return dataset: a DataSet object
"""
dataset = DataSet()
for sentence in sentences:
instance = Instance()
instance['raw_sentence'] = sentence
dataset.append(instance)
return dataset


class DataSet(object): class DataSet(object):
"""DataSet is the collection of examples. """DataSet is the collection of examples.
DataSet provides instance-level interface. You can append and access an instance of the DataSet. DataSet provides instance-level interface. You can append and access an instance of the DataSet.
However, it stores data in a different way: Field-first, Instance-second. However, it stores data in a different way: Field-first, Instance-second.


""" """

def __init__(self, data=None): def __init__(self, data=None):
""" """


:param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field.
All values must be of the same length.
If it is a list, it must be a list of Instance objects.
:param data: a dict or a list.
If `data` is a dict, the key is the name of a FieldArray and the value is the FieldArray. All values
must be of the same length.
If `data` is a list, it must be a list of Instance objects.
""" """
self.field_arrays = {} self.field_arrays = {}
if data is not None: if data is not None:
@@ -60,6 +48,7 @@ class DataSet(object):
def iter_func(): def iter_func():
for idx in range(len(self)): for idx in range(len(self)):
yield self[idx] yield self[idx]

return iter_func() return iter_func()


def _inner_iter(self): def _inner_iter(self):
@@ -69,7 +58,8 @@ class DataSet(object):
self.idx = idx self.idx = idx


def __getitem__(self, item): def __getitem__(self, item):
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[self.idx])
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx] return self.dataset.field_arrays[item][self.idx]


@@ -79,6 +69,7 @@ class DataSet(object):
def inner_iter_func(): def inner_iter_func():
for idx in range(len(self)): for idx in range(len(self)):
yield Iter_ptr(self, idx) yield Iter_ptr(self, idx)

return inner_iter_func() return inner_iter_func()


def __getitem__(self, idx): def __getitem__(self, idx):
@@ -217,9 +208,17 @@ class DataSet(object):
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))


def get_input_name(self): def get_input_name(self):
"""Get all field names with `is_input` as True.

:return list field_names: a list of str
"""
return [name for name, field in self.field_arrays.items() if field.is_input] return [name for name, field in self.field_arrays.items() if field.is_input]


def get_target_name(self): def get_target_name(self):
"""Get all field names with `is_target` as True.

:return list field_names: a list of str
"""
return [name for name, field in self.field_arrays.items() if field.is_target] return [name for name, field in self.field_arrays.items() if field.is_target]


@classmethod @classmethod
@@ -243,7 +242,7 @@ class DataSet(object):
:return results: if new_field_name is not passed, returned values of the function over all instances. :return results: if new_field_name is not passed, returned values of the function over all instances.
""" """
results = [func(ins) for ins in self._inner_iter()] results = [func(ins) for ins in self._inner_iter()]
if len(list(filter(lambda x: x is not None, results)))==0: # all None
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func))) raise ValueError("{} always return None.".format(get_func_signature(func=func)))


extra_param = {} extra_param = {}
@@ -269,6 +268,12 @@ class DataSet(object):
return results return results


def drop(self, func): def drop(self, func):
"""Drop instances if a condition holds.

:param func: a function that takes an Instance object as input, and returns bool.
The instance will be dropped if the function returns True.

"""
results = [ins for ins in self._inner_iter() if not func(ins)] results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items(): for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results] self.field_arrays[name].content = [ins[name] for ins in results]
@@ -338,10 +343,33 @@ class DataSet(object):
return cls(_dict) return cls(_dict)


def save(self, path): def save(self, path):
"""Save the DataSet object as pickle.

:param str path: the path to the pickle
"""
with open(path, 'wb') as f: with open(path, 'wb') as f:
pickle.dump(self, f) pickle.dump(self, f)


@staticmethod @staticmethod
def load(path): def load(path):
"""Load a DataSet object from pickle.

:param str path: the path to the pickle
:return DataSet data_set:
"""
with open(path, 'rb') as f: with open(path, 'rb') as f:
return pickle.load(f) return pickle.load(f)


def construct_dataset(sentences):
"""Construct a data set from a list of sentences.

:param sentences: list of list of str
:return dataset: a DataSet object
"""
dataset = DataSet()
for sentence in sentences:
instance = Instance()
instance['raw_sentence'] = sentence
dataset.append(instance)
return dataset

+ 33
- 22
fastNLP/core/losses.py View File

@@ -7,14 +7,13 @@ import torch.nn.functional as F
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_function_or_method
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _check_function_or_method
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature




class LossBase(object): class LossBase(object):
def __init__(self): def __init__(self):
# key: name in target function; value: name in output function
self.param_map = {} self.param_map = {}
self._checked = False self._checked = False


@@ -159,8 +158,18 @@ class LossBase(object):


return loss return loss



class LossFunc(LossBase): class LossFunc(LossBase):
"""A wrapper of user-provided loss function.

"""
def __init__(self, func, key_map=None, **kwargs): def __init__(self, func, key_map=None, **kwargs):
"""

:param func: a callable object, such as a function.
:param dict key_map:
:param kwargs:
"""
super(LossFunc, self).__init__() super(LossFunc, self).__init__()
_check_function_or_method(func) _check_function_or_method(func)
if key_map is not None: if key_map is not None:
@@ -254,19 +263,19 @@ def _prepare_losser(losser):




def squash(predict, truth, **kwargs): def squash(predict, truth, **kwargs):
'''To reshape tensors in order to fit Loss functions in pytorch
"""To reshape tensors in order to fit loss functions in pytorch


:param predict : Tensor, model output :param predict : Tensor, model output
:param truth : Tensor, truth from dataset :param truth : Tensor, truth from dataset
:param **kwargs : extra arguments :param **kwargs : extra arguments


:return predict , truth: predict & truth after processing :return predict , truth: predict & truth after processing
'''
"""
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) return predict.view(-1, predict.size()[-1]), truth.view(-1, )




def unpad(predict, truth, **kwargs): def unpad(predict, truth, **kwargs):
'''To process padded sequence output to get true loss
"""To process padded sequence output to get true loss
Using pack_padded_sequence() method Using pack_padded_sequence() method
This method contains squash() This method contains squash()


@@ -277,7 +286,7 @@ def unpad(predict, truth, **kwargs):
the i-th element is true lengths of i-th sequence the i-th element is true lengths of i-th sequence


:return predict , truth: predict & truth after processing :return predict , truth: predict & truth after processing
'''
"""
if kwargs.get("lens") is None: if kwargs.get("lens") is None:
return predict, truth return predict, truth
lens = torch.LongTensor(kwargs["lens"]) lens = torch.LongTensor(kwargs["lens"])
@@ -288,7 +297,7 @@ def unpad(predict, truth, **kwargs):




def unpad_mask(predict, truth, **kwargs): def unpad_mask(predict, truth, **kwargs):
'''To process padded sequence output to get true loss
"""To process padded sequence output to get true loss
Using mask() method Using mask() method
This method contains squash() This method contains squash()


@@ -299,7 +308,7 @@ def unpad_mask(predict, truth, **kwargs):
the i-th element is true lengths of i-th sequence the i-th element is true lengths of i-th sequence


:return predict , truth: predict & truth after processing :return predict , truth: predict & truth after processing
'''
"""
if kwargs.get("lens") is None: if kwargs.get("lens") is None:
return predict, truth return predict, truth
mas = make_mask(kwargs["lens"], truth.size()[1]) mas = make_mask(kwargs["lens"], truth.size()[1])
@@ -307,7 +316,7 @@ def unpad_mask(predict, truth, **kwargs):




def mask(predict, truth, **kwargs): def mask(predict, truth, **kwargs):
'''To select specific elements from Tensor
"""To select specific elements from Tensor
This method contains squash() This method contains squash()


:param predict : Tensor, [batch_size , max_len , tag_size] :param predict : Tensor, [batch_size , max_len , tag_size]
@@ -317,7 +326,7 @@ def mask(predict, truth, **kwargs):
the mask Tensor , the position that is 1 will be selected the mask Tensor , the position that is 1 will be selected


:return predict , truth: predict & truth after processing :return predict , truth: predict & truth after processing
'''
"""
if kwargs.get("mask") is None: if kwargs.get("mask") is None:
return predict, truth return predict, truth
mask = kwargs["mask"] mask = kwargs["mask"]
@@ -332,14 +341,14 @@ def mask(predict, truth, **kwargs):




def make_mask(lens, tar_len): def make_mask(lens, tar_len):
'''to generate a mask that select [:lens[i]] for i-th element
"""to generate a mask that select [:lens[i]] for i-th element
embezzle from fastNLP.models.sequence_modeling.seq_mask embezzle from fastNLP.models.sequence_modeling.seq_mask


:param lens : list or LongTensor, [batch_size] :param lens : list or LongTensor, [batch_size]
:param tar_len : int :param tar_len : int


:return mask : ByteTensor :return mask : ByteTensor
'''
"""
lens = torch.LongTensor(lens) lens = torch.LongTensor(lens)
mask = [torch.ge(lens, i + 1) for i in range(tar_len)] mask = [torch.ge(lens, i + 1) for i in range(tar_len)]
mask = torch.stack(mask, 1) mask = torch.stack(mask, 1)
@@ -376,9 +385,11 @@ loss_function_name = {
} }




class Loss(object):
"""a Loss object is a callable object represents loss functions
class LossFromTorch(object):
"""a LossFromTorch object is a callable object represents loss functions


This class only helps you with loss functions from PyTorch.
It has nothing to do with Trainer.
""" """


def __init__(self, loss_name, pre_pro=[squash], **kwargs): def __init__(self, loss_name, pre_pro=[squash], **kwargs):
@@ -408,11 +419,11 @@ class Loss(object):
self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro]


def add_pre_pro(self, func): def add_pre_pro(self, func):
'''add a pre_pro function
"""add a pre_pro function


:param func: a function or str, methods to reform parameters before calculating loss :param func: a function or str, methods to reform parameters before calculating loss
the strings will be auto translated to pre-defined functions the strings will be auto translated to pre-defined functions
'''
"""
if not callable(func): if not callable(func):
func = method_dict.get(func) func = method_dict.get(func)
if func is None: if func is None:
@@ -421,12 +432,12 @@ class Loss(object):


@staticmethod @staticmethod
def _get_loss(loss_name, **kwargs): def _get_loss(loss_name, **kwargs):
'''Get loss function from torch
"""Get loss function from torch


:param loss_name: str, the name of loss function :param loss_name: str, the name of loss function
:param **kwargs: kwargs for torch loss function :param **kwargs: kwargs for torch loss function
:return: A callable loss function object :return: A callable loss function object
'''
"""
loss_name = loss_name.strip().lower() loss_name = loss_name.strip().lower()
loss_name = "".join(loss_name.split("_")) loss_name = "".join(loss_name.split("_"))


@@ -435,19 +446,19 @@ class Loss(object):
return loss_function_name[loss_name](**kwargs) return loss_function_name[loss_name](**kwargs)


def get(self): def get(self):
'''This method exists just for make some existing codes run error-freely
'''
"""This method exists just for make some existing codes run error-freely
"""
return self return self


def __call__(self, predict, truth, **kwargs): def __call__(self, predict, truth, **kwargs):
'''call a loss function
"""Call a loss function
predict and truth will be processed by pre_pro methods in order of addition predict and truth will be processed by pre_pro methods in order of addition


:param predict : Tensor, model output :param predict : Tensor, model output
:param truth : Tensor, truth from dataset :param truth : Tensor, truth from dataset
:param **kwargs : extra arguments, pass to pre_pro functions :param **kwargs : extra arguments, pass to pre_pro functions
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens
'''
"""
for f in self.pre_pro: for f in self.pre_pro:
if f is None: if f is None:
continue continue


+ 7
- 0
fastNLP/core/metrics.py View File

@@ -308,6 +308,13 @@ def _prepare_metrics(metrics):
return _metrics return _metrics




"""
Attention: Codes below are not used in current FastNLP.
However, it is useful.

"""


def _conver_numpy(x): def _conver_numpy(x):
"""convert input data to numpy array """convert input data to numpy array




+ 12
- 0
fastNLP/core/optimizer.py View File

@@ -11,6 +11,12 @@ class Optimizer(object):


class SGD(Optimizer): class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.01, momentum=0): def __init__(self, model_params=None, lr=0.01, momentum=0):
"""

:param model_params: a generator. E.g. model.parameters() for PyTorch models.
:param float lr: learning rate. Default: 0.01
:param float momentum: momentum. Default: 0
"""
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
@@ -23,6 +29,12 @@ class SGD(Optimizer):


class Adam(Optimizer): class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.01, weight_decay=0): def __init__(self, model_params=None, lr=0.01, weight_decay=0):
"""

:param model_params: a generator. E.g. model.parameters() for PyTorch models.
:param float lr: learning rate
:param float weight_decay:
"""
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):


+ 0
- 9
fastNLP/core/trainer.py View File

@@ -140,7 +140,6 @@ class Trainer(object):
def train(self): def train(self):
"""Start Training. """Start Training.


:return:
""" """
try: try:
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
@@ -216,14 +215,6 @@ class Trainer(object):
pbar.close() pbar.close()


def _print_train(self): def _print_train(self):
"""

:param data_iterator:
:param model:
:param epoch:
:param start:
:return:
"""
epoch = 1 epoch = 1
start = time.time() start = time.time()
while epoch <= self.n_epochs: while epoch <= self.n_epochs:


+ 0
- 16
fastNLP/io/base_loader.py View File

@@ -29,19 +29,3 @@ class BaseLoader(object):
with open(cache_path, 'wb') as f: with open(cache_path, 'wb') as f:
pickle.dump(obj, f) pickle.dump(obj, f)
return obj return obj


class ToyLoader0(BaseLoader):
"""
For CharLM
"""

def __init__(self, data_path):
super(ToyLoader0, self).__init__(data_path)

def load(self):
with open(self.data_path, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()

fastNLP/io/config_saver.py → fastNLP/io/config_io.py View File

@@ -1,6 +1,152 @@
import configparser
import json
import os import os


from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""loader for configuration files"""

def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))

@staticmethod
def parse(string):
raise NotImplementedError

@staticmethod
def load_config(file_path, sections):
"""
:param file_path: the path of config file
:param sections: the dict of {section_name(string): Section instance}
Example:
test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
:return: return nothing, but the value of attributes are saved in sessions
"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
if not os.path.exists(file_path):
raise FileNotFoundError("config file {} not found. ".format(file_path))
cfg.read(file_path)
for s in sections:
attr_list = [i for i in sections[s].__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]
if s not in cfg:
print('section %s not found in config file' % (s))
continue
gen_sec = cfg[s]
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
# print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
"""
if attr in attr_list then check its type and
update its value.
else add a new attr in sections[s]
"""
setattr(sections[s], attr, val)
except Exception as e:
print("cannot load attribute %s in section %s"
% (attr, s))
pass


class ConfigSection(object):

def __init__(self):
pass

def __getitem__(self, key):
"""
:param key: str, the name of the attribute
:return attr: the value of this attribute
if key not in self.__dict__.keys():
return self[key]
else:
raise AttributeError
"""
if key in self.__dict__.keys():
return getattr(self, key)
raise AttributeError("do NOT have attribute %s" % key)

def __setitem__(self, key, value):
"""
:param key: str, the name of the attribute
:param value: the value of this attribute
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not isinstance(value, type(getattr(self, key))):
raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)

def __contains__(self, item):
"""
:param item: The key of item.
:return: True if the key in self.__dict__.keys() else False.
"""
return item in self.__dict__.keys()

def __eq__(self, other):
"""Overwrite the == operator

:param other: Another ConfigSection() object which to be compared.
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
"""
for k in self.__dict__.keys():
if k not in other.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

for k in other.__dict__.keys():
if k not in self.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

return True

def __ne__(self, other):
"""Overwrite the != operator

:param other:
:return:
"""
return not self.__eq__(other)

@property
def data(self):
return self.__dict__


if __name__ == "__main__":
config = ConfigLoader('there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
General and My can be found in config file, so the attr and
value will be updated
A cannot be found in config file, so nothing will be done
"""

config.load_config("../../test/data_for_tests/config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))




class ConfigSaver(object): class ConfigSaver(object):
@@ -125,7 +271,7 @@ class ConfigSaver(object):
# logger = create_logger(__name__, "./config_loader.log") # logger = create_logger(__name__, "./config_loader.log")
# logger.warning("section [%s] in config file [%s] has been changed" % ( # logger.warning("section [%s] in config file [%s] has been changed" % (
# section_name, self.file_path # section_name, self.file_path
#))
# ))
change_file = True change_file = True
break break
if not change_file: if not change_file:

+ 0
- 149
fastNLP/io/config_loader.py View File

@@ -1,149 +0,0 @@
import configparser
import json
import os

from fastNLP.io.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""loader for configuration files"""

def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))

@staticmethod
def parse(string):
raise NotImplementedError

@staticmethod
def load_config(file_path, sections):
"""
:param file_path: the path of config file
:param sections: the dict of {section_name(string): Section instance}
Example:
test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
:return: return nothing, but the value of attributes are saved in sessions
"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
if not os.path.exists(file_path):
raise FileNotFoundError("config file {} not found. ".format(file_path))
cfg.read(file_path)
for s in sections:
attr_list = [i for i in sections[s].__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]
if s not in cfg:
print('section %s not found in config file' % (s))
continue
gen_sec = cfg[s]
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
# print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
"""
if attr in attr_list then check its type and
update its value.
else add a new attr in sections[s]
"""
setattr(sections[s], attr, val)
except Exception as e:
print("cannot load attribute %s in section %s"
% (attr, s))
pass


class ConfigSection(object):

def __init__(self):
pass

def __getitem__(self, key):
"""
:param key: str, the name of the attribute
:return attr: the value of this attribute
if key not in self.__dict__.keys():
return self[key]
else:
raise AttributeError
"""
if key in self.__dict__.keys():
return getattr(self, key)
raise AttributeError("do NOT have attribute %s" % key)

def __setitem__(self, key, value):
"""
:param key: str, the name of the attribute
:param value: the value of this attribute
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not isinstance(value, type(getattr(self, key))):
raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)

def __contains__(self, item):
"""
:param item: The key of item.
:return: True if the key in self.__dict__.keys() else False.
"""
return item in self.__dict__.keys()

def __eq__(self, other):
"""Overwrite the == operator

:param other: Another ConfigSection() object which to be compared.
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
"""
for k in self.__dict__.keys():
if k not in other.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

for k in other.__dict__.keys():
if k not in self.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

return True

def __ne__(self, other):
"""Overwrite the != operator

:param other:
:return:
"""
return not self.__eq__(other)

@property
def data(self):
return self.__dict__


if __name__ == "__main__":
config = ConfigLoader('there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
General and My can be found in config file, so the attr and
value will be updated
A cannot be found in config file, so nothing will be done
"""

config.load_config("../../test/data_for_tests/config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))

+ 21
- 105
fastNLP/io/dataset_loader.py View File

@@ -1,4 +1,3 @@
#TODO: need fix for current DataSet
import os import os


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
@@ -20,8 +19,7 @@ def convert_seq_dataset(data):
""" """
dataset = DataSet() dataset = DataSet()
for word_seq in data: for word_seq in data:
x = TextField(word_seq, is_target=False)
dataset.append(Instance(word_seq=x))
dataset.append(Instance(word_seq=word_seq))
return dataset return dataset




@@ -40,11 +38,7 @@ def convert_seq2tag_dataset(data):
""" """
dataset = DataSet() dataset = DataSet()
for sample in data: for sample in data:
word_seq, label = sample[0], sample[1]
ins = Instance()
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
.add_field("label", LabelField(label, is_target=True))
dataset.append(ins)
dataset.append(Instance(word_seq=sample[0], label=sample[1]))
return dataset return dataset




@@ -63,11 +57,7 @@ def convert_seq2seq_dataset(data):
""" """
dataset = DataSet() dataset = DataSet()
for sample in data: for sample in data:
word_seq, label_seq = sample[0], sample[1]
ins = Instance()
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
.add_field("label_seq", TextField(label_seq, is_target=True))
dataset.append(ins)
dataset.append(Instance(word_seq=sample[0], label_seq=sample[1]))
return dataset return dataset




@@ -273,85 +263,6 @@ class ClassDataSetLoader(DataSetLoader):
return convert_seq2tag_dataset(data) return convert_seq2tag_dataset(data)




@DataSet.set_reader('read_conll')
class ConllLoader(DataSetLoader):
"""loader for conll format files"""

def __init__(self):
"""
:param str data_path: the path to the conll data set
"""
super(ConllLoader, self).__init__()

def load(self, data_path):
"""
:return: list lines: all lines in a conll file
"""
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return self.convert(data)

@staticmethod
def parse(lines):
"""
:param list lines:a list containing all lines in a conll file.
:return: a 3D list
"""
sentences = list()
tokens = list()
for line in lines:
if line[0] == "#":
# skip the comments
continue
if line == "\n":
sentences.append(tokens)
tokens = []
continue
tokens.append(line.split())
return sentences

def convert(self, data):
pass


@DataSet.set_reader('read_lm')
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

This loader produces data for language model training in a supervised way.
That means it has X and Y.

"""

def __init__(self):
super(LMDataSetLoader, self).__init__()

def load(self, data_path):
if not os.path.exists(data_path):
raise FileNotFoundError("file {} not found.".format(data_path))
with open(data_path, "r", encoding="utf=8") as f:
text = " ".join(f.readlines())
tokens = text.strip().split()
data = self.sentence_cut(tokens)
return self.convert(data)

def sentence_cut(self, tokens, sentence_length=15):
start_idx = 0
data_set = []
for idx in range(len(tokens) // sentence_length):
x = tokens[start_idx * idx: start_idx * idx + sentence_length]
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1]
if start_idx * idx + sentence_length + 1 >= len(tokens):
# ad hoc
y.extend(["<unk>"])
data_set.append([x, y])
return data_set

def convert(self, data):
pass


@DataSet.set_reader('read_people_daily') @DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
""" """
@@ -403,10 +314,19 @@ class PeopleDailyCorpusLoader(DataSetLoader):
pos_tag_examples.append([sent_words, sent_pos_tag]) pos_tag_examples.append([sent_words, sent_pos_tag])
ner_examples.append([sent_words, sent_ner]) ner_examples.append([sent_words, sent_ner])
# List[List[List[str], List[str]]] # List[List[List[str], List[str]]]
return pos_tag_examples, ner_examples
# ner_examples not used
return self.convert(pos_tag_examples)


def convert(self, data): def convert(self, data):
pass
data_set = DataSet()
for item in data:
sent_words, sent_pos_tag = item[0], item[1]
data_set.append(Instance(words=sent_words, tags=sent_pos_tag))
data_set.apply(lambda ins: len(ins), new_field_name="seq_len")
data_set.set_target("tags")
data_set.set_input("sent_words")
data_set.set_input("seq_len")
return data_set




class SNLIDataSetLoader(DataSetLoader): class SNLIDataSetLoader(DataSetLoader):
@@ -462,17 +382,13 @@ class SNLIDataSetLoader(DataSetLoader):
for example in data: for example in data:
p, h, l = example p, h, l = example
# list, list, str # list, list, str
x1 = TextField(p, is_target=False)
x2 = TextField(h, is_target=False)
x1_len = TextField([1] * len(p), is_target=False)
x2_len = TextField([1] * len(h), is_target=False)
y = LabelField(l, is_target=True)
instance = Instance() instance = Instance()
instance.add_field("premise", x1)
instance.add_field("hypothesis", x2)
instance.add_field("premise_len", x1_len)
instance.add_field("hypothesis_len", x2_len)
instance.add_field("truth", y)
instance.add_field("premise", p)
instance.add_field("hypothesis", h)
instance.add_field("truth", l)
data_set.append(instance) data_set.append(instance)

data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len")
data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len")
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len")
data_set.set_target("truth")
return data_set return data_set

fastNLP/io/model_saver.py → fastNLP/io/model_io.py View File

@@ -1,5 +1,32 @@
import torch import torch


from fastNLP.io.base_loader import BaseLoader


class ModelLoader(BaseLoader):
"""
Loader for models.
"""

def __init__(self):
super(ModelLoader, self).__init__()

@staticmethod
def load_pytorch(empty_model, model_path):
"""
Load model parameters from .pkl files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
:param model_path: str, the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))

@staticmethod
def load_pytorch_model(model_path):
"""Load the entire model.

"""
return torch.load(model_path)



class ModelSaver(object): class ModelSaver(object):
"""Save a model """Save a model
@@ -8,6 +35,7 @@ class ModelSaver(object):
saver.save_pytorch(model) saver.save_pytorch(model)


""" """

def __init__(self, save_path): def __init__(self, save_path):
""" """



+ 0
- 28
fastNLP/io/model_loader.py View File

@@ -1,28 +0,0 @@
import torch

from fastNLP.io.base_loader import BaseLoader


class ModelLoader(BaseLoader):
"""
Loader for models.
"""

def __init__(self):
super(ModelLoader, self).__init__()

@staticmethod
def load_pytorch(empty_model, model_path):
"""
Load model parameters from .pkl files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
:param model_path: str, the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))

@staticmethod
def load_pytorch_model(model_path):
"""Load the entire model.

"""
return torch.load(model_path)

+ 1
- 1
reproduction/Biaffine_parser/infer.py View File

@@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])


from fastNLP.api.processor import * from fastNLP.api.processor import *
from fastNLP.models.biaffine_parser import BiaffineParser from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.config_io import ConfigSection, ConfigLoader


import _pickle as pickle import _pickle as pickle
import torch import torch


+ 2
- 3
reproduction/Biaffine_parser/run.py View File

@@ -13,11 +13,10 @@ from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField, SeqLabelField from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.model_loader import ModelLoader
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.io.model_io import ModelLoader, ModelSaver
from fastNLP.io.embed_loader import EmbedLoader from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.models.biaffine_parser import BiaffineParser from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.io.model_saver import ModelSaver


BOS = '<BOS>' BOS = '<BOS>'
EOS = '<EOS>' EOS = '<EOS>'


+ 2
- 2
reproduction/LSTM+self_attention_sentiment_analysis/main.py View File

@@ -2,8 +2,8 @@ import torch.nn.functional as F


from fastNLP.core.trainer import ClassificationTrainer from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.core.utils import ClassPreprocess as Preprocess from fastNLP.core.utils import ClassPreprocess as Preprocess
from fastNLP.io.config_loader import ConfigLoader
from fastNLP.io.config_loader import ConfigSection
from fastNLP.io.config_io import ConfigLoader
from fastNLP.io.config_io import ConfigSection
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules.aggregator.self_attention import SelfAttention from fastNLP.modules.aggregator.self_attention import SelfAttention


+ 2
- 3
reproduction/chinese_word_segment/run.py View File

@@ -3,12 +3,11 @@ import sys


sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader
from fastNLP.core.utils import load_pickle from fastNLP.core.utils import load_pickle
from fastNLP.io.model_saver import ModelSaver
from fastNLP.io.model_loader import ModelLoader
from fastNLP.io.model_io import ModelLoader, ModelSaver
from fastNLP.core.tester import SeqLabelTester from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.predictor import SeqLabelInfer from fastNLP.core.predictor import SeqLabelInfer


+ 2
- 2
setup.py View File

@@ -12,12 +12,12 @@ with open('requirements.txt', encoding='utf-8') as f:
reqs = f.read() reqs = f.read()


setup( setup(
name='fastNLP',
name='FastNLP',
version='0.1.1', version='0.1.1',
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
long_description=readme, long_description=readme,
license=license, license=license,
author='fudanNLP',
author='FudanNLP',
python_requires='>=3.5', python_requires='>=3.5',
packages=find_packages(), packages=find_packages(),
install_requires=reqs.strip().split('\n'), install_requires=reqs.strip().split('\n'),


+ 12
- 0
test/api/test_processor.py View File

@@ -0,0 +1,12 @@
import unittest

from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
from fastNLP.core.dataset import DataSet


class TestProcessor(unittest.TestCase):
def test_FullSpaceToHalfSpaceProcessor(self):
ds = DataSet({"word": ["00, u1, u), (u2, u2"]})
proc = FullSpaceToHalfSpaceProcessor("word")
ds = proc(ds)
self.assertTrue(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"])

+ 5
- 5
test/core/test_loss.py View File

@@ -45,7 +45,7 @@ class TestLoss(unittest.TestCase):
# 验证squash()的正确性 # 验证squash()的正确性


log = math.log log = math.log
loss_func = loss.Loss("nll")
loss_func = loss.LossFromTorch("nll")


y = tc.Tensor( y = tc.Tensor(
[ [
@@ -129,7 +129,7 @@ class TestLoss(unittest.TestCase):
lens = [4, 2, 1] lens = [4, 2, 1]
y = tc.log(y) y = tc.log(y)


loss_func = loss.Loss("nll", pre_pro=["unpad"])
loss_func = loss.LossFromTorch("nll", pre_pro=["unpad"])
los = loss_func(y, gy, lens=lens) los = loss_func(y, gy, lens=lens)


r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
@@ -169,7 +169,7 @@ class TestLoss(unittest.TestCase):


lens = [2, 4, 2] lens = [2, 4, 2]


loss_func = loss.Loss("nll", pre_pro=["mask"])
loss_func = loss.LossFromTorch("nll", pre_pro=["mask"])
los = loss_func(y, gy, mask=mask) los = loss_func(y, gy, mask=mask)


los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1])) los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1]))
@@ -205,7 +205,7 @@ class TestLoss(unittest.TestCase):


y = tc.log(y) y = tc.log(y)


loss_func = loss.Loss("nll", pre_pro=["unpad_mask"])
loss_func = loss.LossFromTorch("nll", pre_pro=["unpad_mask"])
los = loss_func(y, gy, lens=lens) los = loss_func(y, gy, lens=lens)


r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
@@ -235,7 +235,7 @@ class TestLoss(unittest.TestCase):
lens = [4, 2, 1] lens = [4, 2, 1]
y = tc.log(y) y = tc.log(y)


loss_func = loss.Loss("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0]))
loss_func = loss.LossFromTorch("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0]))
loss_func.add_pre_pro("unpad_mask") loss_func.add_pre_pro("unpad_mask")
los = loss_func(y, gy, lens=lens) los = loss_func(y, gy, lens=lens)




+ 1
- 2
test/io/test_config_saver.py View File

@@ -1,8 +1,7 @@
import os import os
import unittest import unittest


from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.config_saver import ConfigSaver
from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver




class TestConfigSaver(unittest.TestCase): class TestConfigSaver(unittest.TestCase):


Loading…
Cancel
Save