Browse Source

fix(mge): fix optimize_for_inference during trace.dump

GitOrigin-RevId: e10f7c323a
release-1.1
Megvii Engine Team 4 years ago
parent
commit
59a9275c66
3 changed files with 17 additions and 3 deletions
  1. +3
    -1
      imperative/python/megengine/jit/tracing.py
  2. +13
    -1
      imperative/python/test/unit/test_tracing.py
  3. +1
    -1
      src/gopt/impl/inference.cpp

+ 3
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -570,7 +570,9 @@ class trace:
if h not in h2v: if h not in h2v:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const(info.bound_data._dev_tensor())
h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device
)
ivars.append(h2v[h]) ivars.append(h2v[h])
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)


+ 13
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -150,7 +150,7 @@ def test_dump_volatile():
(out,) = outputs (out,) = outputs
assert ( assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "SharedDeviceTensor"
== "ImmutableTensor"
) )




@@ -235,6 +235,18 @@ def test_optimize_for_inference():
assert computing_input.dtype == np.float16 assert computing_input.dtype == np.float16




def test_optimize_for_inference_broadcast():
a = tensor(np.ones(1, dtype=np.float32))

@trace(capture_as_const=True, tensor_shape=True)
def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b

f()
f.dump(io.BytesIO())


def test_trace_cvt_bool(): def test_trace_cvt_bool():
set_tensor_shape(True) set_tensor_shape(True)
x = tensor([0], dtype=np.int32) x = tensor([0], dtype=np.int32)


+ 1
- 1
src/gopt/impl/inference.cpp View File

@@ -561,7 +561,7 @@ void ParamFusePass::apply(OptState &state) const {
} }


SymbolVar new_var; SymbolVar new_var;
bool is_default_format = var->layout().format.is_default();
bool is_default_format = var->format().is_default();
if (cg::is_static_var_value(var) && is_default_format) { if (cg::is_static_var_value(var) && is_default_format) {
// use ImmutableTensor for inferable vars // use ImmutableTensor for inferable vars
HostTensorND hv; HostTensorND hv;


Loading…
Cancel
Save