|
|
@@ -62,13 +62,13 @@ class SGD(Optimizer): |
|
|
|
# since `conver_inputs` is disabled for param updates, |
|
|
|
# scalar should be explicitly tansforred to tensor |
|
|
|
|
|
|
|
_lr = tensor(lr) |
|
|
|
_weight_decay = tensor(weight_decay) |
|
|
|
_momentum = tensor(momentum) |
|
|
|
_lr = tensor(lr, dtype="float32") |
|
|
|
_weight_decay = tensor(weight_decay, dtype="float32") |
|
|
|
_momentum = tensor(momentum, dtype="float32") |
|
|
|
|
|
|
|
inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) |
|
|
|
if inplace_mode: |
|
|
|
_neg_lr = tensor(-lr) |
|
|
|
_neg_lr = tensor(-lr, dtype="float32") |
|
|
|
c1 = tensor([1.0]) |
|
|
|
|
|
|
|
for param in param_group["params"]: |
|
|
|