@@ -10,6 +10,7 @@ from .adadelta import Adadelta | |||||
from .adagrad import Adagrad | from .adagrad import Adagrad | ||||
from .adam import Adam | from .adam import Adam | ||||
from .adamw import AdamW | from .adamw import AdamW | ||||
from .clip_grad import * | |||||
from .lr_scheduler import LRScheduler | from .lr_scheduler import LRScheduler | ||||
from .multi_step_lr import MultiStepLR | from .multi_step_lr import MultiStepLR | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
@@ -0,0 +1,72 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# pylint: disable=redefined-builtin | |||||
from typing import Iterable, Union | |||||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||||
from ..functional import clip, concat, minimum, norm | |||||
from ..tensor import Tensor | |||||
__all__ = ["clip_grad_norm", "clip_grad_value"] | |||||
def clip_grad_norm( | |||||
tensors: Union[Tensor, Iterable[Tensor]], max_norm: float, ord: float = 2.0, | |||||
): | |||||
r"""Clips gradient norm of an iterable of parameters. | |||||
The norm is computed over all gradients together, as if they were | |||||
concatenated into a single vector. Gradients are modified in-place. | |||||
:param tensors: an iterable of Tensors or a single Tensor. | |||||
:param max_norm: max norm of the gradients. | |||||
:param ord: type of the used p-norm. Can be ``'inf'`` for infinity norm. | |||||
:return: total norm of the parameters (viewed as a single vector). | |||||
""" | |||||
push_scope("clip_grad_norm") | |||||
if isinstance(tensors, Tensor): | |||||
tensors = [tensors] | |||||
tensors = [t for t in tensors if t.grad is not None] | |||||
if len(tensors) == 0: | |||||
pop_scope("clip_grad_norm") | |||||
return Tensor(0.0) | |||||
norm_ = [norm(t.grad.flatten(), ord=ord) for t in tensors] | |||||
if len(norm_) > 1: | |||||
norm_ = norm(concat(norm_), ord=ord) | |||||
else: | |||||
norm_ = norm_[0] | |||||
scale = max_norm / (norm_ + 1e-6) | |||||
scale = minimum(scale, 1) | |||||
for tensor in tensors: | |||||
tensor.grad._reset(tensor.grad * scale) | |||||
pop_scope("clip_grad_norm") | |||||
return norm_ | |||||
def clip_grad_value( | |||||
tensors: Union[Tensor, Iterable[Tensor]], lower: float, upper: float | |||||
): | |||||
r"""Clips gradient of an iterable of parameters to a specified lower and | |||||
upper. Gradients are modified in-place. | |||||
The gradients are clipped in the range: | |||||
.. math:: \left[\text{lower}, \text{upper}\right] | |||||
:param tensors: an iterable of Tensors or a single Tensor. | |||||
:param lower: minimum allowed value of the gradients. | |||||
:param upper: maximum allowed value of the gradients. | |||||
""" | |||||
push_scope("clip_grad_value") | |||||
if isinstance(tensors, Tensor): | |||||
tensors = [tensors] | |||||
for tensor in tensors: | |||||
if tensor.grad is None: | |||||
continue | |||||
tensor.grad._reset(clip(tensor.grad, lower, upper)) | |||||
pop_scope("clip_grad_value") |
@@ -0,0 +1,120 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import itertools | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.autodiff as ad | |||||
import megengine.functional as F | |||||
import megengine.optimizer as optim | |||||
from megengine import Tensor | |||||
from megengine.jit import trace | |||||
from megengine.module import Linear, Module | |||||
from megengine.optimizer import SGD | |||||
batch_size = 64 | |||||
data_shape = (batch_size, 2) | |||||
label_shape = (batch_size,) | |||||
def minibatch_generator(): | |||||
while True: | |||||
inp_data = np.zeros((batch_size, 2)) | |||||
label = np.zeros(batch_size, dtype=np.int32) | |||||
for i in range(batch_size): | |||||
# [x0, x1], sampled from U[-1, 1] | |||||
inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||||
label[i] = 0 if np.prod(inp_data[i]) < 0 else 1 | |||||
yield inp_data.astype(np.float32), label.astype(np.int32) | |||||
def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float: | |||||
""" Calculate precision for given data and prediction. | |||||
:type data: [[x, y], ...] | |||||
:param data: Input data | |||||
:type pred: [[x_pred, y_pred], ...] | |||||
:param pred: Network output data | |||||
""" | |||||
correct = 0 | |||||
assert len(data) == len(pred) | |||||
for inp_data, pred_output in zip(data, pred): | |||||
label = 0 if np.prod(inp_data) < 0 else 1 | |||||
pred_label = np.argmax(pred_output) | |||||
if pred_label == label: | |||||
correct += 1 | |||||
return float(correct) / len(data) | |||||
class XORNet(Module): | |||||
def __init__(self): | |||||
self.mid_layers = 14 | |||||
self.num_class = 2 | |||||
super().__init__() | |||||
self.fc0 = Linear(self.num_class, self.mid_layers, bias=True) | |||||
self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True) | |||||
self.fc2 = Linear(self.mid_layers, self.num_class, bias=True) | |||||
def forward(self, x): | |||||
x = self.fc0(x) | |||||
x = F.tanh(x) | |||||
x = self.fc1(x) | |||||
x = F.tanh(x) | |||||
x = self.fc2(x) | |||||
return x | |||||
def test_training_converge(): | |||||
net = XORNet() | |||||
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||||
gm = ad.GradManager().attach(net.parameters()) | |||||
@trace(symbolic=False) | |||||
def train(data, label): | |||||
with gm: | |||||
pred = net(data) | |||||
loss = F.nn.cross_entropy(pred, label) | |||||
gm.backward(loss) | |||||
optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0) | |||||
return loss | |||||
def infer(data): | |||||
return net(data) | |||||
train_dataset = minibatch_generator() | |||||
losses = [] | |||||
for data, label in itertools.islice(train_dataset, 2000): | |||||
data = Tensor(data, dtype=np.float32) | |||||
label = Tensor(label, dtype=np.int32) | |||||
opt.clear_grad() | |||||
loss = train(data, label) | |||||
optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1) | |||||
opt.step() | |||||
losses.append(loss.numpy()) | |||||
print(np.mean(losses[-100:])) | |||||
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" | |||||
ngrid = 10 | |||||
x = np.linspace(-1.0, 1.0, ngrid) | |||||
xx, yy = np.meshgrid(x, x) | |||||
xx = xx.reshape((ngrid * ngrid, 1)) | |||||
yy = yy.reshape((ngrid * ngrid, 1)) | |||||
data = np.concatenate((xx, yy), axis=1).astype(np.float32) | |||||
pred = infer(data).numpy() | |||||
precision = calculate_precision(data, pred) | |||||
print("precision=", precision) | |||||
assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||||
precision | |||||
) |
@@ -0,0 +1,80 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import platform | |||||
import weakref | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.autodiff as ad | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.optimizer as optim | |||||
class Net(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv1 = M.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||||
self.bn1 = M.BatchNorm2d(64) | |||||
self.avgpool = M.AvgPool2d(kernel_size=5, stride=5, padding=0) | |||||
self.fc = M.Linear(64, 10) | |||||
def forward(self, x): | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.avgpool(x) | |||||
x = F.avg_pool2d(x, 22) | |||||
x = F.flatten(x, 1) | |||||
x = self.fc(x) | |||||
return x | |||||
def save_grad_value(net): | |||||
for param in net.parameters(): | |||||
param.grad_backup = param.grad.numpy().copy() | |||||
def test_clip_grad_norm(): | |||||
net = Net() | |||||
x = mge.tensor(np.random.randn(10, 3, 224, 224)) | |||||
gm = ad.GradManager().attach(net.parameters()) | |||||
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) | |||||
with gm: | |||||
loss = net(x).sum() | |||||
gm.backward(loss) | |||||
save_grad_value(net) | |||||
max_norm = 1.0 | |||||
original_norm = optim.clip_grad_norm(net.parameters(), max_norm=max_norm, ord=2) | |||||
scale = max_norm / original_norm | |||||
for param in net.parameters(): | |||||
np.testing.assert_almost_equal(param.grad.numpy(), param.grad_backup * scale) | |||||
opt.step().clear_grad() | |||||
def test_clip_grad_value(): | |||||
net = Net() | |||||
x = np.random.randn(10, 3, 224, 224).astype("float32") | |||||
gm = ad.GradManager().attach(net.parameters()) | |||||
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) | |||||
with gm: | |||||
y = net(x) | |||||
y = y.mean() | |||||
gm.backward(y) | |||||
save_grad_value(net) | |||||
max_val = 5 | |||||
min_val = -2 | |||||
optim.clip_grad_value(net.parameters(), lower=min_val, upper=max_val) | |||||
for param in net.parameters(): | |||||
np.testing.assert_almost_equal( | |||||
param.grad.numpy(), | |||||
np.maximum(np.minimum(param.grad_backup, max_val), min_val), | |||||
) | |||||
opt.step().clear_grad() |