Browse Source

feat(mge/imperative): implement trace and dump under new core implementation

GitOrigin-RevId: 4edc38eaf2
release-1.2
Megvii Engine Team 4 years ago
parent
commit
b310f2615b
14 changed files with 652 additions and 413 deletions
  1. +1
    -1
      imperative/python/megengine/core/ops/special.py
  2. +16
    -10
      imperative/python/megengine/core/tensor/megbrain_graph.py
  3. +19
    -1
      imperative/python/megengine/jit/__init__.py
  4. +175
    -302
      imperative/python/megengine/jit/tracing.py
  5. +3
    -2
      imperative/python/megengine/tensor.py
  6. +6
    -1
      imperative/python/src/grad.cpp
  7. +221
    -28
      imperative/python/src/tensor.cpp
  8. +50
    -10
      imperative/python/src/tensor.h
  9. +94
    -0
      imperative/python/src/trace.cpp
  10. +3
    -2
      imperative/python/src/trace.h
  11. +24
    -0
      imperative/python/src/trace_info.h
  12. +39
    -54
      imperative/python/test/unit/test_tracing.py
  13. +0
    -1
      sdk/load-and-run/dump_with_testcase_mge.py
  14. +1
    -1
      sdk/xor-deploy/xornet.py

+ 1
- 1
imperative/python/megengine/core/ops/special.py View File

@@ -20,4 +20,4 @@ class Const:


def __call__(self, *reference): def __call__(self, *reference):
Wrapper = type(reference[0]) Wrapper = type(reference[0])
return (Wrapper(self.value, self.dtype, self.device),)
return (Wrapper(self.value, self.dtype, self.device, True),)

+ 16
- 10
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -19,10 +19,11 @@ import numpy as np
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id
from .. import _imperative_rt from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._imperative_rt.ops import BackwardGraph from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device from .._wrap import device as as_device
from ..ops.builtin import OpDef from ..ops.builtin import OpDef
from .core import OpBase, TensorBase, apply
from .core import OpBase, TensorBase




class Graph(_imperative_rt.ComputingGraph): class Graph(_imperative_rt.ComputingGraph):
@@ -269,9 +270,8 @@ def optimize_for_inference(dest_vars, **kwargs):
if kwargs: if kwargs:
raise ValueError("unknown options: %s" % list(kwargs)) raise ValueError("unknown options: %s" % list(kwargs))


res_vars = _imperative_rt.optimize_for_inference(
[i._node for i in dest_vars], inference_options
)
dest_vars = [var._node for var in dest_vars]
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return [VarNode(i) for i in res_vars] return [VarNode(i) for i in res_vars]




@@ -437,19 +437,25 @@ def _unwrap(x):
return x return x




@apply.register()
def _(op: OpDef, *args: VarNode):
def apply_normal_op(op: OpDef, *args: VarNode):
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs) return _wrap(outputs)




@apply.register()
def _(op: BackwardGraph, *args: VarNode):
def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
assert args assert args
graph = args[0].graph graph = args[0].graph
return BackwardGraph.interpret(
op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args
outputs = op.interpret(
op,
lambda op, args: apply_normal_op(op, *args),
graph._make_const_for_backward,
args,
) )
outputs = [o._node if hasattr(o, "_node") else o for o in outputs]
return outputs


set_cpp_apply_backward_varnode(apply_backward_varnode)




def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):


+ 19
- 1
imperative/python/megengine/jit/__init__.py View File

@@ -6,5 +6,23 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import (
set_cpp_apply_compiled_mode,
set_cpp_apply_const_compiled_mode,
set_cpp_apply_const_with_tracing,
set_cpp_apply_with_tracing,
)
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import exclude_from_trace, trace
from .tracing import (
apply_compiled_mode,
apply_const_compiled_mode,
apply_const_with_tracing,
apply_with_tracing,
exclude_from_trace,
trace,
)

set_cpp_apply_with_tracing(apply_with_tracing)
set_cpp_apply_const_with_tracing(apply_const_with_tracing)
set_cpp_apply_compiled_mode(apply_compiled_mode)
set_cpp_apply_const_compiled_mode(apply_const_compiled_mode)

+ 175
- 302
imperative/python/megengine/jit/tracing.py View File

@@ -18,8 +18,20 @@ import weakref


import numpy as np import numpy as np


from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.core2 import Tensor
from ..core._imperative_rt import GraphProfiler, common, put
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
TensorWeakRef,
apply,
call_level,
set_compiled,
set_symbolic,
set_tracing,
skip_tracing,
unset_compiled,
unset_symbolic,
unset_tracing,
)
from ..core._imperative_rt.ops import ( from ..core._imperative_rt.ops import (
CollectiveComm, CollectiveComm,
GaussianRNG, GaussianRNG,
@@ -29,10 +41,9 @@ from ..core._imperative_rt.ops import (
) )
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops.builtin import OpDef
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig




@@ -45,7 +56,6 @@ class TraceMismatchError(RuntimeError):




active_trace = None active_trace = None
skip_tracing = False




def is_tracing(): def is_tracing():
@@ -63,11 +73,13 @@ def exclude_from_trace():
return return
try: try:
skip_tracing = True skip_tracing = True
unset_tracing()
if active_trace is not None: if active_trace is not None:
active_trace._begin_excluded_region() active_trace._begin_excluded_region()
yield yield
finally: finally:
skip_tracing = False skip_tracing = False
set_tracing()




class TensorInfo: class TensorInfo:
@@ -75,9 +87,6 @@ class TensorInfo:
# collected attributes # collected attributes
"external", "external",
"exported", "exported",
"data_read",
"shape_read",
"value_read",
"device", "device",
"dtype", "dtype",
"shape", "shape",
@@ -93,9 +102,6 @@ class TensorInfo:


def __init__(self): def __init__(self):
self.exported = None self.exported = None
self.data_read = None
self.shape_read = None
self.value_read = None
self.bound_data = None self.bound_data = None


self.data_setter = None self.data_setter = None
@@ -147,6 +153,8 @@ class trace:
self._profiler = None self._profiler = None
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._symbolic_shape = symbolic_shape self._symbolic_shape = symbolic_shape
self._handle2tensors = {}
self._handle2compiledtensors = {}


self._reset() self._reset()


@@ -158,9 +166,9 @@ class trace:
self._graph = None self._graph = None
self._need_reset_nodes = None self._need_reset_nodes = None
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = weakref.WeakSet()
self._lazy_eval_tensors = set()
self._lazy_eval_links = None self._lazy_eval_links = None
self._active_tensors = weakref.WeakSet()
self._active_tensors = set()
self._tensor_remaps = None self._tensor_remaps = None
self._inputs_to_restore = None self._inputs_to_restore = None
self._arg_bindings = None self._arg_bindings = None
@@ -220,66 +228,72 @@ class trace:
) )
info.data_setter.set_value(x._dev_tensor()) info.data_setter.set_value(x._dev_tensor())
else: else:
if x.__class__ is not CompiledTensorProxy:
if x not in self._tensor_remaps:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else:
x = self._tensor_remaps[x]
if x._CompiledTensorProxy__handle != h:
raise TraceMismatchError(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
)
pass
# if x.__class__ is not CompiledTensorProxy:
# if x not in self._tensor_remaps:
# raise TraceMismatchError(
# "unexpected capture: trying to use an external tensor as "
# "input, but that input was an internal tensor last time"
# )
# else:
# x = self._tensor_remaps[x]
# if x._CompiledTensorProxy__handle != h:
# raise TraceMismatchError(
# "mis-wiring: input edge to an data flow "
# "graph node is different from last time"
# )


