|
|
@@ -196,10 +196,11 @@ ValueRefList TracingTransformation::apply_transformation( |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
bool is_const = create_tensor->kind() == CreateTensor::Const; |
|
|
|
bool as_const = is_const || m_capture_as_const; |
|
|
|
auto wrapped_input = record_var( |
|
|
|
outputs[0], is_const || m_capture_as_const, |
|
|
|
is_const ? VarKind::Constant : VarKind::External); |
|
|
|
auto wrapped_output = record_var(outputs[0], false, VarKind::Internal); |
|
|
|
outputs[0], as_const, is_const ? VarKind::Constant : VarKind::External); |
|
|
|
// bound data to outputs too to reduce runtime overhead for shape/value infer |
|
|
|
auto wrapped_output = record_var(outputs[0], as_const, VarKind::Internal); |
|
|
|
auto input_id = wrapped_input->id(); |
|
|
|
auto output_id = wrapped_output->id(); |
|
|
|
m_seq.push_back({{}, {input_id}, {output_id}}); |
|
|
@@ -311,6 +312,18 @@ void CompiledTransformation::compile() { |
|
|
|
auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) { |
|
|
|
VarAccessor accessor; |
|
|
|
accessor.node = node.node(); |
|
|
|
if (auto bound_data = var_info->bound_data) { |
|
|
|
accessor.shape_getter = [bound_data]() -> TensorShape { |
|
|
|
return bound_data.shape()->as_tensor_shape(); |
|
|
|
}; |
|
|
|
accessor.data_getter = [bound_data]() -> DeviceTensorND { |
|
|
|
return bound_data.dev_tensor()->as_nd(); |
|
|
|
}; |
|
|
|
accessor.value_getter = [bound_data]() -> HostTensorND { |
|
|
|
return bound_data.numpy()->as_nd(); |
|
|
|
}; |
|
|
|
return accessor; |
|
|
|
} |
|
|
|
if (var_info->data_required) { |
|
|
|
// reduce d2h when data is available |
|
|
|
// FIXME: compile should not change var_info in-place |
|
|
@@ -410,16 +423,28 @@ void CompiledTransformation::compile() { |
|
|
|
"internal node should be valid when used as input"); |
|
|
|
} |
|
|
|
} |
|
|
|
input_vars.push_back(var_accessors[input].node); |
|
|
|
auto& node = var_accessors[input].node; |
|
|
|
if (input_vars.empty() && require_link && mm_io_link.node()) { |
|
|
|
/*mgb_assert( |
|
|
|
!input_vars.empty(), |
|
|
|
"io-mm operator should have at least one input");*/ |
|
|
|
auto comp_node = mm_io_link.node()->comp_node(); |
|
|
|
// auto comp_node = input_vars[0]->comp_node(); |
|
|
|
node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node) |
|
|
|
.node(); |
|
|
|
} |
|
|
|
input_vars.push_back(node); |
|
|
|
} |
|
|
|
if (require_link && mm_io_link.node()) { |
|
|
|
/*if (require_link && mm_io_link.node()) { |
|
|
|
mgb_assert( |
|
|
|
!input_vars.empty(), |
|
|
|
"io-mm operator should have at least one input"); |
|
|
|
input_vars[0] = |
|
|
|
opr::VirtualDep::make({SymbolVar(input_vars[0]), mm_io_link}) |
|
|
|
.node(); |
|
|
|
} |
|
|
|
auto comp_node = mm_io_link.node()->comp_node(); |
|
|
|
// auto comp_node = input_vars[0]->comp_node(); |
|
|
|
input_vars[0] = opr::VirtualDep::make( |
|
|
|
{SymbolVar(input_vars[0]), mm_io_link}, comp_node) |
|
|
|
.node(); |
|
|
|
}*/ |
|
|
|
VarNodeArray output_vars; |
|
|
|
if (item.op) { |
|
|
|
output_vars = OpDef::apply_on_var_node(*item.op, input_vars); |
|
|
@@ -479,6 +504,12 @@ void CompiledTransformation::recompile() { |
|
|
|
} |
|
|
|
|
|
|
|
void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) { |
|
|
|
if (!lhs.is<HostValue>()) { |
|
|
|
lhs = lhs.numpy(); |
|
|
|
} |
|
|
|
if (!rhs.is<HostValue>()) { |
|
|
|
rhs = rhs.numpy(); |
|
|
|
} |
|
|
|
trace_assert(m_value_comparator(lhs, rhs), "tensors not equals"); |
|
|
|
} |
|
|
|
|
|
|
@@ -507,6 +538,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { |
|
|
|
break; |
|
|
|
} |
|
|
|
case VarKind::Constant: { |
|
|
|
// expect host value here |
|
|
|
mgb_assert(var.bound_data, "const var without data bound"); |
|
|
|
assert_tensor_equal(var.bound_data, value); |
|
|
|
break; |
|
|
@@ -611,7 +643,17 @@ ValueRefList CompiledTransformation::apply_create_tensor( |
|
|
|
trace_assert(item.op == nullptr, "operator mismatch"); |
|
|
|
auto input_id = item.inputs[0]; |
|
|
|
auto output_id = item.outputs[0]; |
|
|
|
auto tensor = imperative::apply(create_tensor, inputs)[0]; |
|
|
|
ValueRef tensor; |
|
|
|
if (create_tensor.kind() == CreateTensor::Const) { |
|
|
|
auto args = create_tensor.parse(inputs); |
|
|
|
if (args.host) { |
|
|
|
// performance issue |
|
|
|
tensor = HostValue::make(*args.host); |
|
|
|
} |
|
|
|
} |
|
|
|
if (!tensor) { |
|
|
|
tensor = imperative::apply(create_tensor, inputs)[0]; |
|
|
|
} |
|
|
|
trace_input(input_id, tensor); |
|
|
|
return {trace_output(output_id)}; |
|
|
|
} |
|
|
|