|
@@ -96,29 +96,30 @@ apply_result_t apply(ApplyContext& ctx) { |
|
|
return apply_grad(ctx); |
|
|
return apply_grad(ctx); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) { |
|
|
|
|
|
py::tuple pyin(ctx.nargs); |
|
|
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
|
|
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); |
|
|
|
|
|
} |
|
|
|
|
|
auto f = py::getattr(op->obj, "_default_rule"); |
|
|
|
|
|
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); |
|
|
|
|
|
if (!pyout) throw py::error_already_set(); |
|
|
|
|
|
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { |
|
|
|
|
|
return {tw->m_tensor}; |
|
|
|
|
|
} |
|
|
|
|
|
apply_result_t ret; |
|
|
|
|
|
ret.reserve(py::len(pyout)); |
|
|
|
|
|
for (auto&& i : pyout) { |
|
|
|
|
|
auto* tw = TensorWrapper::try_cast(i.ptr()); |
|
|
|
|
|
mgb_assert(tw); |
|
|
|
|
|
ret.push_back(tw->m_tensor); |
|
|
|
|
|
} |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (flags & Tensor::Flags::TRACE) { |
|
|
if (flags & Tensor::Flags::TRACE) { |
|
|
return apply_trace(ctx); |
|
|
return apply_trace(ctx); |
|
|
} else { |
|
|
} else { |
|
|
if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) { |
|
|
|
|
|
py::tuple pyin(ctx.nargs); |
|
|
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
|
|
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); |
|
|
|
|
|
} |
|
|
|
|
|
auto f = py::getattr(op->obj, "_default_rule"); |
|
|
|
|
|
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); |
|
|
|
|
|
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { |
|
|
|
|
|
return {tw->m_tensor}; |
|
|
|
|
|
} |
|
|
|
|
|
apply_result_t ret; |
|
|
|
|
|
ret.reserve(py::len(pyout)); |
|
|
|
|
|
for (auto&& i : pyout) { |
|
|
|
|
|
auto* tw = TensorWrapper::try_cast(i.ptr()); |
|
|
|
|
|
mgb_assert(tw); |
|
|
|
|
|
ret.push_back(tw->m_tensor); |
|
|
|
|
|
} |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); |
|
|
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); |
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
handles[i] = ctx.args[i]->m_handle.get(); |
|
|
handles[i] = ctx.args[i]->m_handle.get(); |
|
|