|
@@ -177,11 +177,27 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) |
|
|
throw GradRuleFallback(); |
|
|
throw GradRuleFallback(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template<typename T, typename U> |
|
|
|
|
|
apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
|
|
auto&& op = ctx.op->cast_final_safe<T>(); |
|
|
|
|
|
|
|
|
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
|
|
auto&& op = ctx.op->cast_final_safe<AddAxis>(); |
|
|
mgb_assert(ctx.nargs == 1); |
|
|
mgb_assert(ctx.nargs == 1); |
|
|
auto&& grad_op = U::make(op.axis); |
|
|
|
|
|
|
|
|
auto&& grad_op = RemoveAxis::make(op.axis); |
|
|
|
|
|
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); |
|
|
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
|
|
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
|
|
Tensor* grad = grads[0]; |
|
|
|
|
|
apply_result_t ret(1); |
|
|
|
|
|
ret[0] = python::apply(grad_op_, grad)[0]; |
|
|
|
|
|
return ret; |
|
|
|
|
|
}); |
|
|
|
|
|
return apply(ctx); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
|
|
auto&& op = ctx.op->cast_final_safe<RemoveAxis>(); |
|
|
|
|
|
mgb_assert(ctx.nargs == 1); |
|
|
|
|
|
auto&& grad_op = AddAxis::make(op.axis); |
|
|
|
|
|
std::sort(grad_op->axis.begin(), grad_op->axis.end()); |
|
|
maker.output_size(1).output_captured(0, false); |
|
|
maker.output_size(1).output_captured(0, false); |
|
|
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
mgb_assert(ngrads == 1); |
|
|
mgb_assert(ngrads == 1); |
|
@@ -201,8 +217,8 @@ struct Init { |
|
|
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); |
|
|
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); |
|
|
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); |
|
|
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); |
|
|
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); |
|
|
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); |
|
|
reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule<AddAxis, RemoveAxis>); |
|
|
|
|
|
reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule<RemoveAxis, AddAxis>); |
|
|
|
|
|
|
|
|
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); |
|
|
|
|
|
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); |
|
|
} |
|
|
} |
|
|
} _; |
|
|
} _; |
|
|
|
|
|
|
|
|