Browse Source

perf(mge/imperative): optimize gpu memory with standalone grad_fn

GitOrigin-RevId: 44952b235e
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
46f03b9f7c
2 changed files with 195 additions and 22 deletions
  1. +111
    -22
      imperative/python/megengine/core/autodiff/builtin_op_utils.py
  2. +84
    -0
      imperative/python/test/unit/test_autodiff.py

+ 111
- 22
imperative/python/megengine/core/autodiff/builtin_op_utils.py View File

@@ -12,9 +12,26 @@ import itertools
import numpy as np

from .._imperative_rt import TensorAttr, imperative
from ..ops.builtin import Elemwise, GetVarShape, OpDef, OprAttr, Reduce, Reshape
from ..ops.builtin import (
Broadcast,
Elemwise,
GetVarShape,
IndexingMultiAxisVec,
IndexingSetMultiAxisVec,
OpDef,
OprAttr,
Reduce,
Reshape,
SetSubtensor,
Subtensor,
)
from ..ops.special import Const
from ..tensor.core import apply
from ..tensor.function import Function
from ..tensor.tensor_wrapper import TensorWrapper

_elemwise_add_param = Elemwise(mode="add").to_c().param
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0]


@functools.singledispatch
@@ -22,19 +39,17 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad):
assert 0


_elemwise_add_param = Elemwise(mode="add").to_c().param


@builtin_op_get_backward_fn.register(OpDef)
def _(op: OpDef, inputs, outputs, input_requires_grad):
if (
isinstance(op, OprAttr)
and op.type == "Elemwise"
and op.param == _elemwise_add_param
):
grad_fn = elemwise_grad_fn
elif isinstance(op, OprAttr) and op.type == Reshape.name:
grad_fn = reshape_grad_fn
if isinstance(op, OprAttr):
grad_fn = _oprAttr_grad_fn.get(op.type, None)
if grad_fn is None:
if op.type == Elemwise.name and op.param == _elemwise_add_param:
grad_fn = elemwise_add_grad_fn
elif op.type == Reduce.name and op.param[0] == _reduce_sum_param:
grad_fn = reduce_sum_grad_fn
else:
grad_fn = default_grad_fn
else:
grad_fn = default_grad_fn
return grad_fn(op, inputs, outputs, input_requires_grad)
@@ -73,6 +88,7 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad):
save_for_backward = tuple(
val for val, mask in zip(inputs + outputs, intput_output_mask) if mask
)

del inputs
del outputs

@@ -85,13 +101,14 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad):
return backward, output_grad_mask


# override for elemwise
def elemwise_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2
def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s


def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s
# override for Elemwise.add
def elemwise_add_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2

input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
@@ -110,13 +127,10 @@ def elemwise_grad_fn(op, inputs, outputs, input_requires_grad):
return backward, [True]


# override for Reshape
def reshape_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2

def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s

input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
]
@@ -132,3 +146,78 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad):
)

return backward, [True]


# override for Subtensor
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr()
grad_op.type = SetSubtensor.name
grad_op.param = op.param

input_shape = get_shape(inputs[0])
params = inputs[1:]

def make_grad(grad_op, dy):
grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params)
return dx

def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)

return backward, [True]


# override for IndexingMultiAxisVec
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr()
grad_op.type = IndexingSetMultiAxisVec.name
grad_op.param = op.param

input_shape = get_shape(inputs[0])
params = inputs[1:]

def make_grad(grad_op, dy):
grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params)
return dx

def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)

return backward, [True]


# override for Reduce.sum
def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 1
input_shape = get_shape(inputs[0])

def broadcast_to(dy, s):
(dx,) = apply(Broadcast(), dy, s)
return dx

def backward(dy):
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,)

return backward, [True]


_oprAttr_grad_fn = {
Reshape.name: reshape_grad_fn,
Subtensor.name: subtensor_grad_fn,
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn,
Broadcast.name: elemwise_add_grad_fn,
}

+ 84
- 0
imperative/python/test/unit/test_autodiff.py View File

@@ -14,6 +14,7 @@ import pytest

import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, imperative
from megengine.core._imperative_rt.imperative import sync
from megengine.core.autodiff.grad import Grad
@@ -229,3 +230,86 @@ def test_elemwise_relu_backward_fn():
result = imperative.make_backward_graph(op, [attr], [True], [True])
backward_graph, save_for_backward_mask, input_has_grad = result
assert save_for_backward_mask == [False, True, True], save_for_backward_mask


def test_reshape():
x_np = np.random.rand(2, 5).astype("float32")
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = x.reshape(5, 2)

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy())


def test_subtensor():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = x[1:-1, :2]

grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy()
)


def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = x[[0, 2], [0, 2]]

grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy()
)


def test_AxisAddRemove():
x_np = np.random.rand(1, 5).astype("float32")
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = F.remove_axis(F.add_axis(x, 2), 0)

grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy()
)


def test_Broadcast():
x_np = np.random.rand(3, 3, 1).astype("float32")
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast(x, (3, 3, 10))

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())


def test_Reduce_sum():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = x.sum(axis=0)

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())


def test_Reduce_mean():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = x.mean(axis=0)

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy())

Loading…
Cancel
Save