|
|
@@ -47,7 +47,7 @@ VarNodeArray TraceResult::dump( |
|
|
|
auto& node = nodes[input]; |
|
|
|
// TODO: cambricon CompNode |
|
|
|
auto host = std::make_shared<HostTensorND>( |
|
|
|
CompNode::load("xpux"), shape, var.dtype); |
|
|
|
CompNode::load("xpux"), shape, *var.dtype); |
|
|
|
OperatorNodeConfig config; |
|
|
|
// if prefer_input_names, prefer names from dump args |
|
|
|
// else prefer names got from trace procedure |
|
|
@@ -211,7 +211,6 @@ ValueRefList TracingTransformation::apply_transformation( |
|
|
|
auto& var_info = m_vars[tracing_value->id()]; |
|
|
|
switch (get_attr->attr()) { |
|
|
|
case GetAttr::Shape: |
|
|
|
// TODO: reduce h2d when data or value is available |
|
|
|
var_info.shape_required = true; |
|
|
|
break; |
|
|
|
case GetAttr::Data: |
|
|
@@ -301,8 +300,8 @@ void CompiledTransformation::compile() { |
|
|
|
auto box = make_box<DeviceTensorND>(); |
|
|
|
// TODO: attach ref count, release early |
|
|
|
auto outputs = opr::InputCallback::make( |
|
|
|
*m_graph, [box] { return box->take_value(); }, var_info->device, |
|
|
|
var_info->dtype, var_info->shape, io_links, m_input_shape_static); |
|
|
|
*m_graph, [box] { return box->take_value(); }, *var_info->device, |
|
|
|
*var_info->dtype, var_info->shape, io_links, m_input_shape_static); |
|
|
|
// attach input_callback to io_links |
|
|
|
accessor.node = outputs[0].node(); |
|
|
|
io_links = {outputs[1]}; |
|
|
@@ -312,6 +311,11 @@ void CompiledTransformation::compile() { |
|
|
|
auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) { |
|
|
|
VarAccessor accessor; |
|
|
|
accessor.node = node.node(); |
|
|
|
if (var_info->data_required) { |
|
|
|
// reduce d2h when data is available |
|
|
|
// FIXME: compile should not change var_info in-place |
|
|
|
var_info->shape_required = false; |
|
|
|
} |
|
|
|
if (var_info->shape_required) { |
|
|
|
// TODO: use static infer manager for some vars? |
|
|
|
auto box = make_box<TensorShape>(); |
|
|
@@ -334,6 +338,12 @@ void CompiledTransformation::compile() { |
|
|
|
accessor.data_getter = [box]() -> DeviceTensorND { |
|
|
|
return box->get_value(); |
|
|
|
}; |
|
|
|
if (!accessor.shape_getter) { |
|
|
|
// also implement shape_getter |
|
|
|
accessor.shape_getter = [box]() -> TensorShape { |
|
|
|
return box->get_value().shape(); |
|
|
|
}; |
|
|
|
} |
|
|
|
} |
|
|
|
if (var_info->value_required) { |
|
|
|
struct ValueWithEvent { |
|
|
@@ -341,7 +351,7 @@ void CompiledTransformation::compile() { |
|
|
|
CompNode::Event* event = nullptr; |
|
|
|
}; |
|
|
|
auto box = make_box<ValueWithEvent>(); |
|
|
|
auto event = EventPool::without_timer().alloc_shared(var_info->device); |
|
|
|
auto event = EventPool::without_timer().alloc_shared(*var_info->device); |
|
|
|
auto callback = [box, event](DeviceTensorND data) { |
|
|
|
HostTensorND host_val; |
|
|
|
host_val.copy_from(data); |
|
|
@@ -355,7 +365,7 @@ void CompiledTransformation::compile() { |
|
|
|
}; |
|
|
|
SymbolVarArray inputs = io_links; |
|
|
|
inputs.insert(inputs.begin(), node); |
|
|
|
auto output = opr::OutputCallback::make({callback, false, true}, inputs); |
|
|
|
auto output = opr::OutputCallback::make({callback, true, true}, inputs); |
|
|
|
io_links = {output}; |
|
|
|
accessor.value_getter = [box]() -> HostTensorND { |
|
|
|
auto&& [value, event] = box->get_value(); |
|
|
@@ -486,11 +496,12 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { |
|
|
|
DType dtype = *value.dtype(); |
|
|
|
CompNode device = *value.device(); |
|
|
|
trace_assert( |
|
|
|
var.dtype == dtype, "dtype mismatch: %s vs %s", |
|
|
|
var.dtype.name(), dtype.name()); |
|
|
|
*var.dtype == dtype, "dtype mismatch: %s vs %s", |
|
|
|
var.dtype->name(), dtype.name()); |
|
|
|
trace_assert( |
|
|
|
var.device == device, "comp_node mismatch: %s vs %s", |
|
|
|
var.device.to_string().c_str(), device.to_string().c_str()); |
|
|
|
*var.device == device, "comp_node mismatch: %s vs %s", |
|
|
|
var.device->to_string().c_str(), |
|
|
|
device.to_string().c_str()); |
|
|
|
} |
|
|
|
var_accessor.data_setter(value.dev_tensor()->as_nd()); |
|
|
|
break; |
|
|
@@ -535,17 +546,11 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { |
|
|
|
} |
|
|
|
|
|
|
|
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { |
|
|
|
if (!m_dtype) { |
|
|
|
m_dtype = DTypeValue::make(m_var->dtype); |
|
|
|
} |
|
|
|
return m_dtype; |
|
|
|
return m_var->dtype; |
|
|
|
} |
|
|
|
|
|
|
|
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { |
|
|
|
if (!m_comp_node) { |
|
|
|
m_comp_node = CompNodeValue::make(m_var->device); |
|
|
|
} |
|
|
|
return m_comp_node; |
|
|
|
return m_var->device; |
|
|
|
} |
|
|
|
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { |
|
|
|
return *m_accessor; |
|
|
|