self._pc += 1 self._pc += 1
outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
self._active_tensors.update(outputs)
for h in ohandles:
t = CompiledTensorProxy(h)
t._dev_tensor()
self._handle2compiledtensors[h] = t
outputs = [self._handle2tensors[h] for h in ohandles]
self._active_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs return outputs


def _apply_const(self, op, args):
def _apply_const(self, value, dtype, device):
assert not self._untraced assert not self._untraced
# check against trace # check against trace
if self._pc >= len(self._seq): if self._pc >= len(self._seq):
raise TraceMismatchError("trace should end here, but more op observed") raise TraceMismatchError("trace should end here, but more op observed")
record = self._seq[self._pc] record = self._seq[self._pc]
op_, ihandles, ohandles = record op_, ihandles, ohandles = record
assert isinstance(op_, Const)

eq = op_.value == op.value
if not isinstance(eq, bool):
eq = all(eq)
if not eq:
raise TraceMismatchError(
"const tensor violated: got a different tensor this time"
)
assert isinstance(op_, str) and op_ == "Const"

# TODO : assert on const value
# eq = value == self._tinfo[ohandles[0]].bound_data.numpy()
# if not isinstance(eq, bool):
# eq = all(eq)
# if not eq:
# raise TraceMismatchError(
# "const tensor violated: got a different tensor this time"
# )


self._pc += 1 self._pc += 1
(h,) = ohandles (h,) = ohandles
outputs = tuple([self._tinfo[h].bound_data])
outputs = [self._tinfo[h].bound_data]
return outputs return outputs


def _record_op(self, op, inputs, outputs): def _record_op(self, op, inputs, outputs):
if skip_tracing: if skip_tracing:
for x in inputs: for x in inputs:
h = getattr(x, "_TraceMixin__handle", None)
if h is not None:
self._tinfo[h].data_read = True
h = getattr(x, "mixin_handle", -1)
if h >= 0:
x.data_read = True
return return


ihandles = [] ihandles = []
for x in inputs: for x in inputs:
h = getattr(x, "_TraceMixin__handle", None)
if h is None or (not self._capture_as_const and self._tinfo[h].exported):
h = getattr(x, "mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle() h, info = self._new_handle()
info.external = True info.external = True
info.device = x.device info.device = x.device
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.shape info.shape = x.shape
if self._capture_as_const: if self._capture_as_const:
info.bound_data = x
info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False)


ihandles.append(h) ihandles.append(h)


@@ -288,17 +302,18 @@ class trace:
h, info = self._new_handle() h, info = self._new_handle()
ohandles.append(h) ohandles.append(h)
info.external = False info.external = False
TraceMixin._TraceMixin__inject(x, h)
x.mixin_handle = h
self._handle2tensors[h] = x


self._seq.append((op, tuple(ihandles), tuple(ohandles))) self._seq.append((op, tuple(ihandles), tuple(ohandles)))
self._active_tensors.update(outputs)
self._active_tensors.update([TensorWeakRef(o) for o in outputs])


def _record_const(self, op, outputs):
def _record_const(self, outputs):
if skip_tracing: if skip_tracing:
(x,) = outputs (x,) = outputs
h = getattr(x, "_TraceMixin__handle", None)
if h is not None:
self._tinfo[h].data_read = True
h = getattr(x, "mixin_handle", -1)
if h >= 0:
x.data_read = True
return return


(x,) = outputs (x,) = outputs
@@ -310,8 +325,9 @@ class trace:
info.shape = x.shape info.shape = x.shape
info.bound_data = x info.bound_data = x
info.is_const = True info.is_const = True
TraceMixin._TraceMixin__inject(x, h)
self._seq.append((op, tuple(), tuple(ohandles)))
x.mixin_handle = h
self._handle2tensors[h] = x
self._seq.append(("Const", tuple(), tuple(ohandles)))


def _set_active(self, active: bool): def _set_active(self, active: bool):
global active_trace global active_trace
@@ -324,11 +340,8 @@ class trace:
active_trace = None active_trace = None


def _init_trace(self, symbolic: bool): def _init_trace(self, symbolic: bool):
apply.enable(apply_with_tracing)
apply.enable(apply_const_with_tracing)
if symbolic: if symbolic:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
set_symbolic()
self._lazy_eval_graph = G.Graph() self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph) self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_links = () self._lazy_eval_links = ()
@@ -339,10 +352,7 @@ class trace:
return escaped_tensors return escaped_tensors


def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
readers = [
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
for x in lazy_eval_tensors
]
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph) self._apply_graph_options(lazy_eval_graph)
# FIXME # FIXME
if self._graph_opt_level is not None: if self._graph_opt_level is not None:
@@ -353,20 +363,22 @@ class trace:
lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors): for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
x()._handle = RawTensor(r.op.get_value())._handle


@contextlib.contextmanager @contextlib.contextmanager
def _setup(self): def _setup(self):
interrupted = False interrupted = False


def do_enter(): def do_enter():
set_tracing()
self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
self._set_active(True) self._set_active(True)
if self._untraced: if self._untraced:
self._init_trace(self._symbolic) self._init_trace(self._symbolic)
else: else:
apply.enable(apply_compiled_mode)
apply.enable(apply_const_compiled_mode)
# disable symbolic mode
unset_symbolic()
set_compiled()
if self._graph is None: if self._graph is None:
self._compile() self._compile()
self._graph.execute() self._graph.execute()
@@ -375,12 +387,12 @@ class trace:
escaped_tensors = self._take_escaped_tensors() escaped_tensors = self._take_escaped_tensors()
if self._untraced: if self._untraced:
for x in escaped_tensors: for x in escaped_tensors:
info = self._tinfo[x._TraceMixin__handle]
info.data_read = True
x._TraceMixin__restore()
info = self._tinfo[x().mixin_handle]
x().data_read = True
x().mixin_handle = -1
if self._inputs_to_restore: if self._inputs_to_restore:
for x in self._inputs_to_restore: for x in self._inputs_to_restore:
x._TraceMixin__restore()
x.mixin_handle = -1
if self._symbolic and ( if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links self._lazy_eval_tensors or self._lazy_eval_links
): ):
@@ -399,7 +411,7 @@ class trace:
if self._pc == len(self._seq): if self._pc == len(self._seq):
for x in escaped_tensors: for x in escaped_tensors:
try: try:
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
except TraceMismatchError: except TraceMismatchError:
# TraceMismatchError thrown in do_exit # TraceMismatchError thrown in do_exit
pass pass
@@ -409,22 +421,20 @@ class trace:
# reset status # reset status
self._pc = 0 self._pc = 0
self._tensor_remaps = None self._tensor_remaps = None
apply.disable(apply_with_tracing)
apply.disable(apply_const_with_tracing)
apply.disable(apply_symbolic_mode)
apply.disable(apply_const_symbolic_mode)
apply.disable(apply_compiled_mode)
apply.disable(apply_const_compiled_mode)
self._set_active(False) self._set_active(False)
# Restore global variable
set_symbolic_shape(self._save_symbolic_shape) set_symbolic_shape(self._save_symbolic_shape)
unset_compiled()
unset_symbolic()
unset_tracing()


