|
|
@@ -25,6 +25,25 @@ std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) { |
|
|
|
return python::apply(op, x, s)[0]; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) { |
|
|
|
static auto op = Reshape::make(); |
|
|
|
return python::apply(op, x, s)[0]; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) { |
|
|
|
static auto op = Broadcast::make(); |
|
|
|
return python::apply(op, x, s)[0]; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) { |
|
|
|
HostTensorND scalar{cn, {{1}, dtype::Float32()}}; |
|
|
|
scalar.ptr<float>()[0] = v; |
|
|
|
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar); |
|
|
|
auto&& t = std::make_shared<Tensor>(handle); |
|
|
|
auto&& res = broadcast_to(t.get(), shape); |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
auto& op = ctx.op->cast_final_safe<Elemwise>(); |
|
|
|
if (op.mode == Elemwise::Mode::ADD) { |
|
|
@@ -52,10 +71,138 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make |
|
|
|
throw GradRuleFallback(); |
|
|
|
} |
|
|
|
|
|
|
|
apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
mgb_assert(ctx.nargs == 2); |
|
|
|
std::array<std::shared_ptr<Tensor>, 2> input_shapes; |
|
|
|
for (size_t i = 0; i < 2; ++i) { |
|
|
|
if (input_requires_grad(ctx, i)) { |
|
|
|
input_shapes[i] = get_shape(ctx.args[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(2); |
|
|
|
for (size_t i = 0; i < 2; ++i) { |
|
|
|
if (shapes[i]) { |
|
|
|
ret[i] = reshape_to(grad, shapes[i].get()); |
|
|
|
} |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
return apply(ctx); |
|
|
|
} |
|
|
|
|
|
|
|
apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
auto&& op = ctx.op->cast_final_safe<Subtensor>(); |
|
|
|
auto&& grad_op = SetSubtensor::make(op.items); |
|
|
|
SmallVector<std::shared_ptr<Tensor>> inputs; |
|
|
|
if (input_requires_grad(ctx, 0)) { |
|
|
|
inputs.push_back(get_shape(ctx.args[0])); |
|
|
|
for (size_t i = 1; i < ctx.nargs; ++i) { |
|
|
|
inputs.push_back(ctx.args[i]->copy()); |
|
|
|
} |
|
|
|
} |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
apply_result_t ret(1); |
|
|
|
if (inputs[0]) { |
|
|
|
SmallVector<Tensor*> args_(inputs.size()+1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); |
|
|
|
args_[0] = zeros.get(); |
|
|
|
args_[1] = grad; |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
args_[i+1] = inputs[i].get(); |
|
|
|
} |
|
|
|
ret[0] = python::apply(grad_op_, args_)[0]; |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
return apply(ctx); |
|
|
|
} |
|
|
|
|
|
|
|
apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>(); |
|
|
|
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items); |
|
|
|
SmallVector<std::shared_ptr<Tensor>> inputs; |
|
|
|
if (input_requires_grad(ctx, 0)) { |
|
|
|
inputs.push_back(get_shape(ctx.args[0])); |
|
|
|
for (size_t i = 1; i < ctx.nargs; ++i) { |
|
|
|
inputs.push_back(ctx.args[i]->copy()); |
|
|
|
} |
|
|
|
} |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
apply_result_t ret(1); |
|
|
|
if (inputs[0]) { |
|
|
|
SmallVector<Tensor*> args_(inputs.size()+1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); |
|
|
|
args_[0] = zeros.get(); |
|
|
|
args_[1] = grad; |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
args_[i+1] = inputs[i].get(); |
|
|
|
} |
|
|
|
ret[0] = python::apply(grad_op_, args_)[0]; |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
return apply(ctx); |
|
|
|
} |
|
|
|
|
|
|
|
apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
auto& op = ctx.op->cast_final_safe<Reduce>(); |
|
|
|
if (op.mode == Reduce::Mode::SUM) { |
|
|
|
mgb_assert(ctx.nargs == 1); |
|
|
|
std::array<std::shared_ptr<Tensor>, 1> input_shapes; |
|
|
|
if (input_requires_grad(ctx, 0)) { |
|
|
|
input_shapes[0] = get_shape(ctx.args[0]); |
|
|
|
} |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(1); |
|
|
|
if (shapes[0]) { |
|
|
|
ret[0] = broadcast_to(grad, shapes[0].get()); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
return apply(ctx); |
|
|
|
} |
|
|
|
throw GradRuleFallback(); |
|
|
|
} |
|
|
|
|
|
|
|
template<typename T, typename U> |
|
|
|
apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { |
|
|
|
auto&& op = ctx.op->cast_final_safe<T>(); |
|
|
|
mgb_assert(ctx.nargs == 1); |
|
|
|
auto&& grad_op = U::make(op.axis); |
|
|
|
maker.output_size(1).output_captured(0, false); |
|
|
|
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { |
|
|
|
mgb_assert(ngrads == 1); |
|
|
|
Tensor* grad = grads[0]; |
|
|
|
apply_result_t ret(1); |
|
|
|
ret[0] = python::apply(grad_op_, grad)[0]; |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
return apply(ctx); |
|
|
|
} |
|
|
|
|
|
|
|
struct Init { |
|
|
|
Init() { |
|
|
|
auto& reg = grad_rule_registry(); |
|
|
|
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); |
|
|
|
reg.emplace(Reshape::typeinfo(), reshape_grad_rule); |
|
|
|
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); |
|
|
|
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); |
|
|
|
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); |
|
|
|
reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule<AddAxis, RemoveAxis>); |
|
|
|
reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule<RemoveAxis, AddAxis>); |
|
|
|
} |
|
|
|
} _; |
|
|
|
|
|
|
|