Browse Source

chore(mge): improve symbolic tracing value/shape inference

GitOrigin-RevId: d1a6baac74
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
e027dcbf2c
1 changed files with 25 additions and 3 deletions
  1. +25
    -3
      imperative/python/megengine/jit/tracing.py

+ 25
- 3
imperative/python/megengine/jit/tracing.py View File

@@ -186,6 +186,9 @@ class trace:
self._seq.append((op, tuple(ihandles), tuple(ohandles)))
self._active_tensors.update(outputs)

def _record_const(self, op, outputs):
pass

@contextlib.contextmanager
def _setup(self):
global active_trace
@@ -195,8 +198,10 @@ class trace:

if self._untraced:
apply.enable(apply_with_tracing)
apply.enable(apply_const_with_tracing)
if self._symbolic:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()
else:
apply.enable(apply_compiled_mode)
@@ -239,7 +244,9 @@ class trace:
self._pc = 0

apply.disable(apply_with_tracing)
apply.disable(apply_const_with_tracing)
apply.disable(apply_symbolic_mode)
apply.disable(apply_const_symbolic_mode)
apply.disable(apply_compiled_mode)
active_trace = None

@@ -478,6 +485,16 @@ apply.disable(apply_symbolic_mode)


@apply.register()
def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
return (ret,)


apply.disable(apply_const_symbolic_mode)


@apply.register()
def apply_compiled_mode(op: OpDef, *args: RawTensor):
if skip_tracing:
args = [
@@ -502,9 +519,14 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
apply.disable(apply_with_tracing)


# @apply.register()
# def _(op: Const, *args: RawTensor):
# return active_trace._apply_const(op, args)
@apply.register()
def apply_const_with_tracing(op: Const, *args: RawTensor):
outputs = apply.super(op, *args)
active_trace._record_const(op, outputs)
return outputs


apply.disable(apply_const_with_tracing)


class BrokenRawTensor(RawTensor):


Loading…
Cancel
Save