Browse Source

fix(mge): expand custom op before trace

GitOrigin-RevId: 725a5b87cb
release-1.2
Megvii Engine Team 4 years ago
parent
commit
e9e5f442a7
1 changed files with 21 additions and 20 deletions
  1. +21
    -20
      imperative/python/src/tensor.cpp

+ 21
- 20
imperative/python/src/tensor.cpp View File

@@ -96,29 +96,30 @@ apply_result_t apply(ApplyContext& ctx) {
return apply_grad(ctx);
}

if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
py::tuple pyin(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
}
auto f = py::getattr(op->obj, "_default_rule");
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
if (!pyout) throw py::error_already_set();
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
return {tw->m_tensor};
}
apply_result_t ret;
ret.reserve(py::len(pyout));
for (auto&& i : pyout) {
auto* tw = TensorWrapper::try_cast(i.ptr());
mgb_assert(tw);
ret.push_back(tw->m_tensor);
}
return ret;
}

if (flags & Tensor::Flags::TRACE) {
return apply_trace(ctx);
} else {
if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
py::tuple pyin(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
}
auto f = py::getattr(op->obj, "_default_rule");
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
return {tw->m_tensor};
}
apply_result_t ret;
ret.reserve(py::len(pyout));
for (auto&& i : pyout) {
auto* tw = TensorWrapper::try_cast(i.ptr());
mgb_assert(tw);
ret.push_back(tw->m_tensor);
}
return ret;
}

SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
handles[i] = ctx.args[i]->m_handle.get();


Loading…
Cancel
Save