From 215a67e7af2d3d7a036f8edf23e192056645a0db Mon Sep 17 00:00:00 2001 From: choosewhatulike <1901722105@qq.com> Date: Thu, 26 Jul 2018 22:45:25 +0800 Subject: [PATCH] fix optimizer --- fastNLP/action/optimizer.py | 5 +++++ fastNLP/action/optimizor.py | 50 --------------------------------------------- 2 files changed, 5 insertions(+), 50 deletions(-) create mode 100644 fastNLP/action/optimizer.py delete mode 100644 fastNLP/action/optimizor.py diff --git a/fastNLP/action/optimizer.py b/fastNLP/action/optimizer.py new file mode 100644 index 00000000..b493e3f0 --- /dev/null +++ b/fastNLP/action/optimizer.py @@ -0,0 +1,5 @@ +''' +use optimizer from Pytorch +''' + +from torch.optim import * \ No newline at end of file diff --git a/fastNLP/action/optimizor.py b/fastNLP/action/optimizor.py deleted file mode 100644 index becdc499..00000000 --- a/fastNLP/action/optimizor.py +++ /dev/null @@ -1,50 +0,0 @@ -from torch import optim - - -def get_torch_optimizer(params, alg_name='sgd', **args): - """ - construct PyTorch optimizer by algorithm's name - optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc - - usage: - optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01) - - """ - - name = alg_name.lower() - if name == 'adadelta': - return optim.Adadelta(params, **args) - elif name == 'adagrad': - return optim.Adagrad(params, **args) - elif name == 'adam': - return optim.Adam(params, **args) - elif name == 'adamax': - return optim.Adamax(params, **args) - elif name == 'asgd': - return optim.ASGD(params, **args) - elif name == 'lbfgs': - return optim.LBFGS(params, **args) - elif name == 'rmsprop': - return optim.RMSprop(params, **args) - elif name == 'rprop': - return optim.Rprop(params, **args) - elif name == 'sgd': - # SGD's parameter lr is required - if 'lr' not in args: - args['lr'] = 0.01 - return optim.SGD(params, **args) - elif name == 'sparseadam': - return optim.SparseAdam(params, **args) - else: - raise TypeError('no such optimizer named {}'.format(alg_name)) - - -if __name__ == '__main__': - from torch.nn.modules import Linear - - net = Linear(2, 5) - - test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3) - print(test1) - test2 = get_torch_optimizer(net.parameters(), 'SGD') - print(test2)