|
|
@@ -54,21 +54,25 @@ void ChannelImpl::del(void* handle) { |
|
|
|
SmallVector<void*> ChannelImpl::apply_op( |
|
|
|
std::shared_ptr<OpDef> op, |
|
|
|
const SmallVector<void*>& inputs) { |
|
|
|
SmallVector<TensorInfo*> input_infos; |
|
|
|
input_infos.reserve(inputs.size()); |
|
|
|
SmallVector<LogicalTensorDesc> input_descs; |
|
|
|
input_descs.reserve(inputs.size()); |
|
|
|
for (auto h : inputs) { |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(h); |
|
|
|
for (auto i : inputs) { |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(i); |
|
|
|
input_infos.push_back(info); |
|
|
|
input_descs.push_back(info->desc); |
|
|
|
} |
|
|
|
auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs); |
|
|
|
ApplyOp cmd{std::move(op)}; |
|
|
|
cmd.inputs.reserve(inputs.size()); |
|
|
|
for (auto i : inputs) { |
|
|
|
cmd.inputs.push_back(reinterpret_cast<TensorInfo*>(i)); |
|
|
|
} |
|
|
|
cmd.inputs = std::move(input_infos); |
|
|
|
cmd.outputs.reserve(output_descs.size()); |
|
|
|
SmallVector<void*> outputs; |
|
|
|
bool is_fallible = false; |
|
|
|
for (auto&& desc : output_descs) { |
|
|
|
if (desc.layout.ndim == 0) { |
|
|
|
is_fallible = true; |
|
|
|
} |
|
|
|
auto info = alloc(); |
|
|
|
info->desc = desc; |
|
|
|
m_valid_handle.insert(info); |
|
|
@@ -76,6 +80,9 @@ SmallVector<void*> ChannelImpl::apply_op( |
|
|
|
outputs.push_back(info); |
|
|
|
} |
|
|
|
m_worker.add_task(std::move(cmd)); |
|
|
|
if (is_fallible && m_async_level <= 1) { |
|
|
|
sync(); |
|
|
|
} |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
@@ -162,7 +169,12 @@ void ChannelImpl::close() { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::config_async_level(int level) { |
|
|
|
mgb_assert(0); |
|
|
|
mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2"); |
|
|
|
m_async_level = level; |
|
|
|
} |
|
|
|
|
|
|
|
int ChannelImpl::get_async_level() { |
|
|
|
return m_async_level; |
|
|
|
} |
|
|
|
|
|
|
|
TensorInfo* ChannelImpl::alloc() { |
|
|
|