GitOrigin-RevId: e10f7c323a
release-1.1
@@ -570,7 +570,9 @@ class trace: | |||||
if h not in h2v: | if h not in h2v: | ||||
assert info.external | assert info.external | ||||
assert info.bound_data | assert info.bound_data | ||||
h2v[h] = graph.make_const(info.bound_data._dev_tensor()) | |||||
h2v[h] = graph.make_const( | |||||
info.bound_data.numpy(), dtype=info.dtype, device=info.device | |||||
) | |||||
ivars.append(h2v[h]) | ivars.append(h2v[h]) | ||||
ovars = apply(op, *ivars) | ovars = apply(op, *ivars) | ||||
assert len(ovars) == len(ohandles) | assert len(ovars) == len(ohandles) | ||||
@@ -150,7 +150,7 @@ def test_dump_volatile(): | |||||
(out,) = outputs | (out,) = outputs | ||||
assert ( | assert ( | ||||
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) | cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) | ||||
== "SharedDeviceTensor" | |||||
== "ImmutableTensor" | |||||
) | ) | ||||
@@ -235,6 +235,18 @@ def test_optimize_for_inference(): | |||||
assert computing_input.dtype == np.float16 | assert computing_input.dtype == np.float16 | ||||
def test_optimize_for_inference_broadcast(): | |||||
a = tensor(np.ones(1, dtype=np.float32)) | |||||
@trace(capture_as_const=True, tensor_shape=True) | |||||
def f(): | |||||
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) | |||||
return b | |||||
f() | |||||
f.dump(io.BytesIO()) | |||||
def test_trace_cvt_bool(): | def test_trace_cvt_bool(): | ||||
set_tensor_shape(True) | set_tensor_shape(True) | ||||
x = tensor([0], dtype=np.int32) | x = tensor([0], dtype=np.int32) | ||||
@@ -561,7 +561,7 @@ void ParamFusePass::apply(OptState &state) const { | |||||
} | } | ||||
SymbolVar new_var; | SymbolVar new_var; | ||||
bool is_default_format = var->layout().format.is_default(); | |||||
bool is_default_format = var->format().is_default(); | |||||
if (cg::is_static_var_value(var) && is_default_format) { | if (cg::is_static_var_value(var) && is_default_format) { | ||||
// use ImmutableTensor for inferable vars | // use ImmutableTensor for inferable vars | ||||
HostTensorND hv; | HostTensorND hv; | ||||