Browse Source

fix(grad): stop using exception in grad_override

GitOrigin-RevId: 00ae38d48b
release-1.7
Megvii Engine Team 3 years ago
parent
commit
000517c641
3 changed files with 17 additions and 18 deletions
  1. +3
    -5
      imperative/python/src/grad.cpp
  2. +2
    -1
      imperative/python/src/grad.h
  3. +12
    -12
      imperative/python/src/grad_override.cpp

+ 3
- 5
imperative/python/src/grad.cpp View File

@@ -393,13 +393,11 @@ apply_result_t apply_grad(ApplyContext& ctx) {
auto&& it = registry.find(ctx.op->dyn_typeinfo());
if (it != registry.end()) {
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
try {
auto ret = it->second(ctx, maker);
if (auto ret = it->second(ctx, maker)) {
maker.finalize();
return ret;
} catch (GradRuleFallback&) {
grad_fn_holder.reset();
return *ret;
}
grad_fn_holder.reset();
}
return backward_graph_grad_rule(ctx, grad_fn_holder);
}();


+ 2
- 1
imperative/python/src/grad.h View File

@@ -16,6 +16,7 @@

#include <megbrain/utils/small_vector.h>
#include <memory>
#include <optional>

namespace mgb::imperative::python {

@@ -154,7 +155,7 @@ public:
Maker maker(ApplyContext& ctx) {return {*this, ctx};}
};

using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>;
using GradRuleFn = std::function<std::optional<apply_result_t>(ApplyContext&, CustomBackward::Maker&)>;

std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();



+ 12
- 12
imperative/python/src/grad_override.cpp View File

@@ -37,14 +37,14 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {

std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
HostTensorND scalar{cn, {{1}, dtype}};
std:memset(scalar.raw_ptr(), 0, dtype.size());
std::memset(scalar.raw_ptr(), 0, dtype.size());
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle);
auto res = broadcast_to(t.get(), shape);
return res;
}

apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Elemwise>();
if (op.mode == Elemwise::Mode::ADD) {
mgb_assert(ctx.nargs == 2);
@@ -71,10 +71,10 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
});
return apply(ctx);
}
throw GradRuleFallback();
return {};
}

apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
@@ -100,7 +100,7 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
return apply(ctx);
}

apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
@@ -130,7 +130,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
return apply(ctx);
}

apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
@@ -160,11 +160,11 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
return apply(ctx);
}

apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
if (ctx.nargs != 1) {
throw GradRuleFallback();
return {};
}
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
@@ -182,10 +182,10 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
});
return apply(ctx);
}
throw GradRuleFallback();
return {};
}

apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
@@ -204,7 +204,7 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
return apply(ctx);
}

apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
@@ -223,7 +223,7 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma
return apply(ctx);
}

apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
std::optional<apply_result_t> fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 1);
maker.output_size(1).output_captured(0, false);
maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) {


Loading…
Cancel
Save