Browse Source

fix(imperative/grad): fix hardcode dtype in subtensor_grad_rule

GitOrigin-RevId: 50da4af26d
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
f12355f727
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      imperative/python/src/grad_override.cpp

+ 5
- 5
imperative/python/src/grad_override.cpp View File

@@ -35,9 +35,9 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
return python::apply(op, x, s)[0];
}

std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
HostTensorND scalar{cn, {{1}, dtype::Float32()}};
scalar.ptr<float>()[0] = v;
std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
HostTensorND scalar{cn, {{1}, dtype}};
std:memset(scalar.raw_ptr(), 0, dtype.size());
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle);
auto res = broadcast_to(t.get(), shape);
@@ -117,7 +117,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
apply_result_t ret(1);
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
@@ -147,7 +147,7 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
apply_result_t ret(1);
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {


Loading…
Cancel
Save