def do_exit(): def do_exit():
unset_tracing()
if not self._untraced and self._pc != len(self._seq): if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end") raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced: if not self._symbolic or not self._untraced:
for x in self._active_tensors: for x in self._active_tensors:
x._dev_tensor()
x()._dev_tensor()
x().mixin_handle = -1


try: try:
do_enter() do_enter()
@@ -447,9 +457,9 @@ class trace:
# conditionally reading a compiled tensor in excluded region # conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read # is permitted, so we have to assume every tensor might be read
for x in self._active_tensors: for x in self._active_tensors:
info = self._tinfo[x._TraceMixin__handle]
info = self._tinfo[x().mixin_handle]
info.exported = True info.exported = True
info.data_read = True
x().data_read = True


def _apply_graph_options(self, graph): def _apply_graph_options(self, graph):


@@ -503,7 +513,7 @@ class trace:
in_out_links += opnode.outputs[1:] in_out_links += opnode.outputs[1:]


for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
if isinstance(op, Const):
if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0 assert len(ihandles) == 0
(h,) = ohandles (h,) = ohandles
info = self._tinfo[h] info = self._tinfo[h]
@@ -554,7 +564,10 @@ class trace:
io_links = (info.varnode,) io_links = (info.varnode,)


ivars.append(info.varnode) ivars.append(info.varnode)

ivars = [RawTensor(ivar) for ivar in ivars]
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
ovars = [x._varnode for x in ovars]
if require_links and len(ovars) > 0: if require_links and len(ovars) > 0:
io_links = (ovars[0],) io_links = (ovars[0],)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
@@ -568,7 +581,8 @@ class trace:
readers.append(opnode.outputs[0]) readers.append(opnode.outputs[0])
in_out_links = opnode.outputs in_out_links = opnode.outputs


if info.data_read:
x = self._handle2tensors[h]
if x.data_read:
# Shape can be obtained from data so doesn't need its own # Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately # output node. On the other hand, value is read separately
# to leverage eager h2d copy # to leverage eager h2d copy
@@ -581,6 +595,7 @@ class trace:
if info.shape_read: if info.shape_read:
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)

# FIXME # FIXME
if self._graph_opt_level is not None: if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level graph.options.graph_opt_level = self._graph_opt_level
@@ -593,18 +608,6 @@ class trace:
for opnode in self._need_reset_nodes: for opnode in self._need_reset_nodes:
opnode.reset() opnode.reset()


def _require_shape(self, handle):
info = self._tinfo[handle]
info.shape_read = True

def _require_value(self, handle):
info = self._tinfo[handle]
info.value_read = True

def _require_data(self, handle):
info = self._tinfo[handle]
info.data_read = True

def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if is_tracing(): if is_tracing():
return self.__wrapped__(*args, **kwargs) return self.__wrapped__(*args, **kwargs)
@@ -728,8 +731,9 @@ class trace:
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
) )


set_tracing()
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
if isinstance(op, Const):
if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0 assert len(ihandles) == 0
(h,) = ohandles (h,) = ohandles
info = self._tinfo[h] info = self._tinfo[h]
@@ -750,7 +754,9 @@ class trace:
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
) )
ivars.append(h2v[h]) ivars.append(h2v[h])
ivars = [RawTensor(ivar) for ivar in ivars]
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
ovars = [x._varnode for x in ovars]
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars)) h2v.update(zip(ohandles, ovars))


@@ -761,6 +767,7 @@ class trace:
v.name = output_names[i] v.name = output_names[i]
dest_vars.append(v) dest_vars.append(v)


dest_vars = [G.VarNode(var) for var in dest_vars]
if optimize_for_inference: if optimize_for_inference:
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) dest_vars = G.optimize_for_inference(dest_vars, **kwargs)


@@ -782,15 +789,15 @@ class trace:
info.external = False info.external = False
info.device = x.device info.device = x.device
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.shape
TraceMixin._TraceMixin__inject(x, h)
info.shape = x.numpy().shape
x.mixin_handle = h
self._handle2tensors[h] = x
self._inputs_to_restore.append(x) self._inputs_to_restore.append(x)
return h return h


