@@ -437,7 +437,7 @@ def _unwrap(x): | |||
return x | |||
def apply_normal_op(op: OpDef, *args: VarNode): | |||
def apply_normal_varnode(op: OpDef, *args: VarNode): | |||
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
return _wrap(outputs) | |||
@@ -447,7 +447,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode): | |||
graph = args[0].graph | |||
outputs = op.interpret( | |||
op, | |||
lambda op, args: apply_normal_op(op, *args), | |||
lambda op, args: apply_normal_varnode(op, *args), | |||
graph._make_const_for_backward, | |||
args, | |||
) | |||
@@ -41,7 +41,7 @@ from ..core._imperative_rt.ops import ( | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..core._wrap import device as as_device | |||
from ..core.ops.builtin import OpDef | |||
from ..core.ops.builtin import BackwardGraph, OpDef | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
@@ -372,6 +372,7 @@ class trace: | |||
lazy_eval_graph() | |||
for r, x in zip(readers, lazy_eval_tensors): | |||
x()._handle = RawTensor(r.op.get_value())._handle | |||
x()._reset_varnode() | |||
@contextlib.contextmanager | |||
def _setup(self): | |||
@@ -580,9 +581,11 @@ class trace: | |||
ivars.append(info.varnode) | |||
ivars = [RawTensor(ivar) for ivar in ivars] | |||
ovars = apply(op, *ivars) | |||
ovars = [x._varnode for x in ovars] | |||
if isinstance(op, BackwardGraph): | |||
ovars = G.apply_backward_varnode(op, *ivars) | |||
else: | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
if require_links and len(ovars) > 0: | |||
io_links = (ovars[0],) | |||
assert len(ovars) == len(ohandles) | |||
@@ -768,11 +771,10 @@ class trace: | |||
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | |||
) | |||
ivars.append(h2v[h]) | |||
ivars = [RawTensor(ivar) for ivar in ivars] | |||
ovars = apply(op, *ivars) | |||
ovars = [x._varnode for x in ovars] | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
assert len(ovars) == len(ohandles) | |||
h2v.update(zip(ohandles, ovars)) | |||
unset_tracing() | |||
dest_vars = [] | |||
for i, h in enumerate(self._output_bindings): | |||
@@ -781,7 +783,6 @@ class trace: | |||
v.name = output_names[i] | |||
dest_vars.append(v) | |||
dest_vars = [G.VarNode(var) for var in dest_vars] | |||
if optimize_for_inference: | |||
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
@@ -1007,7 +1008,6 @@ def assign_raw_tensor(lhs, rhs): | |||
lhs.__init__(rhs) | |||
# this hook turns RawTensor into LazyEvalTensor(varnode) | |||
def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
graph = active_trace._lazy_eval_graph | |||
ivars = [] | |||
@@ -1038,13 +1038,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
ivars[0] = opnode.outputs[0] | |||
active_trace._lazy_eval_links = (ivars[0],) | |||
ivars = [ | |||
RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar) | |||
for ivar in ivars | |||
] | |||
unset_symbolic() | |||
outputs = apply(op, *ivars) | |||
set_symbolic() | |||
if isinstance(op, BackwardGraph): | |||
ovars = G.apply_backward_varnode(op, *ivars) | |||
else: | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
outputs = [RawTensor(o) for o in ovars] | |||
if require_links: | |||
active_trace._lazy_eval_links = (outputs[0]._varnode,) | |||
@@ -392,6 +392,10 @@ void TensorWrapper::reset(PyObject* tensor) { | |||
m_tensor = t->m_tensor; | |||
} | |||
void TensorWrapper::reset_varnode() { | |||
m_tensor->m_var = nullptr; | |||
} | |||
PyObject* TensorWrapper::detach() { | |||
PyObject* self = wrap_t::pycast(this); | |||
PyTypeObject* pytype = self->ob_type; | |||
@@ -687,6 +691,7 @@ void init_tensor(py::module m) { | |||
.def<&TensorWrapper::_swap_out>("_swap_out") | |||
.def<&TensorWrapper::_swap_in>("_swap_in") | |||
.def<&TensorWrapper::_drop>("_drop") | |||
.def<&TensorWrapper::reset_varnode>("_reset_varnode") | |||
.def_getset<&TensorWrapper::varnode>("_varnode") | |||
.def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read") | |||
.def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read") | |||
@@ -155,6 +155,7 @@ struct TensorWrapper { | |||
void _swap_out(); | |||
void _drop(); | |||
PyObject* varnode(); | |||
void reset_varnode(); | |||
PyObject* handle(); | |||
void set_handle(PyObject *); | |||
@@ -17,30 +17,9 @@ namespace py = pybind11; | |||
namespace mgb::imperative::python { | |||
apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) { | |||
apply_result_t outputs; | |||
cg::VarNodeArray vinputs(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
vinputs[i] = ctx.args[i]->m_var; | |||
} | |||
auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs); | |||
for (size_t i = 0; i < ovars.size(); i++) { | |||
outputs.emplace_back(std::make_shared<Tensor>(ovars[i])); | |||
} | |||
return outputs; | |||
} | |||
apply_result_t apply_trace(ApplyContext& ctx) { | |||
apply_result_t outputs; | |||
bool run_apply_on_var_node = false; | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr)); | |||
} | |||
if (ctx.backward) { | |||
// reach here when symbolic=True or compiled=True | |||
// call megbrain_graph.py apply(BackwardGraph, *args) | |||
@@ -63,10 +42,6 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
return outputs; | |||
} | |||
if (run_apply_on_var_node && !is_symbolic) { | |||
return apply_tensor_on_var_node(ctx); | |||
} | |||
py::object pyf; | |||
if (is_compiled) { | |||
// run apply in compiled mode, step 2, 3, etc | |||
@@ -112,7 +112,7 @@ def test_quint8_typecvt(): | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def typecvt(x, dt=None): | |||
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
(y,) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x) | |||
return y | |||
# convert to quint8 | |||
@@ -193,7 +193,7 @@ def test_quint4_typecvt(): | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def typecvt(x, dt=None): | |||
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
(y,) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x) | |||
return y | |||
# convert to quint4 | |||
@@ -72,7 +72,7 @@ def test_op(): | |||
lambda: x, device=x.comp_node, dtype=x.dtype, graph=g | |||
) | |||
neg = Elemwise(Elemwise.Mode.NEGATE) | |||
v = mgb_graph.apply_normal_op(neg, v)[0] | |||
v = mgb_graph.apply_normal_varnode(neg, v)[0] | |||
y = Future() | |||
v = mgb_graph.output_callback(y.set_result, v) | |||
f = g.compile(v) | |||
@@ -90,7 +90,7 @@ def test_exception(): | |||
g = mgb_graph.Graph() | |||
x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g) | |||
neg = Elemwise(Elemwise.Mode.NEGATE) | |||
y = mgb_graph.OutputNode(mgb_graph.apply_normal_op(neg, x)[0]) | |||
y = mgb_graph.OutputNode(mgb_graph.apply_normal_varnode(neg, x)[0]) | |||
f = g.compile(y.outputs[0]) | |||
try: | |||
f.execute() | |||
@@ -16,7 +16,7 @@ import megengine.module as M | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine.core.ops.builtin import Elemwise | |||
from megengine.core.tensor import megbrain_graph as mgb_graph | |||
from megengine.core.tensor.megbrain_graph import apply_normal_op | |||
from megengine.core.tensor.megbrain_graph import apply_normal_varnode | |||
from megengine.core.tensor.utils import astensor1d | |||
from megengine.jit import trace | |||
@@ -34,9 +34,9 @@ def test_replace_vars(): | |||
const = g.make_const(1.234, device=device) | |||
add_op = Elemwise(Elemwise.Mode.ADD) | |||
mul_op = Elemwise(Elemwise.Mode.MUL) | |||
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] | |||
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] | |||
rst = apply_normal_op(add_op, a_plus_a_mul_const, a.outputs[0])[0] | |||
a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0] | |||
a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0] | |||
rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0] | |||
(new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node}) | |||
out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) | |||
func = g.compile(out.outputs[0]) | |||
@@ -56,10 +56,10 @@ def test_replace_oprs(): | |||
const = g.make_const(1.25, device=device) | |||
add_op = Elemwise(Elemwise.Mode.ADD) | |||
mul_op = Elemwise(Elemwise.Mode.MUL) | |||
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] | |||
a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0] | |||
old_opr = a_plus_a.op | |||
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] | |||
a_mul_a = apply_normal_op(mul_op, a.outputs[0], a.outputs[0])[0] | |||
a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0] | |||
a_mul_a = apply_normal_varnode(mul_op, a.outputs[0], a.outputs[0])[0] | |||
new_opr = a_mul_a.op | |||
(new,) = cgtools.replace_oprs( | |||
[a_plus_a_mul_const._node], {old_opr._node: new_opr._node} | |||
@@ -163,6 +163,7 @@ def test_trace_profiler(): | |||
assert out.get("profiler") | |||
@pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
def test_goptions(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x): | |||
@@ -181,6 +182,7 @@ def test_goptions(): | |||
np.testing.assert_equal(g(d).numpy().item(), 1.0) | |||
@pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
def test_goptions_log_sum_exp(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x, y): | |||