|
|
@@ -59,6 +59,9 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(2); |
|
|
|
if (!grad) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < 2; ++i) { |
|
|
|
if (shapes[i]) { |
|
|
|
ret[i] = reduce_to(grad, shapes[i].get()); |
|
|
@@ -84,6 +87,9 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(2); |
|
|
|
if (!grad) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < 2; ++i) { |
|
|
|
if (shapes[i]) { |
|
|
|
ret[i] = reshape_to(grad, shapes[i].get()); |
|
|
@@ -107,10 +113,10 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(1); |
|
|
|
if (inputs[0]) { |
|
|
|
if (grad && inputs[0]) { |
|
|
|
SmallVector<Tensor*> args_(inputs.size()+1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); |
|
|
|
args_[0] = zeros.get(); |
|
|
|
args_[1] = grad; |
|
|
@@ -137,10 +143,10 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward: |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(1); |
|
|
|
if (inputs[0]) { |
|
|
|
if (grad && inputs[0]) { |
|
|
|
SmallVector<Tensor*> args_(inputs.size()+1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); |
|
|
|
args_[0] = zeros.get(); |
|
|
|
args_[1] = grad; |
|
|
@@ -167,7 +173,7 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(1); |
|
|
|
if (shapes[0]) { |
|
|
|
if (grad && shapes[0]) { |
|
|
|
ret[0] = broadcast_to(grad, shapes[0].get()); |
|
|
|
} |
|
|
|
return ret; |
|
|
@@ -180,14 +186,17 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) |
|
|
|
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
auto&& op = ctx.op->cast_final_safe<AddAxis>(); |
|
|
|
mgb_assert(ctx.nargs == 1); |
|
|
|
bool flag = input_requires_grad(ctx, 0); |
|
|
|
auto&& grad_op = RemoveAxis::make(op.axis); |
|
|
|
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(1); |
|
|
|
ret[0] = python::apply(grad_op_, grad)[0]; |
|
|
|
if (grad && flag_) { |
|
|
|
ret[0] = python::apply(grad_op_, grad)[0]; |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
return apply(ctx); |
|
|
@@ -196,14 +205,17 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker |
|
|
|
apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
auto&& op = ctx.op->cast_final_safe<RemoveAxis>(); |
|
|
|
mgb_assert(ctx.nargs == 1); |
|
|
|
bool flag = input_requires_grad(ctx, 0); |
|
|
|
auto&& grad_op = AddAxis::make(op.axis); |
|
|
|
std::sort(grad_op->axis.begin(), grad_op->axis.end()); |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(1); |
|
|
|
ret[0] = python::apply(grad_op_, grad)[0]; |
|
|
|
if (grad && flag_) { |
|
|
|
ret[0] = python::apply(grad_op_, grad)[0]; |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
return apply(ctx); |
|
|
|