self._arg_bindings = [] self._arg_bindings = []
for i, x in enumerate(args): for i, x in enumerate(args):
x = find_raw_tensor(x)
if x is None:
if not isinstance(x, RawTensor):
raise TypeError( raise TypeError(
"positional arguments should all be tensor " "positional arguments should all be tensor "
"but args[%d] cannot be recognized as one" % i "but args[%d] cannot be recognized as one" % i
@@ -799,8 +806,7 @@ class trace:


self._kwarg_bindings = {} self._kwarg_bindings = {}
for k, x in kwargs.items(): for k, x in kwargs.items():
x = find_raw_tensor(x)
if x is not None:
if isinstance(x, RawTensor):
self._kwarg_bindings[k] = record_input(x) self._kwarg_bindings[k] = record_input(x)
else: else:
if len(args) != len(self._arg_bindings): if len(args) != len(self._arg_bindings):
@@ -809,8 +815,7 @@ class trace:
self._tensor_remaps = {} self._tensor_remaps = {}


for i, (h, x) in enumerate(zip(self._arg_bindings, args)): for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
x = find_raw_tensor(x)
if x is None:
if not isinstance(x, RawTensor):
raise TypeError( raise TypeError(
"positional arguments should all be tensor " "positional arguments should all be tensor "
"but args[%d] cannot be recognized as one" % i "but args[%d] cannot be recognized as one" % i
@@ -825,8 +830,7 @@ class trace:


kwargs_tensors = {} kwargs_tensors = {}
for k, x in kwargs.items(): for k, x in kwargs.items():
x = find_raw_tensor(x)
if x is not None:
if isinstance(x, RawTensor):
kwargs_tensors[k] = x kwargs_tensors[k] = x
if set(kwargs_tensors) != set(self._kwarg_bindings): if set(kwargs_tensors) != set(self._kwarg_bindings):
too_many = set(kwargs_tensors) - set(self._kwarg_bindings) too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
@@ -877,18 +881,17 @@ class trace:
self._output_bindings = [] self._output_bindings = []


for i, x in enumerate(outputs): for i, x in enumerate(outputs):
x = find_raw_tensor(x)
if x is None:
if not isinstance(x, RawTensor):
raise TypeError("every item of return value should be tensor") raise TypeError("every item of return value should be tensor")
if self._untraced: if self._untraced:
if not isinstance(x, TraceMixin):
h = x.mixin_handle
if h < 0:
raise RuntimeError("output is not computed from inputs") raise RuntimeError("output is not computed from inputs")
h = x._TraceMixin__handle
self._output_bindings.append(h) self._output_bindings.append(h)
else: else:
if not isinstance(x, CompiledTensorProxy):
h = x.mixin_handle
if h not in self._handle2compiledtensors:
raise RuntimeError("output is not computed from inputs") raise RuntimeError("output is not computed from inputs")
h = x._CompiledTensorProxy__handle
if h != self._output_bindings[i]: if h != self._output_bindings[i]:
raise TraceMismatchError( raise TraceMismatchError(
"retval[%s] is a different tensor than last time" "retval[%s] is a different tensor than last time"
@@ -912,7 +915,7 @@ class trace:
) )




class CompiledTensorProxy(RawTensor):
class CompiledTensorProxy:
""" """
Duck-typed RawTensor Duck-typed RawTensor
""" """
@@ -924,6 +927,8 @@ class CompiledTensorProxy(RawTensor):
self.__shape = None self.__shape = None
self.__data = None self.__data = None
self.__value = None self.__value = None
self.__tensor = active_trace._handle2tensors[handle]
self.__tensor.mixin_handle = handle


@property @property
def dtype(self): def dtype(self):
@@ -938,19 +943,19 @@ class CompiledTensorProxy(RawTensor):
if self._isscalar: if self._isscalar:
return () return ()
if self.__shape is None: if self.__shape is None:
if self.__info.shape_read:
if self.__tensor.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape self.__shape = self.__info.shape_reader.get_value().shape
elif self.__info.data_read:
self.__shape = self._dev_tensor().shape
elif self.__tensor.data_read:
self.__shape = self.__tensor._dev_tensor().shape
else: else:
raise TraceMismatchError("shape of this tensor is not read in trace") raise TraceMismatchError("shape of this tensor is not read in trace")
return self.__shape return self.__shape


def numpy(self): def numpy(self):
if self.__value is None: if self.__value is None:
if self.__info.value_read:
if self.__tensor.value_read:
self.__value = self.__info.value_reader.get_value() self.__value = self.__info.value_reader.get_value()
elif self.__info.data_read:
elif self.__tensor.data_read:
self.__value = self._dev_tensor().numpy() self.__value = self._dev_tensor().numpy()
else: else:
raise TraceMismatchError("value of this tensor is not read in trace") raise TraceMismatchError("value of this tensor is not read in trace")
@@ -960,9 +965,11 @@ class CompiledTensorProxy(RawTensor):


def _dev_tensor(self): def _dev_tensor(self):
if self.__data is None: if self.__data is None:
if not self.__info.data_read:
if not self.__tensor.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace") raise TraceMismatchError("raw data of this tensor is not read in trace")
self.__data = self.__info.data_reader.get_value() self.__data = self.__info.data_reader.get_value()
self.__tensor._reset(RawTensor(self.__data))
self.__tensor.mixin_handle = self.__handle
return self.__data return self.__data


def _drop(self): def _drop(self):
@@ -975,132 +982,31 @@ class CompiledTensorProxy(RawTensor):
return return


def __del__(self): def __del__(self):
if self.__info.shape_read and self.__shape is not None:
if self.__tensor.shape_read and self.__shape is not None:
self.__info.shape_reader.drop_value() self.__info.shape_reader.drop_value()
if self.__info.value_read and self.__value is not None:
self.__info.value_reader.drop_value()
if self.__info.data_read and self.__data is not None:
# if self.__tensor.value_read and self.__value is not None:
# self.__info.value_reader.drop_value()
if self.__tensor.data_read and self.__data is not None:
self.__info.data_reader.drop_value() self.__info.data_reader.drop_value()




class LazyEvalTensor(RawTensor):
def __init__(self, varnode, isscalar=False):
super().__init__()
self.__varnode = varnode
self._isscalar = isscalar

@property
def dtype(self):
return self.__varnode.dtype

@property
def device(self):
return self.__varnode.device

@property
def shape(self):
if self._isscalar:
return ()
return self.__varnode.shape

def numpy(self):
ret = self.__varnode.value
if self._isscalar:
ret = ret.squeeze()
return ret

def _drop(self):
return

def _swap_in(self):
return

def _swap_out(self):
return

def _dev_tensor(self):
raise RuntimeError("cannot access data during symbolic tracing")


class TraceMixin:
__subclass_cache = {}

def __inject(self, handle):
cache = __class__.__subclass_cache
cls = self.__class__
subcls = cache.get(cls)
if subcls is None:
subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
self.__class__ = subcls
self.__handle = handle
self.__cls = cls
return self

def __restore(self):
cls = self.__cls
del self.__handle
del self.__cls
self.__class__ = cls
return self

@property
def shape(self):
if not skip_tracing:
active_trace._require_shape(self.__handle)
return super().shape

def numpy(self):
if not skip_tracing:
active_trace._require_value(self.__handle)
return super().numpy()

def _dev_tensor(self):
if not skip_tracing:
active_trace._require_data(self.__handle)
return super()._dev_tensor()

def _drop(self):
return

def _swap_in(self):
return

def _swap_out(self):
return


class TracedRawTensor(TraceMixin, RawTensor):
pass


class TracedLazyTensor(TraceMixin, LazyEvalTensor):
pass


def assign_raw_tensor(lhs, rhs): def assign_raw_tensor(lhs, rhs):
handle = rhs._handle
# Keep isscalar of lhs
isscalar = lhs._isscalar
rhs.__dict__.clear()
lhs.__dict__.clear()
lhs.__class__ = RawTensor
lhs.__init__(handle, isscalar=isscalar)
lhs.__init__(rhs)




# this hook turns RawTensor into LazyEvalTensor
@apply.register()
# this hook turns RawTensor into LazyEvalTensor(varnode)
def apply_symbolic_mode(op: OpDef, *args: RawTensor): def apply_symbolic_mode(op: OpDef, *args: RawTensor):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ivars = [] ivars = []
for x in args: for x in args:
var = getattr(x, "_LazyEvalTensor__varnode", None)
var = getattr(x, "_varnode", None)
if var: if var:
ivars.append(var) ivars.append(var)
else: else:
data_setter = G.InputNode( data_setter = G.InputNode(
device=x.device, device=x.device,
dtype=x.dtype, dtype=x.dtype,
shape=x.shape or (1,),
shape=x.numpy().shape or (1,),
graph=graph, graph=graph,
use_static_shape=True, use_static_shape=True,
) )
@@ -1119,108 +1025,75 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
ivars[0] = opnode.outputs[0] ivars[0] = opnode.outputs[0]
active_trace._lazy_eval_links = (ivars[0],) active_trace._lazy_eval_links = (ivars[0],)


ovars = apply(op, *ivars)
ivars = [
RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar)
for ivar in ivars
]
unset_symbolic()
outputs = apply(op, *ivars)
set_symbolic()


if require_links: if require_links:
active_trace._lazy_eval_links = (ovars[0],)
active_trace._lazy_eval_links = (outputs[0]._varnode,)


outputs = [LazyEvalTensor(v) for v in ovars]
active_trace._lazy_eval_tensors.update(outputs)
active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs return outputs




apply.disable(apply_symbolic_mode)


@apply.register()
def apply_const_symbolic_mode(op: Const, *args: RawTensor):
def apply_const_symbolic_mode(value, dtype, device):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(
graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True
)
active_trace._lazy_eval_tensors.add(ret)
# don't need to unset tracing
# because varnode construction will ignore tracing flag
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret))
return (ret,) return (ret,)




apply.disable(apply_const_symbolic_mode)


@apply.register()
def apply_compiled_mode(op: OpDef, *args: RawTensor): def apply_compiled_mode(op: OpDef, *args: RawTensor):
if skip_tracing: if skip_tracing:
args = [ args = [
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
for x in args for x in args
] ]
return apply.super(op, *args)
unset_tracing()
ret = apply(op, *args)
set_tracing()
return ret
return active_trace._apply_op(op, args) return active_trace._apply_op(op, args)




apply.disable(apply_compiled_mode)


@apply.register()
def apply_const_compiled_mode(op: Const, *args: RawTensor):
def apply_const_compiled_mode(value, dtype, device, is_const):
if skip_tracing: if skip_tracing:
args = [ args = [
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
for x in args for x in args
] ]
return apply.super(op, *args)
return active_trace._apply_const(op, args)
apply.disable(apply_const_compiled_mode)
unset_tracing()
ret = RawTensor(value, dtype, device, False)
set_tracing()
return ret
return active_trace._apply_const(value, dtype, device)




# this hook injects TraceMixin # this hook injects TraceMixin
@apply.register()
def apply_with_tracing(op: OpDef, *args: RawTensor): def apply_with_tracing(op: OpDef, *args: RawTensor):
outputs = apply.super(op, *args)
active_trace._record_op(op, args, outputs)
return outputs


apply.disable(apply_with_tracing)


@apply.register()
def apply_const_with_tracing(op: Const, *args: RawTensor):
outputs = apply.super(op, *args)
active_trace._record_const(op, outputs)
return outputs


apply.disable(apply_const_with_tracing)


class BrokenRawTensor(RawTensor):
def __getattribute__(self, _):
raise RuntimeError("broken due to misuse of tracing")

def __setattr__(self, *_):
raise RuntimeError("broken due to misuse of tracing")


@functools.singledispatch
def find_raw_tensor(x):
return None


@find_raw_tensor.register(RawTensor)
def _(x):
return x

if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args)
else:
unset_tracing()
outputs = apply(op, *args)
set_tracing()


