GitOrigin-RevId: 3d06e3db3c
tags/v1.0.0-rc1
@@ -29,6 +29,7 @@ class GradManager: | |||
def register_after_backward_callback(self, callback): | |||
self._after_backward_callback.append(callback) | |||
return self | |||
def backward(self, ys, dys=None): | |||
global backwarding_grad_manager | |||
@@ -177,6 +177,13 @@ class Grad: | |||
dys = aslist(dys) | |||
assert len(ys) == len(dys) | |||
ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | |||
if len(ids) == 0: | |||
return | |||
ys = [y for i, y in enumerate(ys) if i in ids] | |||
dys = [dy for i, dy in enumerate(dys) if i in ids] | |||
# ys is changed to a list of VariableNode which contains more information | |||
# such as OpNode, callback, etc. | |||
ys = [i._extra_data[self].node for i in ys] | |||
@@ -20,8 +20,8 @@ from ..core.autodiff.grad import ( | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.core import apply | |||
from ..core.tensor.tensor import Tensor, tensor_apply | |||
from ..tensor import tensor | |||
from ..device import get_default_device | |||
from ..tensor import tensor | |||
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
__all__ = [ | |||
@@ -11,7 +11,7 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from ..functional import sqrt | |||
from ..tensor_nn import Buffer, Parameter | |||
from ..tensor_nn import Parameter | |||
from .optimizer import Optimizer | |||
@@ -63,16 +63,7 @@ class Adadelta(Optimizer): | |||
for param in param_group["params"]: | |||
if param.__wrapped__ in self._grad_skip: | |||
self._grad_skip.remove(param.__wrapped__) | |||
continue | |||
if not isinstance(param.grad, Buffer): | |||
raise TypeError( | |||
"grad must be a Buffer, maybe you forget to call backward()?" | |||
) | |||
if not param.requires_grad: | |||
if not param.requires_grad or "grad" not in param.__dict__: | |||
continue | |||
states = self._state[param] | |||
@@ -91,5 +82,3 @@ class Adadelta(Optimizer): | |||
acc_delta = rho * acc_delta + (1 - rho) * delta ** 2 | |||
states["square_avg"]._reset(square_avg) | |||
states["acc_delta"]._reset(acc_delta) | |||
assert len(self._grad_skip) == 0 |
@@ -11,7 +11,7 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from ..functional import sqrt | |||
from ..tensor_nn import Buffer, Parameter | |||
from ..tensor_nn import Parameter | |||
from .optimizer import Optimizer | |||
@@ -62,16 +62,7 @@ class Adagrad(Optimizer): | |||
for param in param_group["params"]: | |||
if param.__wrapped__ in self._grad_skip: | |||
self._grad_skip.remove(param.__wrapped__) | |||
continue | |||
if not isinstance(param.grad, Buffer): | |||
raise TypeError( | |||
"grad must be a Buffer, maybe you forget to call backward()?" | |||
) | |||
if not param.requires_grad: | |||
if not param.requires_grad or "grad" not in param.__dict__: | |||
continue | |||
states = self._state[param] | |||
@@ -87,4 +78,3 @@ class Adagrad(Optimizer): | |||
clr = lr / (1 + (step - 1) * lr_decay) | |||
param -= clr * delta | |||
assert len(self._grad_skip) == 0 |
@@ -8,7 +8,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Tuple, Union | |||
from ..tensor_nn import Buffer, Parameter | |||
from ..tensor_nn import Parameter | |||
from .optimizer import Optimizer | |||
@@ -59,18 +59,9 @@ class Adam(Optimizer): | |||
for param in param_group["params"]: | |||
if param.__wrapped__ in self._grad_skip: | |||
self._grad_skip.remove(param.__wrapped__) | |||
if not param.requires_grad or "grad" not in param.__dict__: | |||
continue | |||
if not param.requires_grad: | |||
continue | |||
if not isinstance(param.grad, Buffer): | |||
raise TypeError( | |||
"grad must be a Buffer, maybe you forget to call backward()?" | |||
) | |||
grad = param.grad | |||
if weight_decay != 0.0: | |||
grad += param * weight_decay | |||
@@ -91,5 +82,3 @@ class Adam(Optimizer): | |||
# not inplace change, need to update underlying tensor handler in state | |||
states["exp_avg"]._reset(exp_avg) | |||
states["exp_avg_sq"]._reset(exp_avg_sq) | |||
assert len(self._grad_skip) == 0 |
@@ -8,7 +8,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Union | |||
from ..tensor_nn import Buffer, Parameter | |||
from ..tensor_nn import Parameter | |||
from .optimizer import Optimizer | |||
@@ -52,7 +52,7 @@ class SGD(Optimizer): | |||
momentum = param_group["momentum"] | |||
for param in param_group["params"]: | |||
if not param.requires_grad: | |||
if not param.requires_grad or "grad" not in param.__dict__: | |||
continue | |||
grad = param.grad | |||
@@ -9,6 +9,7 @@ | |||
import numpy as np | |||
import megengine | |||
import megengine.autodiff as ad | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.module import Module | |||
@@ -37,8 +38,9 @@ class Simple2(Module): | |||
def test_advance_indexing(): | |||
net = Simple() | |||
gm = ad.GradManager().register(net.parameters()) | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
dshape = (10, 10) | |||
raw_data = np.arange(100).reshape(dshape).astype(np.float32) | |||
@@ -46,9 +48,9 @@ def test_advance_indexing(): | |||
data = tensor(raw_data) | |||
mask = tensor(raw_mask) | |||
answer = 1.0 - raw_data[raw_mask].sum() | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net(data, mask).sum() | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) | |||
@@ -56,15 +58,16 @@ def test_advance_indexing(): | |||
def test_advance_indexing_with_subtensor(): | |||
net = Simple2() | |||
gm = ad.GradManager().register(net.parameters()) | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
dshape = (2, 3, 4, 3, 4, 2) | |||
raw_data = np.arange(576).reshape(dshape).astype(np.float32) | |||
data = tensor(raw_data) | |||
answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum() | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net(data).sum() | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) |
@@ -9,6 +9,7 @@ | |||
import numpy as np | |||
import megengine | |||
import megengine.autodiff as ad | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.module import Module | |||
@@ -27,14 +28,15 @@ class Simple(Module): | |||
def test_ai(): | |||
net = Simple() | |||
gm = ad.GradManager().register(net.parameters()) | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
dshape = (10, 10) | |||
data = tensor(np.ones(dshape).astype(np.float32)) | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net(data).sum() | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_almost_equal( | |||
net.a.numpy(), np.array([1.0 - dshape[0]]).astype(np.float32) | |||
@@ -10,6 +10,7 @@ import numpy as np | |||
import pytest | |||
import megengine | |||
import megengine.autodiff as ad | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.module import BatchNorm2d | |||
@@ -24,13 +25,14 @@ def test_frozen_bn(): | |||
saved_wt = m.weight.numpy() | |||
saved_bias = m.bias.numpy() | |||
gm = ad.GradManager().register(m.parameters()) | |||
optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
with optim.record(): | |||
with gm.record(): | |||
loss = m(data).mean() | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
@@ -44,13 +46,14 @@ def test_bn_no_track_stat(): | |||
nchannel = 3 | |||
m = BatchNorm2d(nchannel, track_running_stats=False) | |||
gm = ad.GradManager().register(m.parameters()) | |||
optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
with optim.record(): | |||
with gm.record(): | |||
loss = m(data).sum() | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
@@ -65,13 +68,14 @@ def test_bn_no_track_stat2(): | |||
saved_mean = m.running_mean.numpy() | |||
assert saved_mean is not None | |||
gm = ad.GradManager().register(m.parameters()) | |||
optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
with optim.record(): | |||
with gm.record(): | |||
loss = m(data).sum() | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
@@ -12,6 +12,7 @@ import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
from megengine import Tensor | |||
from megengine.module import Linear, Module | |||
@@ -76,12 +77,13 @@ def test_training_converge(): | |||
opt = SGD( | |||
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | |||
) | |||
gm = ad.GradManager().register(net.parameters()) | |||
def train(data, label): | |||
with opt.record(): | |||
with gm.record(): | |||
pred = net(data) | |||
loss = F.cross_entropy_with_softmax(pred, label) | |||
opt.backward(loss) | |||
gm.backward(loss) | |||
return loss | |||
def infer(data): | |||
@@ -93,7 +95,7 @@ def test_training_converge(): | |||
for data, label in itertools.islice(train_dataset, 2000): | |||
data = Tensor(data, dtype=np.float32) | |||
label = Tensor(label, dtype=np.int32) | |||
opt.zero_grad() | |||
opt.clear_grad() | |||
loss = train(data, label) | |||
opt.step() | |||
losses.append(loss.numpy()) | |||
@@ -15,6 +15,7 @@ import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
from megengine import jit | |||
from megengine.core._trace_option import set_tensor_shape | |||
@@ -89,11 +90,11 @@ class MnistNet(Module): | |||
return x | |||
def train(data, label, net, opt): | |||
with opt.record(): | |||
def train(data, label, net, opt, gm): | |||
with gm.record(): | |||
pred = net(data) | |||
loss = F.cross_entropy_with_softmax(pred, label) | |||
opt.backward(loss) | |||
gm.backward(loss) | |||
return loss | |||
@@ -116,12 +117,13 @@ def update_model(model_path): | |||
net.load_state_dict(checkpoint["net_init"]) | |||
lr = checkpoint["sgd_lr"] | |||
opt = SGD(net.parameters(), lr=lr) | |||
gm = ad.GradManager().register(net.parameters()) | |||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||
label = Tensor(checkpoint["label"], dtype=np.int32) | |||
opt.zero_grad() | |||
loss = train(data, label, net=net, opt=opt) | |||
opt.clear_grad() | |||
loss = train(data, label, net, opt, gm) | |||
opt.step() | |||
xpu_name = get_xpu_name() | |||
@@ -150,6 +152,7 @@ def run_train( | |||
net.load_state_dict(checkpoint["net_init"]) | |||
lr = checkpoint["sgd_lr"] | |||
opt = SGD(net.parameters(), lr=lr) | |||
gm = ad.GradManager().register(net.parameters()) | |||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||
label = Tensor(checkpoint["label"], dtype=np.int32) | |||
@@ -165,8 +168,8 @@ def run_train( | |||
sublinear_memory_config=sublinear_memory_config, | |||
) | |||
opt.zero_grad() | |||
loss = train_func(data, label, net=net, opt=opt) | |||
opt.clear_grad() | |||
loss = train_func(data, label, net, opt, gm) | |||
opt.step() | |||
assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) | |||
@@ -9,6 +9,7 @@ | |||
import numpy as np | |||
import megengine | |||
import megengine.autodiff as ad | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.module import Module | |||
@@ -30,13 +31,14 @@ def test_detach(): | |||
net = Simple() | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
gm = ad.GradManager().register(net.parameters()) | |||
dshape = (10, 10) | |||
data = tensor(np.ones(dshape).astype(np.float32)) | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net(data).sum() | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_equal(net.a.numpy(), np.array([1.0]).astype(np.float32)) | |||
np.testing.assert_equal( | |||
@@ -18,6 +18,7 @@ import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.distributed as dist | |||
import megengine.functional as F | |||
from megengine.device import get_default_device, set_default_device | |||
@@ -94,11 +95,13 @@ class MnistNet(Module): | |||
return x | |||
def train(data, label, net, opt): | |||
with opt.record(): | |||
def train(data, label, net, opt, gm): | |||
opt.clear_grad() | |||
with gm.record(): | |||
pred = net(data) | |||
loss = F.cross_entropy_with_softmax(pred, label) | |||
opt.backward(loss) | |||
gm.backward(loss) | |||
opt.step() | |||
return loss | |||
@@ -111,7 +114,7 @@ def update_model(model_path): | |||
.. code-block:: python | |||
from test_correctness import update_model | |||
from test_dp_correctness import update_model | |||
update_model('mnist_model_with_test.mge') # for gpu | |||
update_model('mnist_model_with_test_cpu.mge') # for cpu | |||
@@ -122,6 +125,11 @@ def update_model(model_path): | |||
lr = checkpoint["sgd_lr"] | |||
opt = SGD(net.parameters(), lr=lr) | |||
gm = ad.GradManager() | |||
gm.register( | |||
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] | |||
) | |||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||
label = Tensor(checkpoint["label"], dtype=np.int32) | |||
@@ -158,24 +166,23 @@ def run_test( | |||
def worker(rank, max_err): | |||
dist.init_process_group("localhost", port, p_num, rank, rank) | |||
set_default_device(device="gpu{}".format(dist.get_rank())) | |||
net = MnistNet(has_bn=True) | |||
net.load_state_dict(checkpoint["net_init"]) | |||
lr = checkpoint["sgd_lr"] | |||
opt = SGD(net.parameters(), reduce_method="mean", lr=lr) | |||
opt = SGD(net.parameters(), lr=lr) | |||
gm = ad.GradManager() | |||
gm.register( | |||
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] | |||
) | |||
# use same data and label for all gpu's | |||
# such that the result does not depend on number of gpu | |||
data_train = Tensor(data) | |||
label_train = Tensor(label) | |||
train_func = train | |||
opt.zero_grad() | |||
loss = train_func(data_train, label_train, net=net, opt=opt) | |||
opt.step() | |||
loss = train(data_train, label_train, net, opt, gm) | |||
print("{} loss {}".format(get_default_device(), loss.numpy()[0])) | |||
assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) | |||
if dist.get_rank(): | |||
@@ -12,6 +12,7 @@ import numpy as np | |||
import pytest | |||
import megengine | |||
import megengine.autodiff as ad | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.module import Module | |||
@@ -31,12 +32,13 @@ def test_hello_world(): | |||
net = Simple() | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
gm = ad.GradManager().register(net.parameters()) | |||
data = tensor([2.34]) | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net(data) | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_almost_equal( | |||
net.a.numpy(), np.array([1.23 - 2.34]).astype(np.float32) | |||
@@ -8,6 +8,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
from megengine import Parameter, optimizer | |||
from megengine.jit import trace | |||
@@ -43,6 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
net = Simple() | |||
opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) | |||
check_func = check_class(net, **test_case) | |||
gm = ad.GradManager().register(net.parameters()) | |||
step = 0 | |||
data_shape = (2, 28) | |||
@@ -54,11 +56,11 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
check_func.lr += 0.01 | |||
data = tensor(np.random.random(data_shape).astype(np.float32)) | |||
opt.zero_grad() | |||
with opt.record(): | |||
opt.clear_grad() | |||
with gm.record(): | |||
pred = net(data) | |||
loss = pred.sum() | |||
opt.backward(loss) | |||
gm.backward(loss) | |||
ori_params = TensorDict() | |||
for param in net.parameters(): | |||
@@ -1,6 +1,7 @@ | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.core.tensor.raw_tensor import RawTensor | |||
@@ -21,13 +22,14 @@ def test_save_load(): | |||
net = Simple() | |||
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
gm = ad.GradManager().register(net.parameters()) | |||
data = tensor([2.34]) | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net(data) | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
@@ -53,9 +55,9 @@ def test_save_load(): | |||
optim.load_state_dict(checkpoint["opt_state"]) | |||
print("load done") | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net([1.23]) | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
# Restore device | |||
@@ -9,6 +9,7 @@ | |||
import numpy as np | |||
import megengine | |||
import megengine.autodiff as ad | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.jit import trace | |||
@@ -29,14 +30,15 @@ def test_sgd_momentum(): | |||
net = Simple() | |||
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) | |||
optim.zero_grad() | |||
optim.clear_grad() | |||
gm = ad.GradManager().register(net.parameters()) | |||
data = tensor([2.34]) | |||
# do a step of train | |||
with optim.record(): | |||
with gm.record(): | |||
loss = net(data) | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) | |||
@@ -48,10 +50,10 @@ def test_sgd_momentum(): | |||
np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) | |||
# do a step of train | |||
optim.zero_grad() | |||
with optim.record(): | |||
optim.clear_grad() | |||
with gm.record(): | |||
loss = net(data) | |||
optim.backward(loss) | |||
gm.backward(loss) | |||
optim.step() | |||
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) | |||
@@ -9,6 +9,7 @@ import copy | |||
import numpy as np | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter | |||
@@ -41,13 +42,14 @@ def test_single_input(): | |||
return x | |||
net = Simple(av) | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
gm = ad.GradManager().register(net.parameters()) | |||
opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
with optim.record(): | |||
opt.clear_grad() | |||
with gm.record(): | |||
loss = net() | |||
optim.backward(loss.sum()) | |||
optim.step() | |||
gm.backward(loss.sum()) | |||
opt.step() | |||
np.testing.assert_almost_equal(loss.numpy(), (av * 10)) | |||
np.testing.assert_almost_equal(net.a.numpy(), (av - 10)) | |||
@@ -79,13 +81,14 @@ def test_multi_input(): | |||
return x | |||
net = Simple(av, bv) | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
gm = ad.GradManager().register(net.parameters()) | |||
opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
with optim.record(): | |||
opt.clear_grad() | |||
with gm.record(): | |||
loss = net() | |||
optim.backward(loss.sum()) | |||
optim.step() | |||
gm.backward(loss.sum()) | |||
opt.step() | |||
np.testing.assert_almost_equal(loss.numpy(), (av * bv)) | |||
np.testing.assert_almost_equal(net.a.numpy(), (av - 2 * bv)) | |||
@@ -118,13 +121,14 @@ def test_multi_output(): | |||
return x + y | |||
net = Simple(av, bv) | |||
optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
optim.zero_grad() | |||
gm = ad.GradManager().register(net.parameters()) | |||
opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
with optim.record(): | |||
opt.clear_grad() | |||
with gm.record(): | |||
loss = net() | |||
optim.backward(loss.sum()) | |||
optim.step() | |||
gm.backward(loss.sum()) | |||
opt.step() | |||
np.testing.assert_almost_equal(loss.numpy(), (av * bv + av + bv), decimal=6) | |||
np.testing.assert_almost_equal(net.a.numpy(), (av - bv - 1), decimal=6) | |||