|
|
@@ -34,8 +34,13 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { |
|
|
|
info->desc.layout = value.layout(); |
|
|
|
info->desc.comp_node = value.comp_node(); |
|
|
|
info->desc.value = value.proxy_to_default_cpu(); |
|
|
|
info->h_value = value; |
|
|
|
m_valid_handle.insert(info); |
|
|
|
m_buffer.enqueue(Put{info, value, no_cache}); |
|
|
|
if (m_async_level == 0) { |
|
|
|
sync(); |
|
|
|
info->desc.comp_node.sync(); |
|
|
|
} |
|
|
|
return info; |
|
|
|
} |
|
|
|
|
|
|
@@ -90,14 +95,19 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
{ |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
for (auto&& info : input_infos) { |
|
|
|
mgb_assert(info->ptr, "invalid tensor ptr!"); |
|
|
|
auto input_cn = info->desc.comp_node; |
|
|
|
if (!output_cn.valid()) { |
|
|
|
output_cn = info->ptr->comp_node(); |
|
|
|
output_cn = input_cn; |
|
|
|
} else { |
|
|
|
mgb_assert(output_cn == input_cn, "cannot decide output comp node"); |
|
|
|
} |
|
|
|
|
|
|
|
if (info->ptr && info->ptr->try_get_value()) { |
|
|
|
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); |
|
|
|
} else { |
|
|
|
mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); |
|
|
|
mgb_assert(!info->h_value.empty(), "inp->h_value is empty!"); |
|
|
|
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); |
|
|
|
} |
|
|
|
mgb_assert(info->ptr->try_get_value(), "no valid host value"); |
|
|
|
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -116,18 +126,12 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
SmallVector<TensorInfo*> output_infos; |
|
|
|
output_infos.reserve(output_descs.size()); |
|
|
|
for (auto&& tensornd : output_tensornds) { |
|
|
|
// tensornd -> host_tensornd |
|
|
|
HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) |
|
|
|
.proxy_to_comp_node(output_cn); |
|
|
|
// tensornd -> desc |
|
|
|
LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; |
|
|
|
// tensornd -> tensor |
|
|
|
auto info = alloc(); |
|
|
|
info->desc = desc; |
|
|
|
m_valid_handle.insert(info); |
|
|
|
// use `put` for consistency |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(put(host_tensornd, false)); |
|
|
|
mgb_assert(info->desc.layout.ndim != 0); |
|
|
|
output_infos.push_back(info); |
|
|
|
info->ptr = Tensor::make(host_tensornd, true); // host_only=true |
|
|
|
info->value_fetched = true; |
|
|
|
outputs->push_back(info); |
|
|
|
} |
|
|
|
|
|
|
@@ -159,6 +163,11 @@ void ChannelImpl::dispatch_kernel( |
|
|
|
for (auto&& desc : output_descs) { |
|
|
|
auto info = alloc(); |
|
|
|
info->desc = desc; |
|
|
|
// make sure desc's value is consistent with h_value |
|
|
|
if (!info->desc.value.empty()) { |
|
|
|
info->h_value = HostTensorND::make_proxy(desc.value) |
|
|
|
.proxy_to_comp_node(desc.comp_node); |
|
|
|
} |
|
|
|
m_valid_handle.insert(info); |
|
|
|
cmd.outputs.push_back(info); |
|
|
|
outputs->push_back(info); |
|
|
@@ -220,7 +229,6 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_assert(outputs.size() > 0, "Invalid dispatch mode!"); |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|