@find_raw_tensor.register(TensorWrapperBase)
def _(x):
x = getattr(x, "__wrapped__", None)
if x is not None:
return find_raw_tensor(x)
active_trace._record_op(op, args, outputs)
return list(outputs)




@find_raw_tensor.register(Tensor)
def _(x):
x = getattr(x, "_data", None)
if x is not None:
return find_raw_tensor(x)
def apply_const_with_tracing(value, dtype, device, is_const):
if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device)
else:
unset_tracing()
outputs = (RawTensor(value, dtype, device, False),)
set_tracing()
active_trace._record_const(outputs)
return list(outputs)

+ 3
- 2
imperative/python/megengine/tensor.py View File

@@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
dmap_callback = None dmap_callback = None
q_dict = {"mode": None, "scale": None, "zero_point": None} q_dict = {"mode": None, "scale": None, "zero_point": None}


def __new__(cls, data, dtype=None, device=None):
def __new__(cls, data, dtype=None, device=None, is_const=False):
if device is None: if device is None:
cn = get_default_device() cn = get_default_device()
elif isinstance(device, str): elif isinstance(device, str):
@@ -40,6 +40,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
assert isinstance(device, CompNode) assert isinstance(device, CompNode)
cn = device cn = device


# import pdb; pdb.set_trace()
if isinstance(data, _Tensor): if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data) obj = _Tensor.__new__(cls, data)
else: else:
@@ -47,7 +48,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if 0 in data.strides: if 0 in data.strides:
data = data.squeeze().reshape(data.shape) data = data.squeeze().reshape(data.shape)


obj = _Tensor.__new__(cls, data, dtype, cn)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const)
return obj return obj


@property @property


+ 6
- 1
imperative/python/src/grad.cpp View File

@@ -296,7 +296,9 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta)
Tensor* args[2] = {grad.get(), delta.get()}; Tensor* args[2] = {grad.get(), delta.get()};
ctx.args = args; ctx.args = args;
ctx.flags = grad->m_flags | delta->m_flags; ctx.flags = grad->m_flags | delta->m_flags;

if (is_tracing) {
ctx.flags |= Tensor::Flags::TRACE;
}
grad = apply(ctx)[0]; grad = apply(ctx)[0];
} }


@@ -354,6 +356,9 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
} }
ctx.args = args; ctx.args = args;


if (is_tracing)
ctx.flags |= Tensor::Flags::TRACE;

auto grads = apply(ctx); auto grads = apply(ctx);


size_t j = 0; size_t j = 0;


+ 221
- 28
imperative/python/src/tensor.cpp View File

@@ -11,8 +11,10 @@


#include "./tensor.h" #include "./tensor.h"
#include "./grad.h" #include "./grad.h"
#include "./trace.h"
#include "./common.h" #include "./common.h"
#include "./numpy_dtypes.h" #include "./numpy_dtypes.h"
#include "./graph_rt.h"


#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/operators.h> #include <pybind11/operators.h>
@@ -23,6 +25,47 @@ namespace mgb::imperative::python {


std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;


py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing,
cpp_apply_compiled_mode, cpp_apply_const_compiled_mode;

py::object cpp_apply_backward_varnode;

#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
mode = pybind11::reinterpret_steal<py::object>(pyf); \
}

REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_compiled_mode)
REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode)
REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)

#undef REGISTE_APPLY_FUNC

bool is_tracing = false;
bool is_symbolic = false;
bool is_compiled = false;

int64_t call_level = 0;


#define SET_UNSET_PROP(mode) \
void set_##mode() { \
is_##mode = true; \
} \
void unset_##mode() { \
is_##mode = false; \
} \

SET_UNSET_PROP(tracing)
SET_UNSET_PROP(symbolic)
SET_UNSET_PROP(compiled)

#undef SET_UNSET_PROP

bool skip_tracing = false;

