Browse Source

fix(mge/autodiff): fix segfault when grad is nullptr

GitOrigin-RevId: 6139212bfd
release-1.2
Megvii Engine Team 4 years ago
parent
commit
cf3f58cb9b
1 changed files with 21 additions and 9 deletions
  1. +21
    -9
      imperative/python/src/grad_override.cpp

+ 21
- 9
imperative/python/src/grad_override.cpp View File

@@ -59,6 +59,9 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(2); apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) { if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i].get()); ret[i] = reduce_to(grad, shapes[i].get());
@@ -84,6 +87,9 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(2); apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) { if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i].get()); ret[i] = reshape_to(grad, shapes[i].get());
@@ -107,10 +113,10 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
if (inputs[0]) {
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1); SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get(); args_[0] = zeros.get();
args_[1] = grad; args_[1] = grad;
@@ -137,10 +143,10 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
if (inputs[0]) {
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1); SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get(); args_[0] = zeros.get();
args_[1] = grad; args_[1] = grad;
@@ -167,7 +173,7 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
if (shapes[0]) {
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0].get()); ret[0] = broadcast_to(grad, shapes[0].get());
} }
return ret; return ret;
@@ -180,14 +186,17 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>(); auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1); mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = RemoveAxis::make(op.axis); auto&& grad_op = RemoveAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0];
if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
}
return ret; return ret;
}); });
return apply(ctx); return apply(ctx);
@@ -196,14 +205,17 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>(); auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1); mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = AddAxis::make(op.axis); auto&& grad_op = AddAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end()); std::sort(grad_op->axis.begin(), grad_op->axis.end());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0];
if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
}
return ret; return ret;
}); });
return apply(ctx); return apply(ctx);


Loading…
Cancel
Save