|
|
@@ -16,6 +16,7 @@ import megengine |
|
|
|
import megengine.core.tensor.megbrain_graph as G |
|
|
|
import megengine.module as M |
|
|
|
from megengine import cgtools, tensor |
|
|
|
from megengine.core._trace_option import set_tensor_shape |
|
|
|
from megengine.core.ops import builtin as ops |
|
|
|
from megengine.core.tensor import megbrain_graph as G |
|
|
|
from megengine.core.tensor.core import apply |
|
|
@@ -274,3 +275,15 @@ def test_optimize_for_inference(): |
|
|
|
res = G.load_comp_graph_from_file(out) |
|
|
|
computing_input = res.output_vars_list[0].owner.inputs[0] |
|
|
|
assert computing_input.dtype == np.float16 |
|
|
|
|
|
|
|
|
|
|
|
def test_trace_cvt_bool(): |
|
|
|
set_tensor_shape(True) |
|
|
|
x = tensor([0], dtype=np.int32) |
|
|
|
|
|
|
|
@trace(symbolic=True) |
|
|
|
def f(x): |
|
|
|
return x.shape[0] == 0 |
|
|
|
|
|
|
|
for i in range(3): |
|
|
|
np.testing.assert_equal(f(x).numpy()[0], False) |