From 46f03b9f7ca854b85079246a35daeab3a1a86c92 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 31 Aug 2020 15:41:07 +0800 Subject: [PATCH] perf(mge/imperative): optimize gpu memory with standalone grad_fn GitOrigin-RevId: 44952b235ebe213e168c050a3b20636efee53d5d --- .../megengine/core/autodiff/builtin_op_utils.py | 133 +++++++++++++++++---- imperative/python/test/unit/test_autodiff.py | 84 +++++++++++++ 2 files changed, 195 insertions(+), 22 deletions(-) diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index 54d959d1..51f54194 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -12,9 +12,26 @@ import itertools import numpy as np from .._imperative_rt import TensorAttr, imperative -from ..ops.builtin import Elemwise, GetVarShape, OpDef, OprAttr, Reduce, Reshape +from ..ops.builtin import ( + Broadcast, + Elemwise, + GetVarShape, + IndexingMultiAxisVec, + IndexingSetMultiAxisVec, + OpDef, + OprAttr, + Reduce, + Reshape, + SetSubtensor, + Subtensor, +) +from ..ops.special import Const from ..tensor.core import apply from ..tensor.function import Function +from ..tensor.tensor_wrapper import TensorWrapper + +_elemwise_add_param = Elemwise(mode="add").to_c().param +_reduce_sum_param = Reduce(mode="SUM").to_c().param[0] @functools.singledispatch @@ -22,19 +39,17 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): assert 0 -_elemwise_add_param = Elemwise(mode="add").to_c().param - - @builtin_op_get_backward_fn.register(OpDef) def _(op: OpDef, inputs, outputs, input_requires_grad): - if ( - isinstance(op, OprAttr) - and op.type == "Elemwise" - and op.param == _elemwise_add_param - ): - grad_fn = elemwise_grad_fn - elif isinstance(op, OprAttr) and op.type == Reshape.name: - grad_fn = reshape_grad_fn + if isinstance(op, OprAttr): + grad_fn = _oprAttr_grad_fn.get(op.type, None) + if grad_fn is None: + if op.type == Elemwise.name and op.param == _elemwise_add_param: + grad_fn = elemwise_add_grad_fn + elif op.type == Reduce.name and op.param[0] == _reduce_sum_param: + grad_fn = reduce_sum_grad_fn + else: + grad_fn = default_grad_fn else: grad_fn = default_grad_fn return grad_fn(op, inputs, outputs, input_requires_grad) @@ -73,6 +88,7 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad): save_for_backward = tuple( val for val, mask in zip(inputs + outputs, intput_output_mask) if mask ) + del inputs del outputs @@ -85,13 +101,14 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad): return backward, output_grad_mask -# override for elemwise -def elemwise_grad_fn(op, inputs, outputs, input_requires_grad): - assert len(inputs) == len(input_requires_grad) == 2 +def get_shape(x): + (s,) = apply(GetVarShape(), x) + return s + - def get_shape(x): - (s,) = apply(GetVarShape(), x) - return s +# override for Elemwise.add +def elemwise_add_grad_fn(op, inputs, outputs, input_requires_grad): + assert len(inputs) == len(input_requires_grad) == 2 input_shapes = [ get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) @@ -110,13 +127,10 @@ def elemwise_grad_fn(op, inputs, outputs, input_requires_grad): return backward, [True] +# override for Reshape def reshape_grad_fn(op, inputs, outputs, input_requires_grad): assert len(inputs) == len(input_requires_grad) == 2 - def get_shape(x): - (s,) = apply(GetVarShape(), x) - return s - input_shapes = [ get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) ] @@ -132,3 +146,78 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad): ) return backward, [True] + + +# override for Subtensor +def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): + grad_op = OprAttr() + grad_op.type = SetSubtensor.name + grad_op.param = op.param + + input_shape = get_shape(inputs[0]) + params = inputs[1:] + + def make_grad(grad_op, dy): + grad = ( + TensorWrapper(0, dtype=dy.dtype, device=dy.device) + .broadcast(TensorWrapper(input_shape)) + .__wrapped__ + ) + (dx,) = apply(grad_op, grad, dy, *params) + return dx + + def backward(dy): + return tuple( + make_grad(grad_op, dy) if mask else None for mask in input_requires_grad + ) + + return backward, [True] + + +# override for IndexingMultiAxisVec +def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): + grad_op = OprAttr() + grad_op.type = IndexingSetMultiAxisVec.name + grad_op.param = op.param + + input_shape = get_shape(inputs[0]) + params = inputs[1:] + + def make_grad(grad_op, dy): + grad = ( + TensorWrapper(0, dtype=dy.dtype, device=dy.device) + .broadcast(TensorWrapper(input_shape)) + .__wrapped__ + ) + (dx,) = apply(grad_op, grad, dy, *params) + return dx + + def backward(dy): + return tuple( + make_grad(grad_op, dy) if mask else None for mask in input_requires_grad + ) + + return backward, [True] + + +# override for Reduce.sum +def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad): + assert len(inputs) == len(input_requires_grad) == 1 + input_shape = get_shape(inputs[0]) + + def broadcast_to(dy, s): + (dx,) = apply(Broadcast(), dy, s) + return dx + + def backward(dy): + return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,) + + return backward, [True] + + +_oprAttr_grad_fn = { + Reshape.name: reshape_grad_fn, + Subtensor.name: subtensor_grad_fn, + IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, + Broadcast.name: elemwise_add_grad_fn, +} diff --git a/imperative/python/test/unit/test_autodiff.py b/imperative/python/test/unit/test_autodiff.py index 85b60e82..3caaed61 100644 --- a/imperative/python/test/unit/test_autodiff.py +++ b/imperative/python/test/unit/test_autodiff.py @@ -14,6 +14,7 @@ import pytest import megengine as mge import megengine.distributed as dist +import megengine.functional as F from megengine.core._imperative_rt import TensorAttr, imperative from megengine.core._imperative_rt.imperative import sync from megengine.core.autodiff.grad import Grad @@ -229,3 +230,86 @@ def test_elemwise_relu_backward_fn(): result = imperative.make_backward_graph(op, [attr], [True], [True]) backward_graph, save_for_backward_mask, input_has_grad = result assert save_for_backward_mask == [False, True, True], save_for_backward_mask + + +def test_reshape(): + x_np = np.random.rand(2, 5).astype("float32") + x = TensorWrapper(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = x.reshape(5, 2) + + grad(y, F.ones_like(y)) + np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) + + +def test_subtensor(): + x_np = np.random.rand(3, 3).astype("float32") + x = TensorWrapper(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = x[1:-1, :2] + + grad(y, F.ones_like(y)) + np.testing.assert_equal( + np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() + ) + + +def test_IndexingMultiAxisVec(): + x_np = np.random.rand(3, 3).astype("float32") + x = TensorWrapper(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = x[[0, 2], [0, 2]] + + grad(y, F.ones_like(y)) + np.testing.assert_equal( + np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() + ) + + +def test_AxisAddRemove(): + x_np = np.random.rand(1, 5).astype("float32") + x = TensorWrapper(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = F.remove_axis(F.add_axis(x, 2), 0) + + grad(y, F.ones_like(y)) + np.testing.assert_equal( + np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy() + ) + + +def test_Broadcast(): + x_np = np.random.rand(3, 3, 1).astype("float32") + x = TensorWrapper(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = F.broadcast(x, (3, 3, 10)) + + grad(y, F.ones_like(y)) + np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) + + +def test_Reduce_sum(): + x_np = np.random.rand(3, 3).astype("float32") + x = TensorWrapper(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = x.sum(axis=0) + + grad(y, F.ones_like(y)) + np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) + + +def test_Reduce_mean(): + x_np = np.random.rand(3, 3).astype("float32") + x = TensorWrapper(x_np) + + grad = Grad().wrt(x, callback=save_to(x)) + y = x.mean(axis=0) + + grad(y, F.ones_like(y)) + np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy())