From f12355f7278d5ed373e0203b7577e62ec8ca69f3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 18 May 2021 10:46:48 +0800 Subject: [PATCH] fix(imperative/grad): fix hardcode dtype in subtensor_grad_rule GitOrigin-RevId: 50da4af26dd4f0f0efe38f07573d704ea2fbe841 --- imperative/python/src/grad_override.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index b49b58cb..2f495b33 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -35,9 +35,9 @@ std::shared_ptr broadcast_to(Tensor* x, Tensor* s) { return python::apply(op, x, s)[0]; } -std::shared_ptr make_tensor(CompNode cn, Tensor* shape, float v = 0) { - HostTensorND scalar{cn, {{1}, dtype::Float32()}}; - scalar.ptr()[0] = v; +std::shared_ptr make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) { + HostTensorND scalar{cn, {{1}, dtype}}; + std:memset(scalar.raw_ptr(), 0, dtype.size()); interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); auto&& t = std::make_shared(handle); auto res = broadcast_to(t.get(), shape); @@ -117,7 +117,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak apply_result_t ret(1); if (grad && inputs[0]) { SmallVector args_(inputs.size()+1); - auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); + auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype()); args_[0] = zeros.get(); args_[1] = grad; for (size_t i = 1; i < inputs.size(); ++i) { @@ -147,7 +147,7 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward: apply_result_t ret(1); if (grad && inputs[0]) { SmallVector args_(inputs.size()+1); - auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); + auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype()); args_[0] = zeros.get(); args_[1] = grad; for (size_t i = 1; i < inputs.size(); ++i) {