apply_result_t apply(ApplyContext& ctx) { apply_result_t apply(ApplyContext& ctx) {
// emulating scalar should be put to specific op's apply, e.g., // emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python // elementwise, reduce, typecvt. Currently it's still handled at python
@@ -36,7 +79,7 @@ apply_result_t apply(ApplyContext& ctx) {
} }


if (ctx.flags & Tensor::Flags::TRACE) { if (ctx.flags & Tensor::Flags::TRACE) {
// TODO: trace
return apply_trace(ctx);
} else { } else {
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
@@ -58,7 +101,6 @@ apply_result_t apply(ApplyContext& ctx) {


PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) { PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
try { try {

// if (kwnames && PyTuple_GET_SIZE(kwnames)) { // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
// PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
// return nullptr; // return nullptr;
@@ -67,6 +109,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
PyErr_SetString(PyExc_TypeError, "expect Op"); PyErr_SetString(PyExc_TypeError, "expect Op");
return nullptr; return nullptr;
} }

auto* op = args[0]; auto* op = args[0];


PyTypeObject* pytype = args[1]->ob_type; PyTypeObject* pytype = args[1]->ob_type;
@@ -79,18 +122,23 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
SmallVector<Tensor*, 64> tensors(nargs); SmallVector<Tensor*, 64> tensors(nargs);
ctx.args = &tensors[0]; ctx.args = &tensors[0];
ctx.nargs = nargs; ctx.nargs = nargs;
if (strstr(op->ob_type->tp_name, "BackwardGraph")) {
ctx.backward = true;
}


for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
TensorWrapper* tw = TensorWrapper::cast_safe(args[i]);
if (!tw) {
if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) {
auto* t = tensors[i] = tw->m_tensor.get();
ctx.flags |= t->m_flags;
} else {
PyErr_SetString(PyExc_TypeError, "expect Tensor"); PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr; return nullptr;
} }
auto* t = tensors[i] = tw->m_tensor.get();
ctx.flags |= t->m_flags;
} }


// TODO: set TRACE flag
if (is_tracing) {
ctx.flags |= Tensor::Flags::TRACE;
}


auto outputs = apply(ctx); auto outputs = apply(ctx);
size_t nout = outputs.size(); size_t nout = outputs.size();
@@ -99,7 +147,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
} }
return ret.release().ptr(); return ret.release().ptr();

} catch (std::exception& e) { } catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what()); PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr; return nullptr;
@@ -122,36 +169,116 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} }
m_tensor = t->m_tensor; m_tensor = t->m_tensor;
} else { } else {
if (nargs != 3) {
throw py::type_error("expect 3 arguments");
}
py::detail::loader_life_support life_sup; // required to cast DType
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();

interpreter::Interpreter::Handle handle;
constexpr auto size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
if (nargs == 1) {
auto arg0 = PyTuple_GetItem(args, 0);
// for lazy_eval_tensor
if (strstr(arg0->ob_type->tp_name, "VarNode")) {
if (PyObject_HasAttrString(arg0, "_node")) {
arg0 = PyObject_GetAttrString(arg0, "_node");
}
m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>());
} else {
// for DeviceTensorND
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
auto dv = py::handle(arg0).cast<DeviceTensorND>();
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
m_tensor = std::make_shared<Tensor>(handle);
} else {
throw py::type_error("single argument is not tensor, varnode or devicetensor");
}
}
} else { } else {
HostTensorND ret(cn);
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}
py::detail::loader_life_support life_sup; // required to cast DType
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>();
if (nargs != 4) {
throw py::type_error("expect 3 arguments");
}

// const op
if (is_const && is_tracing) {
py::object pyf;
if (is_compiled) {
pyf = cpp_apply_const_compiled_mode;
} else {
pyf = cpp_apply_const_with_tracing;
}

auto ret = pyf(*tup);
auto py_ret = py::reinterpret_borrow<py::list>(ret);
if (auto* t = cast_safe(py_ret[0].ptr())) {
m_tensor = t->m_tensor;
}
return;
}

interpreter::Interpreter::Handle handle;
constexpr auto size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
} else {
HostTensorND ret(cn);
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}

m_tensor = std::make_shared<Tensor>(handle);


m_tensor = std::make_shared<Tensor>(handle);
if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
}
} }
} }
} }




#define REGISTE_TENSORWRAPPER_FUNC(type, member) \
PyObject* TensorWrapper::member() { \
return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
auto py_dest = py::reinterpret_borrow<py::object>(dest); \
type real_dest = py_dest.cast<type>(); \
m_tensor->m_trace_info.member = real_dest; \
}

REGISTE_TENSORWRAPPER_FUNC(bool, data_read)
REGISTE_TENSORWRAPPER_FUNC(bool, value_read)
REGISTE_TENSORWRAPPER_FUNC(bool, shape_read)
REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)

#undef REGISTE_TENSORWRAPPER_FUNC


PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr();
}


void TensorWrapper::set_handle(PyObject* dest) {
auto py_dest = py::reinterpret_borrow<py::object>(dest);
SharedHandle real_dest = py_dest.cast<SharedHandle>();
auto&& t = std::move(m_tensor->m_handle);
m_tensor->m_handle = std::move(real_dest);
}


PyObject* TensorWrapper::shape() { PyObject* TensorWrapper::shape() {
if (!skip_tracing) {
set_shape_read(py::cast(true). release().ptr());
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0); return PyTuple_New(0);
} }
auto&& shape = m_tensor->shape();

TensorShape shape;
if (m_tensor->m_var) {
shape = m_tensor->m_var->shape();
} else {
shape = m_tensor->shape();
}

if (!shape.ndim) { if (!shape.ndim) {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
@@ -164,16 +291,38 @@ PyObject* TensorWrapper::shape() {




PyObject* TensorWrapper::dtype() { PyObject* TensorWrapper::dtype() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var->dtype()).release().ptr();
}
return py::cast(m_tensor->dtype()).release().ptr(); return py::cast(m_tensor->dtype()).release().ptr();
} }




PyObject* TensorWrapper::device() { PyObject* TensorWrapper::device() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var->comp_node()).release().ptr();
}
return py::cast(m_tensor->comp_node()).release().ptr(); return py::cast(m_tensor->comp_node()).release().ptr();
} }




PyObject* TensorWrapper::numpy() { PyObject* TensorWrapper::numpy() {
if (!skip_tracing) {
set_value_read(py::cast(true).release().ptr());
}
if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(m_tensor->m_var);
using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
return nullptr;
}
auto* val = mgr.infer_value_fallible(m_tensor->m_var);
if (!val) {
return nullptr;
}
return py::cast(*val).attr("numpy")().release().ptr();
}
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) return nullptr; if (!arr) return nullptr;
@@ -184,6 +333,13 @@ PyObject* TensorWrapper::numpy() {
return arr.release().ptr(); return arr.release().ptr();
} }


PyObject* TensorWrapper::varnode() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var).release().ptr();
}
return nullptr;
}

void TensorWrapper::reset(PyObject* tensor) { void TensorWrapper::reset(PyObject* tensor) {
TensorWrapper* t = TensorWrapper::cast_safe(tensor); TensorWrapper* t = TensorWrapper::cast_safe(tensor);
if (!t) { if (!t) {
@@ -195,13 +351,22 @@ void TensorWrapper::reset(PyObject* tensor) {
PyObject* TensorWrapper::detach() { PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this); PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type; PyTypeObject* pytype = self->ob_type;
auto new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);

std::shared_ptr<Tensor> new_tensor;
if (m_tensor->m_handle.get()) {
new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
} else {
new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
}
auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr(); return ret.release().ptr();


} }


PyObject* TensorWrapper::_dev_tensor(){ PyObject* TensorWrapper::_dev_tensor(){
if (!skip_tracing) {
set_data_read(py::cast(true).release().ptr());
}
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
return py::cast(dev_tensor).release().ptr(); return py::cast(dev_tensor).release().ptr();
} }
@@ -227,11 +392,14 @@ PyObject* TensorWrapper::isscalar() {
} }
} }



