Browse Source

perf(mge): add more specialized grad rules

GitOrigin-RevId: f88809a6d7
release-1.2
Megvii Engine Team 4 years ago
parent
commit
a892e5d0e4
3 changed files with 147 additions and 174 deletions
  1. +0
    -173
      imperative/python/megengine/core/autodiff/builtin_op_utils.py
  2. +0
    -1
      imperative/python/megengine/core/autodiff/grad.py
  3. +147
    -0
      imperative/python/src/grad_override.cpp

+ 0
- 173
imperative/python/megengine/core/autodiff/builtin_op_utils.py View File

@@ -1,173 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import itertools

import numpy as np

from .._imperative_rt import TensorAttr, imperative
from .._imperative_rt.core2 import apply
from ..ops.builtin import (
Broadcast,
Elemwise,
GetVarShape,
IndexingMultiAxisVec,
IndexingSetMultiAxisVec,
OpDef,
Reduce,
Reshape,
SetSubtensor,
Subtensor,
)
from ..ops.special import Const


def default_grad_fn(op, inputs, outputs, input_requires_grad):
def get_tensor_attr(x):
attr = TensorAttr()
attr.dtype = x.dtype
attr.comp_node = x.device.to_c()
return attr

output_has_grads = [True,] * len(outputs)
result = imperative.make_backward_graph(
op, list(map(get_tensor_attr, inputs)), input_requires_grad, output_has_grads
)
if result is None:
nr_inputs = len(inputs)
nr_outputs = len(outputs)

def backward(*args):
return nr_inputs * [
None,
]

return backward, nr_outputs * [False,]
backward_graph, save_for_backward_mask, input_has_grad = result

intput_output_mask = save_for_backward_mask[: len(inputs + outputs) :]
output_grad_mask = save_for_backward_mask[len(inputs + outputs) :]
save_for_backward = tuple(
val for val, mask in zip(inputs + outputs, intput_output_mask) if mask
)

del inputs
del outputs

def backward(*args):
output_grads = tuple(val for val, mask in zip(args, output_grad_mask) if mask)
assert None not in output_grads
ret = iter(apply(backward_graph, *(save_for_backward + output_grads)))
return tuple(next(ret) if mask else None for mask in input_has_grad)

return backward, output_grad_mask


def get_shape(x):
(s,) = apply(GetVarShape(), x._data)
return Tensor(s)


# override for Elemwise.add
def elemwise_add_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2

input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
]

def reduce_to(x, s):
(y,) = apply(Reduce(), x, s)
return y

def backward(dy):
return tuple(
reduce_to(dy, s) if i else None
for i, s in zip(input_requires_grad, input_shapes)
)

return backward, [True]


# override for Reshape
def reshape_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2

input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
]

def reshape_to(dy, s):
(dx,) = apply(Reshape(), dy, s)
return dx

def backward(dy):
return tuple(
reshape_to(dy, s) if i else None
for i, s in zip(input_requires_grad, input_shapes)
)

return backward, [True]


# override for Subtensor
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = SetSubtensor(op.items)

input_shape = get_shape(inputs[0])
params = inputs[1:]

def make_grad(grad_op, dy):
(_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy)
(grad,) = apply(Broadcast(), _z, input_shape)
(dx,) = apply(grad_op, grad, dy, *params)
return dx

def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)

return backward, [True]


# override for IndexingMultiAxisVec
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = IndexingSetMultiAxisVec(op.items)

input_shape = get_shape(inputs[0])
params = inputs[1:]

def make_grad(grad_op, dy):
(_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy)
(grad,) = apply(Broadcast(), _z, input_shape)
(dx,) = apply(grad_op, grad, dy, *params)
return dx

def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)

return backward, [True]


# override for Reduce.sum
def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 1
input_shape = get_shape(inputs[0])

def broadcast_to(dy, s):
(dx,) = apply(Broadcast(), dy, s)
return dx

def backward(dy):
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,)

return backward, [True]

+ 0
- 1
imperative/python/megengine/core/autodiff/grad.py View File

@@ -19,7 +19,6 @@ import megengine as mge
from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from . import builtin_op_utils

""" Some notes:
1. Initialize the optimizer:


+ 147
- 0
imperative/python/src/grad_override.cpp View File

@@ -25,6 +25,25 @@ std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
return python::apply(op, x, s)[0];
}

std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
static auto op = Reshape::make();
return python::apply(op, x, s)[0];
}

std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
static auto op = Broadcast::make();
return python::apply(op, x, s)[0];
}

std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
HostTensorND scalar{cn, {{1}, dtype::Float32()}};
scalar.ptr<float>()[0] = v;
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar);
auto&& t = std::make_shared<Tensor>(handle);
auto&& res = broadcast_to(t.get(), shape);
return res;
}

apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Elemwise>();
if (op.mode == Elemwise::Mode::ADD) {
@@ -52,10 +71,138 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
throw GradRuleFallback();
}

apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (input_requires_grad(ctx, i)) {
input_shapes[i] = get_shape(ctx.args[i]);
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i].get());
}
}
return ret;
});
return apply(ctx);
}

apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
apply_result_t ret(1);
if (inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
args_[i+1] = inputs[i].get();
}
ret[0] = python::apply(grad_op_, args_)[0];
}
return ret;
});
return apply(ctx);
}

apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
apply_result_t ret(1);
if (inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
args_[i+1] = inputs[i].get();
}
ret[0] = python::apply(grad_op_, args_)[0];
}
return ret;
});
return apply(ctx);
}

apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
mgb_assert(ctx.nargs == 1);
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
input_shapes[0] = get_shape(ctx.args[0]);
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0].get());
}
return ret;
});
return apply(ctx);
}
throw GradRuleFallback();
}

template<typename T, typename U>
apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<T>();
mgb_assert(ctx.nargs == 1);
auto&& grad_op = U::make(op.axis);
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0];
return ret;
});
return apply(ctx);
}

struct Init {
Init() {
auto& reg = grad_rule_registry();
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule<AddAxis, RemoveAxis>);
reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule<RemoveAxis, AddAxis>);
}
} _;



Loading…
Cancel
Save