|
|
@@ -373,7 +373,7 @@ SmallVector<Handle> ChannelImpl::apply_op_impl( |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
for (auto i : inputs) { |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(i); |
|
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); |
|
|
|
mgb_assert(!info->invalid, "an input tensor is unusable due to previous error"); |
|
|
|
input_infos.push_back(info); |
|
|
|
input_descs.push_back(info->desc); |
|
|
|
} |
|
|
@@ -403,7 +403,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
// donnot use info->value_fetched, it's unsafe |
|
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); |
|
|
|
mgb_assert(!info->invalid, "tensor is unusable due to previous error"); |
|
|
|
return wait_tensor(info, TensorProp::HostValue)->get_value(); |
|
|
|
} |
|
|
|
|
|
|
@@ -776,7 +776,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { |
|
|
|
RECORD_EVENT(OpExecuteFinishEvent, apply_id); |
|
|
|
// End profiling operator |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void ChannelImpl::flush_apply_stack() { |
|
|
|
m_applying = true; |
|
|
|
auto& state = get_worker_state(); |
|
|
@@ -1002,7 +1002,7 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPt |
|
|
|
} |
|
|
|
return tensors; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)}; |
|
|
|
} |
|
|
|
|
|
|
@@ -1021,6 +1021,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put); |
|
|
|
sample_on_device(cmd.dest->desc.comp_node, false); |
|
|
|
} else if constexpr (std::is_same_v<T, ApplyOp>) { |
|
|
|
for (auto& i : cmd.inputs) { |
|
|
|
if (i->invalid) { |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
for (auto& i : cmd.outputs) { |
|
|
|
i->invalid = true; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
m_apply_stack.push({cmd, 0, nullptr}); |
|
|
|
flush_apply_stack(); |
|
|
|
for (size_t i = 0; i < cmd.outputs.size(); ++i) { |
|
|
@@ -1085,21 +1094,23 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandFinishEvent::Del); |
|
|
|
sample_on_device(device, false); |
|
|
|
} else if constexpr (std::is_same_v<T, GetValue>) { |
|
|
|
if (cmd.dest->invalid) return; |
|
|
|
imperative_log_profile_begin("GetValue"); |
|
|
|
if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) { |
|
|
|
regenerate(cmd.dest); |
|
|
|
} |
|
|
|
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); |
|
|
|
cmd.dest->ptr->fetch_value(); |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
notify_tensor_unsafe(cmd.dest); |
|
|
|
imperative_log_profile_end("GetValue"); |
|
|
|
} else if constexpr (std::is_same_v<T, SwapIn>) { |
|
|
|
if (cmd.dest->invalid) return; |
|
|
|
RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapIn); |
|
|
|
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); |
|
|
|
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapIn); |
|
|
|
sample_on_device(cmd.dest->desc.comp_node, false); |
|
|
|
} else if constexpr (std::is_same_v<T, SwapOut>) { |
|
|
|
if (cmd.dest->invalid) return; |
|
|
|
RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapOut); |
|
|
|
cmd.dest->h_value = cmd.dest->ptr->get_value(); |
|
|
|
if (cmd.dest->evict_type == EvictType::NONE) { |
|
|
@@ -1110,6 +1121,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapOut); |
|
|
|
sample_on_device(cmd.dest->desc.comp_node, false); |
|
|
|
} else if constexpr (std::is_same_v<T, Drop>) { |
|
|
|
if (cmd.dest->invalid) return; |
|
|
|
RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Drop); |
|
|
|
do_drop(cmd.dest, true); |
|
|
|
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Drop); |
|
|
@@ -1186,7 +1198,11 @@ void ChannelImpl::check_worker_exc_unsafe() { |
|
|
|
m_waitee = nullptr; |
|
|
|
std::exception_ptr exc; |
|
|
|
std::swap(exc, m_worker_exc); |
|
|
|
std::rethrow_exception(exc); |
|
|
|
try { |
|
|
|
std::rethrow_exception(exc); |
|
|
|
} catch (...) { |
|
|
|
throw AsyncError(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|