void TensorWrapper::setscalar() { void TensorWrapper::setscalar() {
m_tensor->m_flags |= Tensor::Flags::SCALAR; m_tensor->m_flags |= Tensor::Flags::SCALAR;
} }




PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr};

struct TensorWeakRef { struct TensorWeakRef {
std::weak_ptr<Tensor> wptr; std::weak_ptr<Tensor> wptr;


@@ -262,6 +430,12 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_swap_out>("_swap_out") .def<&TensorWrapper::_swap_out>("_swap_out")
.def<&TensorWrapper::_swap_in>("_swap_in") .def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::_drop>("_drop")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read")
.def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read")
.def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
.finalize(); .finalize();
if (!tensor_type) throw py::error_already_set(); if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type); py::setattr(m, "Tensor", tensor_type);
@@ -296,6 +470,25 @@ void init_tensor(py::module m) {
if (!grad_key_type) throw py::error_already_set(); if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type); py::setattr(m, "GradKey", grad_key_type);
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward));
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode);
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);

m.attr("skip_tracing") = &skip_tracing;
m.attr("call_level") = &call_level;

py::class_<SharedHandle>(m, "SharedHandle")
.def(py::init<const SharedHandle&>());

m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing);
m.def("set_symbolic", &set_symbolic);
m.def("unset_symbolic", &unset_symbolic);
m.def("set_compiled", &set_compiled);
m.def("unset_compiled", &unset_compiled);

} }


} // namespace mgb::imperative::python } // namespace mgb::imperative::python

+ 50
- 10
imperative/python/src/tensor.h View File

@@ -30,13 +30,10 @@ struct ObjectPtr : B {
} // namespace mgb::imperative::python } // namespace mgb::imperative::python


#include "./grad_info.h" // for struct GradInfo #include "./grad_info.h" // for struct GradInfo
#include "./trace_info.h" // for struct TraceInfo


namespace mgb::imperative::python { namespace mgb::imperative::python {


struct TraceInfo {

};

extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;


class SharedHandle { class SharedHandle {
@@ -46,7 +43,9 @@ class SharedHandle {


public: public:
inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
interpreter_for_py->del(h);
if (h) {
interpreter_for_py->del(h);
}
}) {} }) {}
SharedHandle(const SharedHandle&) = default; SharedHandle(const SharedHandle&) = default;
SharedHandle& operator=(const SharedHandle&) = default; SharedHandle& operator=(const SharedHandle&) = default;
@@ -71,11 +70,14 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
GradInfo m_grad_info; GradInfo m_grad_info;
TraceInfo m_trace_info; TraceInfo m_trace_info;
SharedHandle m_handle; SharedHandle m_handle;
cg::VarNode* m_var;


using Handle = interpreter::Interpreter::Handle; using Handle = interpreter::Interpreter::Handle;


inline explicit Tensor(Handle handle) : m_handle(handle) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {}
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}

~Tensor() = default; ~Tensor() = default;


inline std::shared_ptr<Tensor> copy() { inline std::shared_ptr<Tensor> copy() {
@@ -83,12 +85,28 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
ret->m_flags = m_flags; ret->m_flags = m_flags;
ret->m_grad_info = m_grad_info; ret->m_grad_info = m_grad_info;
ret->m_trace_info = m_trace_info; ret->m_trace_info = m_trace_info;
ret->m_var = m_var;
return ret; return ret;
} }


inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());}
inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());}
inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());}
inline DType dtype() {
if (m_var) {
return m_var->dtype();
}
return interpreter_for_py->get_dtype(m_handle.get());
}
inline CompNode comp_node() {
if (m_var) {
return m_var->comp_node();
}
return interpreter_for_py->get_device(m_handle.get());
}
inline TensorShape shape() {
if (m_var) {
return m_var->shape();
}
return interpreter_for_py->get_shape(m_handle.get());
}
}; };




@@ -135,6 +153,19 @@ struct TensorWrapper {
void _swap_in(); void _swap_in();
void _swap_out(); void _swap_out();
void _drop(); void _drop();
PyObject* varnode();
PyObject* handle();
void set_handle(PyObject *);

PyObject* data_read();
PyObject* value_read();
PyObject* shape_read();
PyObject* mixin_handle();

void set_data_read(PyObject*);
void set_value_read(PyObject*);
void set_shape_read(PyObject*);
void set_mixin_handle(PyObject*);
}; };




@@ -145,6 +176,7 @@ struct ApplyContext {
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
Tensor*const* args; Tensor*const* args;
size_t nargs; size_t nargs;
bool backward = false;
}; };


using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
@@ -153,6 +185,14 @@ apply_result_t apply(ApplyContext& ctx);


void init_tensor(pybind11::module); void init_tensor(pybind11::module);


extern bool is_tracing;
extern bool is_symbolic;
extern bool is_compiled;
extern int64_t call_level;

extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode;
extern pybind11::object cpp_apply_backward_varnode;

} // namespace mgb::imperative::python } // namespace mgb::imperative::python


