Browse Source

fix(mgb/trace): finalize when exception raise

GitOrigin-RevId: b8ffd00a7e
release-1.1
Megvii Engine Team 4 years ago
parent
commit
7ac4dbc27d
2 changed files with 138 additions and 62 deletions
  1. +105
    -62
      imperative/python/megengine/jit/tracing.py
  2. +33
    -0
      imperative/python/test/unit/test_tracing.py

+ 105
- 62
imperative/python/megengine/jit/tracing.py View File

@@ -125,6 +125,9 @@ class trace:
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._tensor_shape = tensor_shape self._tensor_shape = tensor_shape


self._reset()

def _reset(self):
self._untraced = True self._untraced = True
self._tinfo = [] # handle -> TensorInfo self._tinfo = [] # handle -> TensorInfo
self._seq = [] self._seq = []
@@ -257,77 +260,117 @@ class trace:
def _record_const(self, op, outputs): def _record_const(self, op, outputs):
pass pass


@contextlib.contextmanager
def _setup(self):
def _set_active(self, active: bool):
global active_trace global active_trace
if active_trace:
raise NotImplementedError("sorry, not implemented: nested trace")
active_trace = self

if self._untraced:
apply.enable(apply_with_tracing)
apply.enable(apply_const_with_tracing)
if self._symbolic:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()
if active:
if active_trace:
raise NotImplementedError("sorry, not implemented: nested trace")
active_trace = self
else: else:
apply.enable(apply_compiled_mode)
if self._graph is None:
self._compile()
self._graph.execute()

yield

assert active_trace is self
active_trace = None

def _init_trace(self, symbolic: bool):
apply.enable(apply_with_tracing)
apply.enable(apply_const_with_tracing)
if symbolic:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()

def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors) escaped_tensors = tuple(self._active_tensors)
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors


if self._untraced:
for x in escaped_tensors:
info = self._tinfo[x._TraceMixin__handle]
info.data_read = True
x._TraceMixin__restore()
if self._inputs_to_restore:
for x in self._inputs_to_restore:
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors):
active_lazy_eval_tensors = []
visited = set()
readers = []
for x in lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
active_lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.compile(*readers)
lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))

@contextlib.contextmanager
def _setup(self):
interrupted = False

def do_enter():
self._set_active(True)
if self._untraced:
self._init_trace(self._symbolic)
else:
apply.enable(apply_compiled_mode)
if self._graph is None:
self._compile()
self._graph.execute()

def do_finalize():
escaped_tensors = self._take_escaped_tensors()
if self._untraced:
for x in escaped_tensors:
info = self._tinfo[x._TraceMixin__handle]
info.data_read = True
x._TraceMixin__restore() x._TraceMixin__restore()
if self._symbolic:
# eval lazy eval tensors
if self._lazy_eval_tensors:
lazy_eval_tensors = []
visited = set()
readers = []
for x in self._lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_graph.compile(*readers)
self._lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x._TraceMixin__restore()
if self._symbolic and self._lazy_eval_tensors:
# eval lazy eval tensors
self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors)
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = None self._lazy_eval_tensors = None
self._untraced = False
else:
if self._pc != len(self._seq):
raise TraceMismatchError("premature end")
for x in escaped_tensors:
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
self._graph.wait()
self._reset_exec_env()
self._untraced = False
else:
# compiled_tensor leaks
if self._pc == len(self._seq):
for x in escaped_tensors:
try:
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
except TraceMismatchError:
# TraceMismatchError thrown in do_exit
pass
self._graph.wait()
self._reset_exec_env()

# reset status
self._pc = 0 self._pc = 0

self._tensor_remaps = None
apply.disable(apply_with_tracing)
apply.disable(apply_const_with_tracing)
apply.disable(apply_symbolic_mode)
apply.disable(apply_const_symbolic_mode)
apply.disable(apply_compiled_mode)
active_trace = None
self._tensor_remaps = None
apply.disable(apply_with_tracing)
apply.disable(apply_const_with_tracing)
apply.disable(apply_symbolic_mode)
apply.disable(apply_const_symbolic_mode)
apply.disable(apply_compiled_mode)
self._set_active(False)

def do_exit():
if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
for x in self._active_tensors:
x._dev_tensor()

try:
do_enter()
yield
do_exit()
except:
interrupted = True
raise
finally:
do_finalize()
if interrupted:
self._reset()


def _begin_excluded_region(self): def _begin_excluded_region(self):
if self._capture_as_const: if self._capture_as_const:


+ 33
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -307,3 +307,36 @@ def test_trace_warp_perspective():


for i in range(1): for i in range(1):
f(x, M) f(x, M)


def test_raise_on_trace():
step_count = 0
catch_count = 0
bad_step = 10

class CatchMe(Exception):
pass

a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2])

@trace
def add_abc(a, b, c):
print("Hello")
ps = a + b
result = ps + c
if step_count == bad_step:
raise CatchMe("catch me")
return result

for i in range(100):
try:
d = add_abc(a, b, c)
except CatchMe as e:
catch_count += 1
else:
np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
step_count += 1

assert catch_count == 1

Loading…
Cancel
Save