@@ -12,6 +12,7 @@ | |||||
#include "./grad.h" | #include "./grad.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
#include "range/v3/all.hpp" | #include "range/v3/all.hpp" | ||||
@@ -21,6 +22,9 @@ namespace views = ranges::views; | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
using scoped_disable = ApplyContext::scoped_disable; | |||||
using Flags = Tensor::Flags; | |||||
namespace { | namespace { | ||||
struct GradSlotWeakPtr { | struct GradSlotWeakPtr { | ||||
@@ -78,6 +82,21 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||||
return result; | return result; | ||||
} | } | ||||
struct BackwardContext { | |||||
PyTypeObject* pytype = nullptr; | |||||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||||
if (pytype) { | |||||
return TensorWrapper::make(pytype, std::move(t)); | |||||
} | |||||
return TensorWrapper::make(std::move(t)); | |||||
} | |||||
auto wrap_tensor(Tensor* t) { | |||||
return wrap_tensor(t->shared_from_this()); | |||||
} | |||||
}; | |||||
struct BackwardGraphWithClosure { | struct BackwardGraphWithClosure { | ||||
std::shared_ptr<BackwardGraphResult> backward_graph; | std::shared_ptr<BackwardGraphResult> backward_graph; | ||||
SmallVector<std::shared_ptr<Tensor>> closure; | SmallVector<std::shared_ptr<Tensor>> closure; | ||||
@@ -119,7 +138,7 @@ struct BackwardGraphWithClosure { | |||||
} | } | ||||
template <typename T, typename R> | template <typename T, typename R> | ||||
void operator()(T&& grads, R&& receiver) { | |||||
void operator()(BackwardContext&, T&& grads, R&& receiver) { | |||||
Tensor* args[closure.size() + grads.size()]; | Tensor* args[closure.size() + grads.size()]; | ||||
size_t nargs = 0; | size_t nargs = 0; | ||||
for (auto&& t : closure) { | for (auto&& t : closure) { | ||||
@@ -143,7 +162,7 @@ struct BackwardGraphWithClosure { | |||||
ApplyContext ctx; | ApplyContext ctx; | ||||
ctx.op = backward_graph->backward; | ctx.op = backward_graph->backward; | ||||
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; | |||||
ctx.flags = is_tracing ? Flags::TRACE : 0; | |||||
ctx.nargs = nargs; | ctx.nargs = nargs; | ||||
ctx.args = args; | ctx.args = args; | ||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
@@ -174,6 +193,47 @@ struct BackwardGraphWithClosure { | |||||
} | } | ||||
}; | }; | ||||
struct PythonBackward { | |||||
py::object pyfunc; | |||||
size_t input_size; | |||||
PythonBackward(py::object f, size_t nin) | |||||
: pyfunc(f), input_size(nin) {} | |||||
template <typename T, typename R> | |||||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||||
auto args = py::tuple(grads.size()); | |||||
for (size_t i = 0; i < grads.size(); ++i) { | |||||
auto&& g = grads[i]; | |||||
args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||||
} | |||||
auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||||
if (input_grads.is_none()) return; | |||||
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||||
if (input_size != 1) { | |||||
throw py::value_error("custom grad rule returned wrong number of grads"); | |||||
} | |||||
receiver(0, tw->m_tensor); | |||||
return; | |||||
} | |||||
if (py::len(input_grads) != input_size) { | |||||
throw py::value_error("custom grad rule returned wrong number of grads"); | |||||
} | |||||
for (auto [i, g] : views::enumerate(input_grads)) { | |||||
if (g.is_none()) continue; | |||||
auto* tw = TensorWrapper::try_cast(g.ptr()); | |||||
if (!tw) { | |||||
throw py::type_error("custom grad rule returned non-tensor"); | |||||
} | |||||
receiver(i, tw->m_tensor); | |||||
} | |||||
} | |||||
static constexpr bool input_has_grad(size_t) {return true;} | |||||
static constexpr bool output_requires_grad(size_t) {return true;} | |||||
static constexpr bool output_captured(size_t) {return true;} | |||||
}; | |||||
} // namespace | } // namespace | ||||
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | ||||
@@ -210,7 +270,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { | |||||
// same length as inputs (of forward op) | // same length as inputs (of forward op) | ||||
SmallVector<GradSlotProducerPtr> dsts; | SmallVector<GradSlotProducerPtr> dsts; | ||||
// encapsules actual function to compute gradient | // encapsules actual function to compute gradient | ||||
std::variant<std::monostate, BackwardGraphWithClosure> backward; | |||||
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward> backward; | |||||
// a flag used during backward | // a flag used during backward | ||||
bool in_ref_keeper = false; | bool in_ref_keeper = false; | ||||
@@ -268,6 +328,30 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra | |||||
return outputs; | return outputs; | ||||
} | } | ||||
apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||||
auto* op = ctx.op->try_cast_final<GenericPyOp>(); | |||||
py::tuple pyin(ctx.nargs); | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||||
} | |||||
auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||||
auto pyret = (scoped_disable(Flags::GRAD), | |||||
py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression | |||||
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||||
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||||
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||||
return {tw->m_tensor}; | |||||
} | |||||
apply_result_t ret; | |||||
ret.reserve(py::len(outputs)); | |||||
for (auto&& i : outputs) { | |||||
auto* tw = TensorWrapper::try_cast(i.ptr()); | |||||
mgb_assert(tw); | |||||
ret.push_back(tw->m_tensor); | |||||
} | |||||
return ret; | |||||
} | |||||
} // namespace | } // namespace | ||||
apply_result_t apply_grad(ApplyContext& ctx) { | apply_result_t apply_grad(ApplyContext& ctx) { | ||||
@@ -290,21 +374,23 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
// cleanup stale grad info | // cleanup stale grad info | ||||
// under what condition? | // under what condition? | ||||
tensor->m_grad_info = {}; | tensor->m_grad_info = {}; | ||||
tensor->m_flags &= ~Tensor::Flags::GRAD; | |||||
tensor->m_flags &= ~Flags::GRAD; | |||||
} | } | ||||
} else { | } else { | ||||
tensor->m_flags &= ~Tensor::Flags::GRAD; | |||||
tensor->m_flags &= ~Flags::GRAD; | |||||
} | } | ||||
} | } | ||||
ctx.flags &= ~Tensor::Flags::GRAD; | |||||
ctx.flags &= ~Flags::GRAD; | |||||
if (!grad_key) { | if (!grad_key) { | ||||
return apply(ctx); | return apply(ctx); | ||||
} | } | ||||
GradFnHelper grad_fn_holder; | GradFnHelper grad_fn_holder; | ||||
auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder); | |||||
auto outputs = ctx.op->same_type<GenericPyOp>() ? | |||||
python_grad_rule(ctx, grad_fn_holder) : | |||||
backward_graph_grad_rule(ctx, grad_fn_holder); | |||||
auto& grad_fn = grad_fn_holder.grad_fn; | auto& grad_fn = grad_fn_holder.grad_fn; | ||||
if (!grad_fn) { | if (!grad_fn) { | ||||
@@ -341,7 +427,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
grad_info.grad_fn = grad_fn; | grad_info.grad_fn = grad_fn; | ||||
grad_info.idx = i; | grad_info.idx = i; | ||||
grad_info.insert_after(grad_key->free_vars_head); | grad_info.insert_after(grad_key->free_vars_head); | ||||
outputs[i]->m_flags |= Tensor::Flags::GRAD; | |||||
outputs[i]->m_flags |= Flags::GRAD; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -357,7 +443,7 @@ void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | |||||
if (nargs != 2) { | if (nargs != 2) { | ||||
throw py::type_error("expect 2 arguments"); | throw py::type_error("expect 2 arguments"); | ||||
} | } | ||||
auto* tw = TensorWrapper::cast_safe(args[0]); | |||||
auto* tw = TensorWrapper::try_cast(args[0]); | |||||
if (!tw) { | if (!tw) { | ||||
throw py::type_error("argument 1 must be Tensor"); | throw py::type_error("argument 1 must be Tensor"); | ||||
} | } | ||||
@@ -390,14 +476,15 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||||
grad_fn->key = shared_from_this(); | grad_fn->key = shared_from_this(); | ||||
grad_fn->slots.resize(1); | grad_fn->slots.resize(1); | ||||
tensor->m_grad_info.insert_after(free_vars_head); | tensor->m_grad_info.insert_after(free_vars_head); | ||||
tensor->m_flags |= Tensor::Flags::GRAD; | |||||
tensor->m_flags |= Flags::GRAD; | |||||
} | } | ||||
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | ||||
} | } | ||||
void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) { | |||||
template<typename T> | |||||
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||||
if (!grad) { | if (!grad) { | ||||
grad = std::forward<decltype(delta)>(delta); | |||||
grad = std::forward<T>(delta); | |||||
return; | return; | ||||
} | } | ||||
static ApplyContext ctx; | static ApplyContext ctx; | ||||
@@ -409,7 +496,7 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) | |||||
ctx.args = args; | ctx.args = args; | ||||
ctx.flags = grad->m_flags | delta->m_flags; | ctx.flags = grad->m_flags | delta->m_flags; | ||||
if (is_tracing) { | if (is_tracing) { | ||||
ctx.flags |= Tensor::Flags::TRACE; | |||||
ctx.flags |= Flags::TRACE; | |||||
} | } | ||||
grad = apply(ctx)[0]; | grad = apply(ctx)[0]; | ||||
} | } | ||||
@@ -440,6 +527,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
} | } | ||||
} | } | ||||
BackwardContext bctx{pytype}; | |||||
std::vector<std::shared_ptr<GradFn>> ref_keeper; | std::vector<std::shared_ptr<GradFn>> ref_keeper; | ||||
ref_keeper.reserve(tape.size()); | ref_keeper.reserve(tape.size()); | ||||
// back-propagation in reverse order | // back-propagation in reverse order | ||||
@@ -456,7 +544,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
mgb_assert(0); | mgb_assert(0); | ||||
} else { | } else { | ||||
auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();}); | auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();}); | ||||
backward(std::forward<decltype(grads)>(grads), grad_receiver); | |||||
backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver); | |||||
} | } | ||||
}, grad_fn->backward); | }, grad_fn->backward); | ||||
@@ -14,6 +14,7 @@ | |||||
#include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include <Python.h> | #include <Python.h> | ||||
@@ -245,6 +246,35 @@ void _init_py_backward_graph(py::module m) { | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | ||||
} | } | ||||
struct PyOpBase : PyOpDef { | |||||
static PyTypeObject py_type; | |||||
static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) { | |||||
auto* obj = type->tp_alloc(type, 0); | |||||
if (obj) { | |||||
auto* self = reinterpret_cast<PyOpBase*>(obj); | |||||
new(&self->op) decltype(self->op); | |||||
} | |||||
return obj; | |||||
} | |||||
}; | |||||
PyTypeObject PyOpBase::py_type; | |||||
void _init_py_op_base(py::module m) { | |||||
using py_op = PyOpBase; | |||||
auto& py_type = PyOpBase::py_type; | |||||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase"; | |||||
py_type.tp_basicsize = sizeof(py_op); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "PyOpBase"; | |||||
py_type.tp_base = &PyOpType(OpDef); | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
py_type.tp_new = py_op::tp_new; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type)); | |||||
} | |||||
/*********** end of hand-write opdefs **************/ | /*********** end of hand-write opdefs **************/ | ||||
// auto generated opdefs | // auto generated opdefs | ||||
@@ -260,9 +290,16 @@ bool type_caster<OpDef>::load(handle src, bool convert) { | |||||
return false; | return false; | ||||
} | } | ||||
value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; | value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; | ||||
if (!value) { | |||||
// opdef only defined in Python | |||||
value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src)); | |||||
} | |||||
return true; | return true; | ||||
} | } | ||||
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | ||||
if (auto* pyop = op.try_cast_final<GenericPyOp>()) { | |||||
return object(pyop->obj).release(); | |||||
} | |||||
PyTypeObject* pytype; | PyTypeObject* pytype; | ||||
auto& c2p = PyOp(OpDef)::ctype2pytype; | auto& c2p = PyOp(OpDef)::ctype2pytype; | ||||
auto&& iter = c2p.find(op.dyn_typeinfo()); | auto&& iter = c2p.find(op.dyn_typeinfo()); | ||||
@@ -283,5 +320,6 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||||
void init_ops(py::module m) { | void init_ops(py::module m) { | ||||
_init_py_op_def(m); | _init_py_op_def(m); | ||||
_init_py_backward_graph(m); | _init_py_backward_graph(m); | ||||
_init_py_op_base(m); | |||||
INIT_ALL_OP(m) | INIT_ALL_OP(m) | ||||
} | } |
@@ -11,6 +11,7 @@ | |||||
#include "megbrain/dtype.h" | #include "megbrain/dtype.h" | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "./tensor.h" | #include "./tensor.h" | ||||
#include "./grad.h" | #include "./grad.h" | ||||
@@ -22,10 +23,12 @@ | |||||
#include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
#include <pybind11/operators.h> | #include <pybind11/operators.h> | ||||
#include <range/v3/all.hpp> | |||||
#include <unordered_map> | #include <unordered_map> | ||||
namespace py = pybind11; | namespace py = pybind11; | ||||
namespace views = ranges::views; | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
@@ -69,21 +72,45 @@ SET_UNSET_PROP(compiled) | |||||
bool skip_tracing = false; | bool skip_tracing = false; | ||||
Tensor::flags_t ApplyContext::global_disable = 0; | |||||
apply_result_t apply(ApplyContext& ctx) { | apply_result_t apply(ApplyContext& ctx) { | ||||
// emulating scalar should be put to specific op's apply, e.g., | // emulating scalar should be put to specific op's apply, e.g., | ||||
// elementwise, reduce, typecvt. Currently it's still handled at python | // elementwise, reduce, typecvt. Currently it's still handled at python | ||||
// side. It could be move to C++ side if it has an impact on performance | // side. It could be move to C++ side if it has an impact on performance | ||||
if (ctx.flags & Tensor::Flags::SCALAR) { | |||||
auto flags = ctx.flags & ~ApplyContext::global_disable; | |||||
if (flags & Tensor::Flags::SCALAR) { | |||||
// TODO: emulate scalar | // TODO: emulate scalar | ||||
} | } | ||||
if (ctx.flags & Tensor::Flags::GRAD) { | |||||
if (flags & Tensor::Flags::GRAD) { | |||||
return apply_grad(ctx); | return apply_grad(ctx); | ||||
} | } | ||||
if (ctx.flags & Tensor::Flags::TRACE) { | |||||
if (flags & Tensor::Flags::TRACE) { | |||||
return apply_trace(ctx); | return apply_trace(ctx); | ||||
} else { | } else { | ||||
if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) { | |||||
py::tuple pyin(ctx.nargs); | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||||
} | |||||
auto f = py::getattr(op->obj, "_default_rule"); | |||||
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); | |||||
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { | |||||
return {tw->m_tensor}; | |||||
} | |||||
apply_result_t ret; | |||||
ret.reserve(py::len(pyout)); | |||||
for (auto&& i : pyout) { | |||||
auto* tw = TensorWrapper::try_cast(i.ptr()); | |||||
mgb_assert(tw); | |||||
ret.push_back(tw->m_tensor); | |||||
} | |||||
return ret; | |||||
} | |||||
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | ||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
handles[i] = ctx.args[i]->m_handle.get(); | handles[i] = ctx.args[i]->m_handle.get(); | ||||
@@ -125,12 +152,13 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
SmallVector<Tensor*, 64> tensors(nargs); | SmallVector<Tensor*, 64> tensors(nargs); | ||||
ctx.args = &tensors[0]; | ctx.args = &tensors[0]; | ||||
ctx.nargs = nargs; | ctx.nargs = nargs; | ||||
ctx.pytype = pytype; | |||||
if (strstr(op->ob_type->tp_name, "BackwardGraph")) { | if (strstr(op->ob_type->tp_name, "BackwardGraph")) { | ||||
ctx.backward = true; | ctx.backward = true; | ||||
} | } | ||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { | |||||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||||
auto* t = tensors[i] = tw->m_tensor.get(); | auto* t = tensors[i] = tw->m_tensor.get(); | ||||
ctx.flags |= t->m_flags; | ctx.flags |= t->m_flags; | ||||
} else { | } else { | ||||
@@ -166,7 +194,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
if (nargs == 0) { | if (nargs == 0) { | ||||
throw py::type_error("too few arguments"); | throw py::type_error("too few arguments"); | ||||
} | } | ||||
if (auto* t = cast_safe(tup[0].ptr())) { | |||||
if (auto* t = try_cast(tup[0].ptr())) { | |||||
if (nargs > 1) { | if (nargs > 1) { | ||||
throw py::type_error("expect 1 argument"); | throw py::type_error("expect 1 argument"); | ||||
} | } | ||||
@@ -211,7 +239,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
auto ret = pyf(*tup); | auto ret = pyf(*tup); | ||||
auto py_ret = py::reinterpret_borrow<py::list>(ret); | auto py_ret = py::reinterpret_borrow<py::list>(ret); | ||||
if (auto* t = cast_safe(py_ret[0].ptr())) { | |||||
if (auto* t = try_cast(py_ret[0].ptr())) { | |||||
m_tensor = t->m_tensor; | m_tensor = t->m_tensor; | ||||
} | } | ||||
return; | return; | ||||
@@ -349,7 +377,7 @@ PyObject* TensorWrapper::varnode() { | |||||
} | } | ||||
void TensorWrapper::reset(PyObject* tensor) { | void TensorWrapper::reset(PyObject* tensor) { | ||||
TensorWrapper* t = TensorWrapper::cast_safe(tensor); | |||||
TensorWrapper* t = TensorWrapper::try_cast(tensor); | |||||
if (!t) { | if (!t) { | ||||
throw py::type_error("expect Tensor"); | throw py::type_error("expect Tensor"); | ||||
} | } | ||||
@@ -446,7 +474,7 @@ uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||||
} | } | ||||
} | } | ||||
// Returns the data type with sufficient size to hold all types of | |||||
// Returns the data type with sufficient size to hold all types of | |||||
// category `cat` in the list `types`. | // category `cat` in the list `types`. | ||||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | ||||
// Return value: New reference | // Return value: New reference | ||||
@@ -507,7 +535,7 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | ||||
if (handle == Py_None) continue; | if (handle == Py_None) continue; | ||||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
if (tw) { | if (tw) { | ||||
mgb::DType type = tw->m_tensor->dtype(); | mgb::DType type = tw->m_tensor->dtype(); | ||||
auto&& descr = npy::dtype_mgb2np_descr(type); | auto&& descr = npy::dtype_mgb2np_descr(type); | ||||
@@ -562,7 +590,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||||
CompNode cn; | CompNode cn; | ||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | ||||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
if (tw) { | if (tw) { | ||||
if (!valid) { | if (!valid) { | ||||
cn = tw->m_tensor->comp_node(); | cn = tw->m_tensor->comp_node(); | ||||
@@ -124,7 +124,7 @@ struct TensorWrapper { | |||||
friend wrap_t; | friend wrap_t; | ||||
inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} | inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} | ||||
inline static TensorWrapper* cast_safe(PyObject* op) { | |||||
inline static TensorWrapper* try_cast(PyObject* op) { | |||||
if (!wrap_t::type().isinstance(op)) return nullptr; | if (!wrap_t::type().isinstance(op)) return nullptr; | ||||
return cast(op); | return cast(op); | ||||
} | } | ||||
@@ -173,11 +173,26 @@ struct TensorWrapper { | |||||
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | ||||
struct ApplyContext { | struct ApplyContext { | ||||
static Tensor::flags_t global_disable; | |||||
Tensor::flags_t flags; | Tensor::flags_t flags; | ||||
std::shared_ptr<OpDef> op; | std::shared_ptr<OpDef> op; | ||||
Tensor*const* args; | Tensor*const* args; | ||||
size_t nargs; | size_t nargs; | ||||
PyTypeObject* pytype = nullptr; | |||||
bool backward = false; | bool backward = false; | ||||
class scoped_disable : NonCopyableObj { | |||||
Tensor::flags_t saved_flags; | |||||
public: | |||||
scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) { | |||||
ApplyContext::global_disable |= flags; | |||||
} | |||||
~scoped_disable() { | |||||
ApplyContext::global_disable = saved_flags; | |||||
} | |||||
}; | |||||
}; | }; | ||||
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | ||||
@@ -85,7 +85,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
// assumption: python function always returns PyList | // assumption: python function always returns PyList | ||||
auto tup = py::reinterpret_borrow<py::list>(ret); | auto tup = py::reinterpret_borrow<py::list>(ret); | ||||
for (auto i = 0; i < tup.size(); i++) { | for (auto i = 0; i < tup.size(); i++) { | ||||
auto tw = TensorWrapper::cast_safe(tup[i].ptr()); | |||||
auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||||
outputs.emplace_back(tw->m_tensor); | outputs.emplace_back(tw->m_tensor); | ||||
} | } | ||||
return outputs; | return outputs; | ||||
@@ -0,0 +1,21 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/utility.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "megbrain/imperative/ops/opr_attr.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,38 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/ops/utility.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/op_def.h" | |||||
#include "megbrain/utils/hash.h" | |||||
#include <pybind11/pybind11.h> | |||||
namespace mgb::imperative { | |||||
struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||||
pybind11::object obj; | |||||
GenericPyOp(pybind11::object obj_) : obj(std::move(obj_)) {}; | |||||
size_t hash() const override { | |||||
return pybind11::hash(obj); | |||||
} | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
return obj.equal(static_cast<const GenericPyOp&>(rhs).obj); | |||||
} | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
}; | |||||
} // namespace mgb::imperative |