namespace pybind11::detail { namespace pybind11::detail {


+ 94
- 0
imperative/python/src/trace.cpp View File

@@ -0,0 +1,94 @@
/**
* \file imperative/python/src/trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./trace.h"
#include "./helper.h"
#include "megbrain/imperative/ops/autogen.h"

namespace py = pybind11;

namespace mgb::imperative::python {

apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) {
apply_result_t outputs;

cg::VarNodeArray vinputs(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
vinputs[i] = ctx.args[i]->m_var;
}
auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs);

for (size_t i = 0; i < ovars.size(); i++) {
outputs.emplace_back(std::make_shared<Tensor>(ovars[i]));
}

return outputs;
}

apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs;

bool run_apply_on_var_node = false;
for (size_t i = 0; i < ctx.nargs; i++) {
run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr));
}

if (ctx.backward) {
// reach here when symbolic=True or compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args)
auto args = py::tuple(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = py::cast(ctx.args[i]->m_var);
}
py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args);

if (!ret) {
throw py::value_error("invalid py object call");
}

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode *>();
outputs.emplace_back(std::make_shared<Tensor>(pitem));
}
return outputs;
}

if (run_apply_on_var_node && !is_symbolic) {
return apply_tensor_on_var_node(ctx);
}

py::object pyf;
if (is_compiled) {
// run apply in compiled mode, step 2, 3, etc
pyf = cpp_apply_compiled_mode;
} else {
// run first step, both symbolic and non symbolic
pyf = cpp_apply_with_tracing;
}

auto args = py::tuple(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release();
}
auto ret = pyf(py::cast(ctx.op), *args);

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::cast_safe(tup[i].ptr());
outputs.emplace_back(tw->m_tensor);
}
return outputs;
}

} // namespace mgb::imperative::python

+ 3
- 2
imperative/python/src/trace.h View File

@@ -9,9 +9,10 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "./tensor.h"

namespace mgb::imperative::python { namespace mgb::imperative::python {


struct TraceInfo {
};
apply_result_t apply_trace(ApplyContext& ctx);


} // namespace mgb::imperative::python } // namespace mgb::imperative::python

+ 24
- 0
imperative/python/src/trace_info.h View File

@@ -0,0 +1,24 @@
/**
* \file imperative/python/src/trace_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "inttypes.h"

namespace mgb::imperative::python {

struct TraceInfo {
int64_t mixin_handle = -1;

bool data_read = false;
bool value_read = false;
bool shape_read = false;
};

} // namespace mgb::imperative::python

+ 39
- 54
imperative/python/test/unit/test_tracing.py View File

@@ -19,8 +19,6 @@ from megengine import tensor
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.utils import isscalar from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import exclude_from_trace, trace
@@ -32,35 +30,32 @@ def test_trace():


@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def f(x): def f(x):
op = ops.Elemwise(Elemwise.Mode.NEGATE)
(y,) = apply(op, x)
return y
return -x


x = as_raw_tensor([1]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([1])
y = f(x).numpy()


for i in range(3): for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)




def test_exclude_from_trace(): def test_exclude_from_trace():
for symbolic in [False, True]:
for symbolic in [False]:


@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def f(x): def f(x):
neg = ops.Elemwise(Elemwise.Mode.NEGATE)
(x,) = apply(neg, x)
x = -x
with exclude_from_trace(): with exclude_from_trace():
if i % 2: if i % 2:
(x,) = apply(neg, x)
(x,) = apply(neg, x)
x = -x
x = -x
return x return x


x = as_raw_tensor([1]).numpy()
x = tensor([1])


for i in range(3): for i in range(3):
y = f.__wrapped__(as_raw_tensor(x)).numpy()
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
y = f(x).numpy()
np.testing.assert_equal(f(x).numpy(), y)




def test_print_in_trace(): def test_print_in_trace():
@@ -69,36 +64,33 @@ def test_print_in_trace():
@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def f(x): def f(x):
nonlocal buf nonlocal buf
neg = ops.Elemwise(Elemwise.Mode.NEGATE)
(x,) = apply(neg, x)
x = -x
buf = x.numpy() buf = x.numpy()
(x,) = apply(neg, x)
x = -x
return x return x


buf = None buf = None
x = as_raw_tensor([1]).numpy()
x = tensor([1])


for i in range(3): for i in range(3):
y = f.__wrapped__(as_raw_tensor(x)).numpy()
y = f(x).numpy()
z = buf z = buf
buf = None buf = None
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)
np.testing.assert_equal(z, buf) np.testing.assert_equal(z, buf)




def test_dump(): def test_dump():
@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(a, b): def f(a, b):
op = ops.Elemwise(Elemwise.Mode.ADD)
(y,) = apply(op, a, b)
return y
return a + b


a = as_raw_tensor([2]).numpy()
b = as_raw_tensor([4]).numpy()
y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy()
a = tensor([2])
b = tensor([4])
y = f(a, b).numpy()


for i in range(3): for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y)
np.testing.assert_equal(f(a, b).numpy(), y)


file = io.BytesIO() file = io.BytesIO()
dump_info = f.dump(file) dump_info = f.dump(file)
@@ -111,19 +103,17 @@ def test_dump():




def test_capture_dump(): def test_capture_dump():
a = as_raw_tensor([2])
a = tensor([2])


@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(x): def f(x):
op = ops.Elemwise(Elemwise.Mode.MUL)
(y,) = apply(op, x, a)
return y
return x * a


x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([3])
y = f(x).numpy()


for i in range(3): for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)


file = io.BytesIO() file = io.BytesIO()
f.dump(file) f.dump(file)
@@ -133,19 +123,17 @@ def test_capture_dump():




def test_dump_volatile(): def test_dump_volatile():
p = as_raw_tensor([2])
p = tensor([2])


@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(x): def f(x):
op = ops.Elemwise(Elemwise.Mode.MUL)
(y,) = apply(op, x, p)
return y
return x * p


x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([3])
y = f(x).numpy()


for i in range(3): for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)


file = io.BytesIO() file = io.BytesIO()
f.dump(file, optimize_for_inference=False) f.dump(file, optimize_for_inference=False)
@@ -163,21 +151,18 @@ def test_trace_profiler():


@trace(symbolic=symbolic, profiling=True) @trace(symbolic=symbolic, profiling=True)
def f(x): def f(x):
op = ops.Elemwise(Elemwise.Mode.NEGATE)
(y,) = apply(op, x)
return y
return -x


x = as_raw_tensor([1]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([1])
y = f(x).numpy()


f(as_raw_tensor(x))
f(as_raw_tensor(x)) # XXX: has to run twice
f(x)
f(x) # XXX: has to run twice


out = f.get_profile() out = f.get_profile()
assert out.get("profiler") assert out.get("profiler")




@pytest.mark.skip(reason="force opt_level=0 when building graph")
def test_goptions(): def test_goptions():
@trace(symbolic=True, opt_level=0, capture_as_const=True) @trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x): def f(x):
@@ -196,7 +181,6 @@ def test_goptions():
np.testing.assert_equal(g(d).numpy().item(), 1.0) np.testing.assert_equal(g(d).numpy().item(), 1.0)




@pytest.mark.skip(reason="force opt_level=0 when building graph")
def test_goptions_log_sum_exp(): def test_goptions_log_sum_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True) @trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x, y): def f(x, y):
@@ -256,8 +240,7 @@ def test_optimize_for_inference_broadcast():


@trace(capture_as_const=True, symbolic_shape=True) @trace(capture_as_const=True, symbolic_shape=True)
def f(): def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b
return a._broadcast(tensor([1, 10], dtype=np.int32))


f() f()
f.dump(io.BytesIO()) f.dump(io.BytesIO())
@@ -387,7 +370,9 @@ def test_trace_nms():


@trace(symbolic=False) @trace(symbolic=False)
def f(boxes, scores): def f(boxes, scores):
# with tracing, max_output must be specified
results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
# without tracing, max output can be inferred inside nms
with exclude_from_trace(): with exclude_from_trace():
_ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5) _ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5)
return results return results


+ 0
- 1
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -318,7 +318,6 @@ def optimize_for_inference(args, outputs):
), "optimize_for_inference should be set when {} is given".format(k) ), "optimize_for_inference should be set when {} is given".format(k)
kwargs[v] = True kwargs[v] = True


outputs = [G.VarNode(output) for output in outputs]
if args.optimize_for_inference: if args.optimize_for_inference:
outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)] outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)]




+ 1
- 1
sdk/xor-deploy/xornet.py View File

@@ -84,7 +84,7 @@ def main():
minibatch = next(val_dataset) minibatch = next(val_dataset)
net.eval() net.eval()
_, loss = val_fun(data, label) _, loss = val_fun(data, label)
loss = loss.numpy()[0]
loss = loss.numpy()
val_loss.append((step, loss)) val_loss.append((step, loss))
print("Step: {} loss={}".format(step, loss)) print("Step: {} loss={}".format(step, loss))
opt.step() opt.step()


Loading…
Cancel
Save