|
|
@@ -13,6 +13,7 @@ |
|
|
|
#include "megbrain/common.h" |
|
|
|
#include "megbrain/imperative/ops/utility.h" |
|
|
|
#include "megbrain/imperative/ops/backward_graph.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
|
|
|
|
#include "./tensor.h" |
|
|
|
#include "./grad.h" |
|
|
@@ -39,6 +40,26 @@ interpreter::Interpreter::Channel* interpreter_for_py; |
|
|
|
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing; |
|
|
|
PyObject *cpp_apply_backward_varnode; |
|
|
|
|
|
|
|
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { |
|
|
|
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { |
|
|
|
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor())); |
|
|
|
} |
|
|
|
py::tuple tup(6); |
|
|
|
auto data = value->get_value(); |
|
|
|
tup[0] = py::reinterpret_steal<py::array>(ndarray_from_tensor(data, npy::ShareType::MUST_SHARE)); |
|
|
|
tup[1] = value->dtype(); |
|
|
|
tup[2] = value->comp_node(); |
|
|
|
tup[3] = true; |
|
|
|
tup[4] = false; |
|
|
|
tup[5] = py::none{}; |
|
|
|
auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); |
|
|
|
if (!py_ret) throw py::error_already_set(); |
|
|
|
auto py_list = py::reinterpret_steal<py::list>(py_ret); |
|
|
|
auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr()); |
|
|
|
auto tensor = tensor_wrapper->m_tensor; |
|
|
|
return tensor_wrapper->m_tensor; |
|
|
|
} |
|
|
|
|
|
|
|
#define REGISTE_APPLY_FUNC(mode) \ |
|
|
|
void set_##mode(py::object pyf) { \ |
|
|
|
mode = pyf.ptr(); \ |
|
|
|