From e9d6361e33b978f80b90bb762a0c03a4b8d7d76f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 23 Apr 2021 16:24:56 +0800 Subject: [PATCH] fix(mge/optimizer): fix optimizer update step inplace add grad GitOrigin-RevId: d677d1ca6b37bf94b89305a6102d2f1a11d6c872 --- imperative/python/megengine/optimizer/adadelta.py | 2 +- imperative/python/megengine/optimizer/adagrad.py | 2 +- imperative/python/megengine/optimizer/adam.py | 2 +- imperative/python/megengine/optimizer/sgd.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index 60b18593..1c321d21 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -84,7 +84,7 @@ class Adadelta(Optimizer): step += c1 grad = param.grad if weight_decay != 0.0: - grad += param * _weight_decay + grad = grad + param * _weight_decay square_avg = states["square_avg"] acc_delta = states["acc_delta"] diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index 9b309077..c983c791 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -82,7 +82,7 @@ class Adagrad(Optimizer): step += c1 grad = param.grad if weight_decay != 0.0: - grad += param * _weight_decay + grad = grad + param * _weight_decay square_avg = states["square_avg"] square_avg += grad ** c2 diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 4bd7bea6..40d5eec5 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -85,7 +85,7 @@ class Adam(Optimizer): grad = param.grad if weight_decay != 0.0: - grad += param * _weight_decay + grad = grad + param * _weight_decay states = self._state[param] diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 95e5867c..5ed256d2 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -72,7 +72,7 @@ class SGD(Optimizer): grad = param.grad if weight_decay != 0.0: - grad += param * _weight_decay + grad = grad + param * _weight_decay if inplace_mode: if momentum: