|
|
@@ -150,7 +150,7 @@ def test_dump_volatile(): |
|
|
|
(out,) = outputs |
|
|
|
assert ( |
|
|
|
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) |
|
|
|
== "SharedDeviceTensor" |
|
|
|
== "ImmutableTensor" |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
@@ -235,6 +235,18 @@ def test_optimize_for_inference(): |
|
|
|
assert computing_input.dtype == np.float16 |
|
|
|
|
|
|
|
|
|
|
|
def test_optimize_for_inference_broadcast(): |
|
|
|
a = tensor(np.ones(1, dtype=np.float32)) |
|
|
|
|
|
|
|
@trace(capture_as_const=True, tensor_shape=True) |
|
|
|
def f(): |
|
|
|
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) |
|
|
|
return b |
|
|
|
|
|
|
|
f() |
|
|
|
f.dump(io.BytesIO()) |
|
|
|
|
|
|
|
|
|
|
|
def test_trace_cvt_bool(): |
|
|
|
set_tensor_shape(True) |
|
|
|
x = tensor([0], dtype=np.int32) |
|
|
|