Browse Source

fix(mge/trace): fix graph option in trace

GitOrigin-RevId: 7bec84f56d
release-1.1
Megvii Engine Team 4 years ago
parent
commit
1b56851715
2 changed files with 32 additions and 20 deletions
  1. +3
    -1
      imperative/python/megengine/jit/tracing.py
  2. +29
    -19
      imperative/python/test/unit/test_tracing.py

+ 3
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -284,6 +284,7 @@ class trace:
apply.enable(apply_symbolic_mode) apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode) apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph() self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph)


def _take_escaped_tensors(self): def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors) escaped_tensors = tuple(self._active_tensors)
@@ -302,7 +303,6 @@ class trace:
readers.append(reader) readers.append(reader)
active_lazy_eval_tensors.append(x) active_lazy_eval_tensors.append(x)
visited.add(x) visited.add(x)
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.compile(*readers) lazy_eval_graph.compile(*readers)
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors): for r, x in zip(readers, active_lazy_eval_tensors):
@@ -599,6 +599,8 @@ class trace:


h2v = {} h2v = {}
graph = G.Graph() graph = G.Graph()
# only graph_opt_level takes effect in dump
self._apply_graph_options(graph)


for i, h in enumerate(self._arg_bindings): for i, h in enumerate(self._arg_bindings):
info = self._tinfo[h] info = self._tinfo[h]


+ 29
- 19
imperative/python/test/unit/test_tracing.py View File

@@ -174,31 +174,24 @@ def test_trace_profiler():
assert out.get("profiler") assert out.get("profiler")




@pytest.mark.skip(reason="could not disable opt_level")
def test_goptions_log_exp():
def test_goptions():
@trace(symbolic=True, opt_level=0, capture_as_const=True) @trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x): def f(x):
return log(exp(x))
# directly return x / x will not trigger gopt
# since there's no way to tell the two x are the same
y = 2.0 * x
return y / y


@trace(symbolic=True, opt_level=1, capture_as_const=True) @trace(symbolic=True, opt_level=1, capture_as_const=True)
def g(x): def g(x):
return log(exp(x))

f(tensor(1.0))
_, out = mkstemp()
f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_1 = cgtools.get_oprs_seq(outputs)

g(tensor(1.0))
g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_2 = cgtools.get_oprs_seq(outputs)
y = 2.0 * x
return y / y


assert len(oprs_1) - len(oprs_2) == 2
d = tensor(0.0)
assert not np.isfinite(f(d).numpy())
np.testing.assert_equal(g(d).numpy().item(), 1.0)




@pytest.mark.skip(reason="could not disable opt_level")
def test_goptions_log_sum_exp(): def test_goptions_log_sum_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True) @trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x, y): def f(x, y):
@@ -208,13 +201,30 @@ def test_goptions_log_sum_exp():
def g(x, y): def g(x, y):
return log(exp(x) + exp(y)) return log(exp(x) + exp(y))


f(tensor(1.0), tensor(2.0))
val = 1.0e4
d = tensor(val)
o = tensor(0.0)
assert not np.isfinite(f(d, o).numpy())
np.testing.assert_almost_equal(g(d, o), val)


@pytest.mark.skip(reason="could not use opt_level=0 with dump")
def test_goptions_log_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
return log(exp(x))

@trace(symbolic=True, opt_level=1, capture_as_const=True)
def g(x):
return log(exp(x))

f(tensor(1.0))
_, out = mkstemp() _, out = mkstemp()
f.dump(out, optimize_for_inference=False) f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out) *_, outputs = G.load_graph(out)
oprs_1 = cgtools.get_oprs_seq(outputs) oprs_1 = cgtools.get_oprs_seq(outputs)


g(tensor(1.0), tensor(2.0))
g(tensor(1.0))
g.dump(out, optimize_for_inference=False) g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out) *_, outputs = G.load_graph(out)
oprs_2 = cgtools.get_oprs_seq(outputs) oprs_2 = cgtools.get_oprs_seq(outputs)


Loading…
Cancel
Save