Browse Source

fix(trace): constants in backward graph treat as ImmutableTensor corectlly

GitOrigin-RevId: 5fd6a5e00c
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
2ac3c9dc8b
2 changed files with 25 additions and 5 deletions
  1. +21
    -0
      imperative/python/src/tensor.cpp
  2. +4
    -5
      imperative/python/src/tensor.h

+ 21
- 0
imperative/python/src/tensor.cpp View File

@@ -13,6 +13,7 @@
#include "megbrain/common.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/opr/io.h"

#include "./tensor.h"
#include "./grad.h"
@@ -39,6 +40,26 @@ interpreter::Interpreter::Channel* interpreter_for_py;
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing;
PyObject *cpp_apply_backward_varnode;

std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
}
py::tuple tup(6);
auto data = value->get_value();
tup[0] = py::reinterpret_steal<py::array>(ndarray_from_tensor(data, npy::ShareType::MUST_SHARE));
tup[1] = value->dtype();
tup[2] = value->comp_node();
tup[3] = true;
tup[4] = false;
tup[5] = py::none{};
auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
if (!py_ret) throw py::error_already_set();
auto py_list = py::reinterpret_steal<py::list>(py_ret);
auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr());
auto tensor = tensor_wrapper->m_tensor;
return tensor_wrapper->m_tensor;
}

#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
mode = pyf.ptr(); \


+ 4
- 5
imperative/python/src/tensor.h View File

@@ -271,18 +271,17 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
return apply(op, args, nargs);
}

std::shared_ptr<Tensor> make_const(imperative::TensorPtr value);

inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) {
SmallVector<std::shared_ptr<Tensor>> inputs;
for (size_t i = 0; i < nargs; ++i) {
inputs.push_back(args[i]->shared_from_this());
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) {
return apply(op, inputs);
};
auto const_functor = [](imperative::TensorPtr value) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
return apply(op, std::move(inputs));
};
return graph.apply(inputs, apply_functor, const_functor);
return graph.apply(inputs, apply_functor, &make_const);
}

template <typename T>


Loading…
Cancel
Save