Browse Source

feat(mge/trace): support dict return value processing in trace

GitOrigin-RevId: 5b1c08848b
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
f211124865
2 changed files with 115 additions and 106 deletions
  1. +12
    -10
      imperative/python/megengine/jit/tracing.py
  2. +103
    -96
      imperative/python/test/unit/test_tracing.py

+ 12
- 10
imperative/python/megengine/jit/tracing.py View File

@@ -642,22 +642,24 @@ class trace:
if self._capture_as_const:
self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs)
transform = False
# outputs can be None
if self._capture_as_const:
self._process_outputs(outputs)

# outputs could be None
if outputs is not None:
if not isinstance(outputs, collections.abc.Sequence):
transform = True
outputs = (outputs,)
for o in outputs:
list_outputs = outputs
if isinstance(outputs, collections.abc.Mapping):
_, list_outputs = zip(*sorted(outputs.items()))
elif not isinstance(outputs, collections.abc.Sequence):
list_outputs = (outputs,)

for o in list_outputs:
# if outputs are copied, then use the newest info in trace data structure
if o._copied:
self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic:
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._capture_as_const:
self._process_outputs(outputs)
if transform:
outputs = outputs[0]

return outputs

def dump(


+ 103
- 96
imperative/python/test/unit/test_tracing.py View File

@@ -28,18 +28,32 @@ from megengine.module import Module
from megengine.random import normal, uniform


def test_trace():
for symbolic in [False, True]:

@trace(symbolic=symbolic)
def f(x):
@pytest.mark.parametrize("trace_mode", [False, True])
@pytest.mark.parametrize("return_mode", ["Value", "Tuple", "List", "Dict"])
def test_trace(trace_mode, return_mode):
@trace(symbolic=trace_mode)
def f(x):
if return_mode == "Tuple":
return (-x,)
elif return_mode == "List":
return [-x]
elif return_mode == "Dict":
return {"neg": -x}
else:
return -x

x = tensor([1])
y = f(x).numpy()
def get_numpy(y):
if return_mode == "Tuple" or return_mode == "List":
return y[0].numpy()
elif return_mode == "Dict":
return y["neg"].numpy()
return y.numpy()

for i in range(3):
np.testing.assert_equal(f(x).numpy(), y)
x = tensor([1])
y = get_numpy(f(x))

for i in range(3):
np.testing.assert_equal(get_numpy(f(x)), y)


def test_output_copy_trace():
@@ -54,51 +68,46 @@ def test_output_copy_trace():
x = F.exp(x)
return x

net = Simple()

gm = GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
ys = {False: [], True: []}

@trace(symbolic=False)
def train_f1(d):
with gm:
loss = net(d)
gm.backward(loss)
opt.step().clear_grad()
return loss
for symbolic in [False, True]:
net = Simple()
gm = GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
data = tensor(np.arange(4).reshape(2, 2), dtype="float32")

@trace(symbolic=True)
def train_f2(d):
with gm:
loss = net(d)
gm.backward(loss)
opt.step().clear_grad()
return loss
@trace(symbolic=symbolic)
def train_func(d):
with gm:
loss = net(d)
gm.backward(loss)
opt.step().clear_grad()
return loss

for i in range(2):
y1 = train_f1(data).numpy()
y2 = train_f2(data).numpy()
np.testing.assert_equal(y1, y2)
for i in range(3):
y = train_func(data).numpy()
ys[symbolic].append(y)

for i in range(3):
np.testing.assert_equal(ys[False][i], ys[True][i])

def test_exclude_from_trace():
for symbolic in [False, True]:

@trace(symbolic=symbolic)
def f(x):
x = -x
with exclude_from_trace():
if i % 2:
x = -x
x = -x
return x
@pytest.mark.parametrize("trace_mode", [False, True])
def test_exclude_from_trace(trace_mode):
@trace(symbolic=trace_mode)
def f(x):
x = -x
with exclude_from_trace():
if i % 2:
x = -x
x = -x
return x

x = tensor([1])
x = tensor([1])

for i in range(3):
y = f(x).numpy()
np.testing.assert_equal(f(x).numpy(), y)
for i in range(3):
y = f(x).numpy()
np.testing.assert_equal(f(x).numpy(), y)


def test_print_in_trace():
@@ -191,21 +200,20 @@ def test_dump_volatile():
)


def test_trace_profiler():
for symbolic in [False, True]:

@trace(symbolic=symbolic, profiling=True)
def f(x):
return -x
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode):
@trace(symbolic=trace_mode, profiling=True)
def f(x):
return -x

x = tensor([1])
y = f(x).numpy()
x = tensor([1])
y = f(x).numpy()

f(x)
f(x) # XXX: has to run twice
f(x)
f(x) # XXX: has to run twice

out = f.get_profile()
assert out.get("profiler")
out = f.get_profile()
assert out.get("profiler")


@pytest.mark.skip(reason="force opt_level=0 when building graph")
@@ -306,20 +314,20 @@ def test_trace_cvt_bool():
np.testing.assert_equal(f(x).numpy(), False)


def test_trace_reshape():
for symbolic in [False, True]:
x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10))
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_reshape(trace_mode):
x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10))

@trace(symbolic=symbolic, capture_as_const=True)
def f(x):
y = x.reshape(x.shape[0], 100)
return y
@trace(symbolic=trace_mode, capture_as_const=True)
def f(x):
y = x.reshape(x.shape[0], 100)
return y

f(x1)
f(x2)
f(x3)
f(x1)
f(x2)
f(x3)


def test_trace_topk():
@@ -387,20 +395,20 @@ def test_raise_on_trace():
assert catch_count == 1


def test_trace_broadcast():
for symbolic in [False, True]:
x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5))
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_broadcast(trace_mode):
x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5))

@trace(symbolic=symbolic, capture_as_const=True)
def f(x):
y = F.broadcast_to(x, (3, 4, 5))
return y
@trace(symbolic=trace_mode, capture_as_const=True)
def f(x):
y = F.broadcast_to(x, (3, 4, 5))
return y

f(x1)
f(x2)
f(x3)
f(x1)
f(x2)
f(x3)


def test_trace_nms():
@@ -466,21 +474,20 @@ def test_slice():
y + y


def test_random():
@pytest.mark.parametrize("shape_mode", [False, True])
def test_random(shape_mode):
def run_test(op):
for symbolic_shape in [True, False]:

@trace(symbolic=True, symbolic_shape=symbolic_shape)
def f():
out = op(size=[10, 10])
out_shape = out.shape
assert out_shape is not None
if not isinstance(out_shape, tuple):
assert out.shape.numpy() is not None
return out

for _ in range(3):
f()
@trace(symbolic=True, symbolic_shape=shape_mode)
def f():
out = op(size=[10, 10])
out_shape = out.shape
assert out_shape is not None
if not isinstance(out_shape, tuple):
assert out.shape.numpy() is not None
return out

for _ in range(3):
f()

run_test(uniform)
run_test(normal)

Loading…
Cancel
Save