@@ -74,6 +74,11 @@ class Graph(_imperative_rt.ComputingGraph): | |||
self.execute(*args) | |||
return self.wait() | |||
def _make_const_for_backward(self, data): | |||
device = as_device(data.comp_node).to_c() | |||
data = data.numpy() | |||
return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype)) | |||
def make_const(self, data, dtype=None, device=None): | |||
if isinstance(data, _imperative_rt.DeviceTensorND): | |||
assert dtype is None and device is None | |||
@@ -437,7 +442,9 @@ def _(op: OpDef, *args: VarNode): | |||
def _(op: BackwardGraph, *args: VarNode): | |||
assert args | |||
graph = args[0].graph | |||
return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args) | |||
return op.interpret( | |||
lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||
) | |||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||
@@ -449,12 +456,26 @@ def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=N | |||
class InputNode(OpNode): | |||
def __init__(self, *args: VarNode, device=None, dtype=None, shape=None, graph=None): | |||
def __init__( | |||
self, | |||
*args: VarNode, | |||
device=None, | |||
dtype=None, | |||
shape=None, | |||
graph=None, | |||
use_static_shape=False | |||
): | |||
r = _imperative_rt.DeviceTensorNDRendezvous() | |||
if device is not None: | |||
device = as_device(device).to_c() | |||
outputs = _imperative_rt.input_callback( | |||
r, device, dtype, shape, _unwrap(args), graph=graph | |||
r, | |||
device, | |||
dtype, | |||
shape, | |||
_unwrap(args), | |||
graph=graph, | |||
use_static_shape=use_static_shape, | |||
) | |||
super().__init__(outputs[0].owner) | |||
self._rendezvous = r | |||
@@ -11,6 +11,7 @@ import contextlib | |||
import functools | |||
import itertools | |||
import json | |||
import os | |||
import typing | |||
import warnings | |||
import weakref | |||
@@ -35,6 +36,10 @@ from ..core.tensor.tensor import Tensor | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
def _input_node_use_static_shape(): | |||
return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None | |||
class TraceMismatchError(RuntimeError): | |||
pass | |||
@@ -76,6 +81,7 @@ class TensorInfo: | |||
"device", | |||
"dtype", | |||
"shape", | |||
"is_const", | |||
"bound_data", | |||
# resources for execution | |||
"varnode", | |||
@@ -242,6 +248,28 @@ class trace: | |||
self._active_tensors.update(outputs) | |||
return outputs | |||
def _apply_const(self, op, args): | |||
assert not self._untraced | |||
# check against trace | |||
if self._pc >= len(self._seq): | |||
raise TraceMismatchError("trace should end here, but more op observed") | |||
record = self._seq[self._pc] | |||
op_, ihandles, ohandles = record | |||
assert isinstance(op_, Const) | |||
eq = op_.value == op.value | |||
if not isinstance(eq, bool): | |||
eq = all(eq) | |||
if not eq: | |||
raise TraceMismatchError( | |||
"const tensor violated: got a different tensor this time" | |||
) | |||
self._pc += 1 | |||
(h,) = ohandles | |||
outputs = tuple([self._tinfo[h].bound_data]) | |||
return outputs | |||
def _record_op(self, op, inputs, outputs): | |||
if skip_tracing: | |||
for x in inputs: | |||
@@ -275,7 +303,24 @@ class trace: | |||
self._active_tensors.update(outputs) | |||
def _record_const(self, op, outputs): | |||
pass | |||
if skip_tracing: | |||
(x,) = outputs | |||
h = getattr(x, "_TraceMixin__handle", None) | |||
if h is not None: | |||
self._tinfo[h].data_read = True | |||
return | |||
(x,) = outputs | |||
h, info = self._new_handle() | |||
ohandles = [h] | |||
info.external = True | |||
info.device = x.device | |||
info.dtype = x.dtype | |||
info.shape = x.shape | |||
info.bound_data = x | |||
info.is_const = True | |||
TraceMixin._TraceMixin__inject(x, h) | |||
self._seq.append((op, tuple(), tuple(ohandles))) | |||
def _set_active(self, active: bool): | |||
global active_trace | |||
@@ -308,6 +353,11 @@ class trace: | |||
for x in lazy_eval_tensors | |||
] | |||
self._apply_graph_options(lazy_eval_graph) | |||
# FIXME | |||
if self._graph_opt_level is not None: | |||
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level | |||
else: | |||
lazy_eval_graph.options.graph_opt_level = 2 | |||
lazy_eval_graph.compile(*lazy_eval_links, *readers) | |||
lazy_eval_graph() | |||
for r, x in zip(readers, lazy_eval_tensors): | |||
@@ -323,6 +373,7 @@ class trace: | |||
self._init_trace(self._symbolic) | |||
else: | |||
apply.enable(apply_compiled_mode) | |||
apply.enable(apply_const_compiled_mode) | |||
if self._graph is None: | |||
self._compile() | |||
self._graph.execute() | |||
@@ -370,6 +421,7 @@ class trace: | |||
apply.disable(apply_symbolic_mode) | |||
apply.disable(apply_const_symbolic_mode) | |||
apply.disable(apply_compiled_mode) | |||
apply.disable(apply_const_compiled_mode) | |||
self._set_active(False) | |||
def do_exit(): | |||
@@ -409,8 +461,10 @@ class trace: | |||
graph.options.no_force_inplace = True | |||
graph.options.seq_opt.enable_seq_comp_node_opt = False | |||
# graph opt level | |||
if self._graph_opt_level is not None: | |||
graph.options.graph_opt_level = self._graph_opt_level | |||
# if self._graph_opt_level is not None: | |||
# graph.options.graph_opt_level = self._graph_opt_level | |||
# FIXME | |||
graph.options.graph_opt_level = 0 | |||
# sublinear | |||
if self._sublinear_memory_config is not None: | |||
graph.options.enable_sublinear_memory_opt = True | |||
@@ -442,22 +496,49 @@ class trace: | |||
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): | |||
info = self._tinfo[h] | |||
opnode = info.data_setter = G.InputNode( | |||
device=info.device, dtype=info.dtype, shape=info.shape, graph=graph | |||
device=info.device, | |||
dtype=info.dtype, | |||
shape=info.shape, | |||
graph=graph, | |||
use_static_shape=_input_node_use_static_shape(), | |||
) | |||
need_reset_nodes.append(opnode) | |||
info.varnode = opnode.outputs[0] | |||
links += opnode.outputs[1:] | |||
for op, ihandles, ohandles in self._seq: | |||
require_links = type(op) in _io_op_types | |||
if isinstance(op, Const): | |||
assert len(ihandles) == 0 | |||
(h,) = ohandles | |||
info = self._tinfo[h] | |||
if not hasattr(info, "varnode"): | |||
assert info.external | |||
assert info.bound_data | |||
info.varnode = graph.make_const( | |||
info.bound_data.numpy(), | |||
info.bound_data.dtype, | |||
info.bound_data.device, | |||
) | |||
continue | |||
require_links = type(op) in _io_op_types | |||
ivars = [] | |||
for i, h in enumerate(ihandles): | |||
info = self._tinfo[h] | |||
if not hasattr(info, "varnode"): | |||
assert info.external | |||
if info.bound_data: | |||
info.varnode = graph.make_const(info.bound_data._dev_tensor()) | |||
if hasattr(info, "is_const") and info.is_const: | |||
info.varnode = graph.make_const( | |||
info.bound_data.numpy(), | |||
info.bound_data.dtype, | |||
info.bound_data.device, | |||
) | |||
else: | |||
info.varnode = graph.make_const( | |||
info.bound_data._dev_tensor() | |||
# info.bound_data.numpy() | |||
) | |||
else: | |||
opnode = info.data_setter = G.InputNode( | |||
*links, | |||
@@ -465,6 +546,7 @@ class trace: | |||
dtype=info.dtype, | |||
shape=info.shape, | |||
graph=graph, | |||
use_static_shape=_input_node_use_static_shape(), | |||
) | |||
need_reset_nodes.append(opnode) | |||
info.varnode, *links = opnode.outputs | |||
@@ -500,7 +582,11 @@ class trace: | |||
if info.shape_read: | |||
opnode = info.shape_reader = G.AttrOutputNode(v, *links) | |||
add_reader(opnode) | |||
# FIXME | |||
if self._graph_opt_level is not None: | |||
graph.options.graph_opt_level = self._graph_opt_level | |||
else: | |||
graph.options.graph_opt_level = 2 | |||
graph.compile(*readers) | |||
def _reset_exec_env(self): | |||
@@ -643,6 +729,17 @@ class trace: | |||
) | |||
for op, ihandles, ohandles in self._seq: | |||
if isinstance(op, Const): | |||
assert len(ihandles) == 0 | |||
(h,) = ohandles | |||
info = self._tinfo[h] | |||
if h not in h2v: | |||
assert info.external | |||
assert info.bound_data | |||
h2v[h] = graph.make_const( | |||
info.bound_data.numpy(), dtype=info.dtype, device=info.device, | |||
) | |||
continue | |||
ivars = [] | |||
for h in ihandles: | |||
info = self._tinfo[h] | |||
@@ -874,6 +971,7 @@ class CompiledTensorProxy(RawTensor): | |||
class LazyEvalTensor(RawTensor): | |||
def __init__(self, varnode): | |||
super(LazyEvalTensor, self).__init__() | |||
self.__varnode = varnode | |||
@property | |||
@@ -953,11 +1051,22 @@ def assign_raw_tensor(lhs, rhs): | |||
@apply.register() | |||
def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
graph = active_trace._lazy_eval_graph | |||
ivars = [ | |||
getattr(x, "_LazyEvalTensor__varnode", None) | |||
or graph.make_const(x._dev_tensor()) | |||
for x in args | |||
] | |||
ivars = [] | |||
for x in args: | |||
var = getattr(x, "_LazyEvalTensor__varnode", None) | |||
if var: | |||
ivars.append(var) | |||
else: | |||
data_setter = G.InputNode( | |||
device=x.device, | |||
dtype=x.dtype, | |||
shape=x.shape, | |||
graph=graph, | |||
use_static_shape=True, | |||
) | |||
var = data_setter.outputs[0] | |||
ivars.append(var) | |||
data_setter.set_value(x._dev_tensor()) | |||
require_links = type(op) in _io_op_types | |||
@@ -1004,6 +1113,20 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||
apply.disable(apply_compiled_mode) | |||
@apply.register() | |||
def apply_const_compiled_mode(op: Const, *args: RawTensor): | |||
if skip_tracing: | |||
args = [ | |||
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
for x in args | |||
] | |||
return apply.super(op, *args) | |||
return active_trace._apply_const(op, args) | |||
apply.disable(apply_const_compiled_mode) | |||
# this hook injects TraceMixin | |||
@apply.register() | |||
def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
@@ -145,11 +145,6 @@ void init_graph_rt(py::module m) { | |||
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) | |||
.def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { | |||
auto&& mgr = v->owner_graph()->static_infer_manager(); | |||
auto&& type = mgr.get_infer_type(v); | |||
using InferType = cg::static_infer::InferType; | |||
if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||
return nullptr; | |||
} | |||
return mgr.infer_shape_fallible(v); | |||
}) | |||
.def_property_readonly("value", [](cg::VarNode* v) -> py::object { | |||
@@ -437,7 +432,8 @@ void init_graph_rt(py::module m) { | |||
const DType& dtype, | |||
const TensorShape& shape, | |||
const std::vector<cg::VarNode*>& inputs, | |||
cg::ComputingGraph* graph) { | |||
cg::ComputingGraph* graph, | |||
bool use_static_shape) { | |||
if (!graph) { | |||
graph = inputs[0]->owner_graph(); | |||
} | |||
@@ -446,7 +442,9 @@ void init_graph_rt(py::module m) { | |||
sinputs.emplace_back(i); | |||
} | |||
static_assert(!std::is_reference<decltype(callback)>::value); | |||
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, shape, sinputs); | |||
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), | |||
comp_node, dtype, shape, | |||
sinputs, use_static_shape); | |||
std::vector<VarNode*> outputs; | |||
outputs.reserve(soutputs.size()); | |||
for (auto i : soutputs) { | |||
@@ -490,23 +488,29 @@ void init_graph_rt(py::module m) { | |||
const DType& dtype, | |||
const TensorShape& shape, | |||
const std::vector<cg::VarNode*>& inputs, | |||
cg::ComputingGraph* graph) { | |||
return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, shape, inputs, graph); | |||
cg::ComputingGraph* graph, | |||
bool use_static_shape) { | |||
return input_callback( | |||
[f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, | |||
comp_node, dtype, shape, inputs, graph, use_static_shape); | |||
}, | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), | |||
py::arg("graph") = py::none(), py::arg("use_static_shape") = false); | |||
m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p, | |||
const CompNode& comp_node, | |||
const DType& dtype, | |||
const TensorShape& shape, | |||
const std::vector<cg::VarNode*>& inputs, | |||
cg::ComputingGraph* graph) { | |||
cg::ComputingGraph* graph, | |||
bool use_static_shape) { | |||
auto f = [p]() -> DeviceTensorND { | |||
return p->get(); | |||
}; | |||
return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph); | |||
return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph, use_static_shape); | |||
}, | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), | |||
py::arg("graph") = py::none(), py::arg("use_static_shape") = false); | |||
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, | |||
std::shared_ptr<RendezvousBase> r = {}, bool borrow = false, bool prefer_host_value = false) { | |||
@@ -97,7 +97,9 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
for param in net.parameters(): | |||
ori_params[param] = np.copy(param.numpy()) | |||
train_func(np.random.random(data_shape).astype(np.float32), opt=opt, gm=gm) | |||
train_func( | |||
tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm | |||
) | |||
step += 1 | |||
check_func(ori_params, net.parameters(), step) | |||
@@ -176,6 +176,7 @@ def test_trace_profiler(): | |||
assert out.get("profiler") | |||
@pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
def test_goptions(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x): | |||
@@ -194,6 +195,7 @@ def test_goptions(): | |||
np.testing.assert_equal(g(d).numpy().item(), 1.0) | |||
@pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
def test_goptions_log_sum_exp(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x, y): | |||
@@ -33,14 +33,18 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback); | |||
InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | |||
const VarNodeArray& inputs, | |||
const TensorShape& output_shape, | |||
const OperatorNodeConfig& config) | |||
const OperatorNodeConfig& config, | |||
bool use_static_shape) | |||
: Super(&graph, config, "input_callback", inputs), | |||
m_output_shape(output_shape), m_callback(callback) { | |||
m_output_shape(output_shape), m_callback(callback), m_use_static_shape(use_static_shape) { | |||
for (VarNode* i : inputs) { | |||
add_input({i}); | |||
} | |||
DType dt = config.output_dtype(); | |||
mgb_assert(dt.valid()); | |||
if(m_use_static_shape){ | |||
mgb_assert(m_output_shape.ndim); | |||
} | |||
add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC).dtype(dt); | |||
add_output(None) | |||
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||
@@ -52,7 +56,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | |||
SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||
callback_t callback, CompNode comp_node, | |||
DType dtype, const TensorShape& shape, | |||
const SymbolVarArray& inputs) { | |||
const SymbolVarArray& inputs, | |||
bool use_static_shape) { | |||
mgb_assert(comp_node.valid()); | |||
mgb_assert(dtype.valid()); | |||
OperatorNodeConfig config; | |||
@@ -60,24 +65,33 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||
config.output_dtype(dtype); | |||
auto vinputs = to_var_node_array(inputs); | |||
auto opr = graph.insert_opr( | |||
std::make_unique<InputCallback>(graph, callback, vinputs, shape, config)); | |||
std::make_unique<InputCallback>(graph, callback, vinputs, shape, config, use_static_shape)); | |||
return to_symbol_var_array(opr->output()); | |||
} | |||
void InputCallback::init_output_static_infer_desc() { | |||
if (m_output_shape.ndim) { | |||
// Write this shape to static infer manager. The effect is | |||
// that infer_shape_fallible() will return a non-empty shape | |||
// while get_infer_type() remains NO_DESC. Most places check | |||
// infer type before relying on inferred shape so things | |||
// won't break. Memory optimizer however, deliberately omits | |||
// infer type check so it will be able to use this shape for hint. | |||
using namespace cg::static_infer; | |||
auto* var = output(0); | |||
var->shape(m_output_shape); | |||
auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); | |||
auto* handle = mgr.get_tag_handler_for_shape(var); | |||
handle->sync_from_var(); | |||
using namespace cg::static_infer; | |||
if(m_use_static_shape) { | |||
auto &&mgr = owner_graph()->static_infer_manager(); | |||
auto infer_shape = [this](TensorShape &dest, const InpVal &) { | |||
dest = m_output_shape; | |||
return true; | |||
}; | |||
mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, infer_shape}); | |||
} else { | |||
if (m_output_shape.ndim) { | |||
// Write this shape to static infer manager. The effect is | |||
// that infer_shape_fallible() will return a non-empty shape | |||
// while get_infer_type() remains NO_DESC. Most places check | |||
// infer type before relying on inferred shape so things | |||
// won't break. Memory optimizer however, deliberately omits | |||
// infer type check so it will be able to use this shape for hint. | |||
auto* var = output(0); | |||
var->shape(m_output_shape); | |||
auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); | |||
auto* handle = mgr.get_tag_handler_for_shape(var); | |||
handle->sync_from_var(); | |||
} | |||
} | |||
} | |||
@@ -92,6 +106,9 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | |||
void InputCallback::scn_do_execute() { | |||
auto dev_tensor = m_callback(); | |||
if (m_use_static_shape) { | |||
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); | |||
} | |||
output(0)->reset_dev_tensor_from_tensor(dev_tensor); | |||
} | |||
@@ -101,7 +118,10 @@ cg::OperatorNodeBase* InputCallback::shallow_copy( | |||
const OperatorNodeConfig &config) { | |||
auto &&opr = opr_.cast_final_safe<InputCallback>(); | |||
auto* graph = ctx.owner_graph(opr, inputs); | |||
return graph->insert_opr(std::make_unique<InputCallback>(*graph, opr.m_callback, inputs, opr.m_output_shape, config)); | |||
return graph->insert_opr( | |||
std::make_unique<InputCallback>(*graph, opr.m_callback, | |||
inputs, opr.m_output_shape, | |||
config, opr.m_use_static_shape)); | |||
} | |||
MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy); | |||
@@ -35,13 +35,15 @@ public: | |||
callback_t callback, | |||
const VarNodeArray& inputs, | |||
const TensorShape& output_shape, | |||
const OperatorNodeConfig &config); | |||
const OperatorNodeConfig &config, | |||
bool use_static_shape); | |||
static SymbolVarArray make(cg::ComputingGraph& graph, | |||
callback_t callback, | |||
CompNode comp_node, | |||
DType dtype, | |||
const TensorShape& shape, | |||
const SymbolVarArray& inputs = {}); | |||
const SymbolVarArray& inputs = {}, | |||
bool use_static_shape = false); | |||
static cg::OperatorNodeBase* shallow_copy( | |||
const serialization::OprShallowCopyContext &ctx, | |||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
@@ -53,6 +55,7 @@ protected: | |||
private: | |||
TensorShape m_output_shape; | |||
callback_t m_callback; | |||
bool m_use_static_shape; | |||
}; | |||
MGB_DEFINE_OPR_CLASS(OutputCallback, cg::SingleCNOperatorNodeBase) // { | |||