@@ -10,6 +10,7 @@ from .adadelta import Adadelta | |||
from .adagrad import Adagrad | |||
from .adam import Adam | |||
from .adamw import AdamW | |||
from .clip_grad import * | |||
from .lr_scheduler import LRScheduler | |||
from .multi_step_lr import MultiStepLR | |||
from .optimizer import Optimizer | |||
@@ -0,0 +1,72 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=redefined-builtin | |||
from typing import Iterable, Union | |||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
from ..functional import clip, concat, minimum, norm | |||
from ..tensor import Tensor | |||
__all__ = ["clip_grad_norm", "clip_grad_value"] | |||
def clip_grad_norm( | |||
tensors: Union[Tensor, Iterable[Tensor]], max_norm: float, ord: float = 2.0, | |||
): | |||
r"""Clips gradient norm of an iterable of parameters. | |||
The norm is computed over all gradients together, as if they were | |||
concatenated into a single vector. Gradients are modified in-place. | |||
:param tensors: an iterable of Tensors or a single Tensor. | |||
:param max_norm: max norm of the gradients. | |||
:param ord: type of the used p-norm. Can be ``'inf'`` for infinity norm. | |||
:return: total norm of the parameters (viewed as a single vector). | |||
""" | |||
push_scope("clip_grad_norm") | |||
if isinstance(tensors, Tensor): | |||
tensors = [tensors] | |||
tensors = [t for t in tensors if t.grad is not None] | |||
if len(tensors) == 0: | |||
pop_scope("clip_grad_norm") | |||
return Tensor(0.0) | |||
norm_ = [norm(t.grad.flatten(), ord=ord) for t in tensors] | |||
if len(norm_) > 1: | |||
norm_ = norm(concat(norm_), ord=ord) | |||
else: | |||
norm_ = norm_[0] | |||
scale = max_norm / (norm_ + 1e-6) | |||
scale = minimum(scale, 1) | |||
for tensor in tensors: | |||
tensor.grad._reset(tensor.grad * scale) | |||
pop_scope("clip_grad_norm") | |||
return norm_ | |||
def clip_grad_value( | |||
tensors: Union[Tensor, Iterable[Tensor]], lower: float, upper: float | |||
): | |||
r"""Clips gradient of an iterable of parameters to a specified lower and | |||
upper. Gradients are modified in-place. | |||
The gradients are clipped in the range: | |||
.. math:: \left[\text{lower}, \text{upper}\right] | |||
:param tensors: an iterable of Tensors or a single Tensor. | |||
:param lower: minimum allowed value of the gradients. | |||
:param upper: maximum allowed value of the gradients. | |||
""" | |||
push_scope("clip_grad_value") | |||
if isinstance(tensors, Tensor): | |||
tensors = [tensors] | |||
for tensor in tensors: | |||
if tensor.grad is None: | |||
continue | |||
tensor.grad._reset(clip(tensor.grad, lower, upper)) | |||
pop_scope("clip_grad_value") |
@@ -0,0 +1,120 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import itertools | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
import megengine.optimizer as optim | |||
from megengine import Tensor | |||
from megengine.jit import trace | |||
from megengine.module import Linear, Module | |||
from megengine.optimizer import SGD | |||
batch_size = 64 | |||
data_shape = (batch_size, 2) | |||
label_shape = (batch_size,) | |||
def minibatch_generator(): | |||
while True: | |||
inp_data = np.zeros((batch_size, 2)) | |||
label = np.zeros(batch_size, dtype=np.int32) | |||
for i in range(batch_size): | |||
# [x0, x1], sampled from U[-1, 1] | |||
inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||
label[i] = 0 if np.prod(inp_data[i]) < 0 else 1 | |||
yield inp_data.astype(np.float32), label.astype(np.int32) | |||
def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float: | |||
""" Calculate precision for given data and prediction. | |||
:type data: [[x, y], ...] | |||
:param data: Input data | |||
:type pred: [[x_pred, y_pred], ...] | |||
:param pred: Network output data | |||
""" | |||
correct = 0 | |||
assert len(data) == len(pred) | |||
for inp_data, pred_output in zip(data, pred): | |||
label = 0 if np.prod(inp_data) < 0 else 1 | |||
pred_label = np.argmax(pred_output) | |||
if pred_label == label: | |||
correct += 1 | |||
return float(correct) / len(data) | |||
class XORNet(Module): | |||
def __init__(self): | |||
self.mid_layers = 14 | |||
self.num_class = 2 | |||
super().__init__() | |||
self.fc0 = Linear(self.num_class, self.mid_layers, bias=True) | |||
self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True) | |||
self.fc2 = Linear(self.mid_layers, self.num_class, bias=True) | |||
def forward(self, x): | |||
x = self.fc0(x) | |||
x = F.tanh(x) | |||
x = self.fc1(x) | |||
x = F.tanh(x) | |||
x = self.fc2(x) | |||
return x | |||
def test_training_converge(): | |||
net = XORNet() | |||
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
gm = ad.GradManager().attach(net.parameters()) | |||
@trace(symbolic=False) | |||
def train(data, label): | |||
with gm: | |||
pred = net(data) | |||
loss = F.nn.cross_entropy(pred, label) | |||
gm.backward(loss) | |||
optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0) | |||
return loss | |||
def infer(data): | |||
return net(data) | |||
train_dataset = minibatch_generator() | |||
losses = [] | |||
for data, label in itertools.islice(train_dataset, 2000): | |||
data = Tensor(data, dtype=np.float32) | |||
label = Tensor(label, dtype=np.int32) | |||
opt.clear_grad() | |||
loss = train(data, label) | |||
optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1) | |||
opt.step() | |||
losses.append(loss.numpy()) | |||
print(np.mean(losses[-100:])) | |||
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" | |||
ngrid = 10 | |||
x = np.linspace(-1.0, 1.0, ngrid) | |||
xx, yy = np.meshgrid(x, x) | |||
xx = xx.reshape((ngrid * ngrid, 1)) | |||
yy = yy.reshape((ngrid * ngrid, 1)) | |||
data = np.concatenate((xx, yy), axis=1).astype(np.float32) | |||
pred = infer(data).numpy() | |||
precision = calculate_precision(data, pred) | |||
print("precision=", precision) | |||
assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||
precision | |||
) |
@@ -0,0 +1,80 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import platform | |||
import weakref | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.optimizer as optim | |||
class Net(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = M.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = M.BatchNorm2d(64) | |||
self.avgpool = M.AvgPool2d(kernel_size=5, stride=5, padding=0) | |||
self.fc = M.Linear(64, 10) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = F.relu(x) | |||
x = self.avgpool(x) | |||
x = F.avg_pool2d(x, 22) | |||
x = F.flatten(x, 1) | |||
x = self.fc(x) | |||
return x | |||
def save_grad_value(net): | |||
for param in net.parameters(): | |||
param.grad_backup = param.grad.numpy().copy() | |||
def test_clip_grad_norm(): | |||
net = Net() | |||
x = mge.tensor(np.random.randn(10, 3, 224, 224)) | |||
gm = ad.GradManager().attach(net.parameters()) | |||
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) | |||
with gm: | |||
loss = net(x).sum() | |||
gm.backward(loss) | |||
save_grad_value(net) | |||
max_norm = 1.0 | |||
original_norm = optim.clip_grad_norm(net.parameters(), max_norm=max_norm, ord=2) | |||
scale = max_norm / original_norm | |||
for param in net.parameters(): | |||
np.testing.assert_almost_equal(param.grad.numpy(), param.grad_backup * scale) | |||
opt.step().clear_grad() | |||
def test_clip_grad_value(): | |||
net = Net() | |||
x = np.random.randn(10, 3, 224, 224).astype("float32") | |||
gm = ad.GradManager().attach(net.parameters()) | |||
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) | |||
with gm: | |||
y = net(x) | |||
y = y.mean() | |||
gm.backward(y) | |||
save_grad_value(net) | |||
max_val = 5 | |||
min_val = -2 | |||
optim.clip_grad_value(net.parameters(), lower=min_val, upper=max_val) | |||
for param in net.parameters(): | |||
np.testing.assert_almost_equal( | |||
param.grad.numpy(), | |||
np.maximum(np.minimum(param.grad_backup, max_val), min_val), | |||
) | |||
opt.step().clear_grad() |