|
@@ -10,6 +10,7 @@ |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#include "./interpreter_impl.h" |
|
|
#include "./interpreter_impl.h" |
|
|
|
|
|
#include "megbrain/common.h" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using namespace mgb; |
|
|
using namespace mgb; |
|
@@ -58,11 +59,14 @@ SmallVector<void*> ChannelImpl::apply_op( |
|
|
input_infos.reserve(inputs.size()); |
|
|
input_infos.reserve(inputs.size()); |
|
|
SmallVector<LogicalTensorDesc> input_descs; |
|
|
SmallVector<LogicalTensorDesc> input_descs; |
|
|
input_descs.reserve(inputs.size()); |
|
|
input_descs.reserve(inputs.size()); |
|
|
|
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
for (auto i : inputs) { |
|
|
for (auto i : inputs) { |
|
|
auto info = reinterpret_cast<TensorInfo*>(i); |
|
|
auto info = reinterpret_cast<TensorInfo*>(i); |
|
|
|
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); |
|
|
input_infos.push_back(info); |
|
|
input_infos.push_back(info); |
|
|
input_descs.push_back(info->desc); |
|
|
input_descs.push_back(info->desc); |
|
|
} |
|
|
} |
|
|
|
|
|
lock.unlock(); |
|
|
|
|
|
|
|
|
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); |
|
|
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); |
|
|
ApplyOp cmd{std::move(op)}; |
|
|
ApplyOp cmd{std::move(op)}; |
|
@@ -101,6 +105,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { |
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
std::unique_lock<decltype(m_mutex)> lock(m_mutex); |
|
|
mgb_assert(!m_waitee); |
|
|
mgb_assert(!m_waitee); |
|
|
if (!info->value_fetched) { |
|
|
if (!info->value_fetched) { |
|
|
|
|
|
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); |
|
|
m_waitee = info; |
|
|
m_waitee = info; |
|
|
m_worker.add_task(GetValue{info}); |
|
|
m_worker.add_task(GetValue{info}); |
|
|
m_cv.wait(lock, [&]() { |
|
|
m_cv.wait(lock, [&]() { |
|
@@ -222,6 +227,7 @@ void ChannelImpl::process_one_task(Command& cmd) { |
|
|
SmallVector<TensorPtr> tensor_inputs; |
|
|
SmallVector<TensorPtr> tensor_inputs; |
|
|
tensor_inputs.reserve(cmd.inputs.size()); |
|
|
tensor_inputs.reserve(cmd.inputs.size()); |
|
|
for (auto i : cmd.inputs) { |
|
|
for (auto i : cmd.inputs) { |
|
|
|
|
|
mgb_assert(i->ptr, "Invalid input tensor ptr!"); |
|
|
tensor_inputs.push_back(i->ptr); |
|
|
tensor_inputs.push_back(i->ptr); |
|
|
} |
|
|
} |
|
|
auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs); |
|
|
auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs); |
|
@@ -232,6 +238,7 @@ void ChannelImpl::process_one_task(Command& cmd) { |
|
|
} else if constexpr (std::is_same_v<T, Del>) { |
|
|
} else if constexpr (std::is_same_v<T, Del>) { |
|
|
free(cmd.dest); |
|
|
free(cmd.dest); |
|
|
} else if constexpr (std::is_same_v<T, GetValue>) { |
|
|
} else if constexpr (std::is_same_v<T, GetValue>) { |
|
|
|
|
|
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); |
|
|
cmd.dest->ptr->fetch_value(); |
|
|
cmd.dest->ptr->fetch_value(); |
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
cmd.dest->value_fetched = true; |
|
|
cmd.dest->value_fetched = true; |
|
@@ -243,6 +250,13 @@ void ChannelImpl::process_one_task(Command& cmd) { |
|
|
} |
|
|
} |
|
|
} catch (...) { |
|
|
} catch (...) { |
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
|
|
if constexpr (std::is_same_v<T, ApplyOp>) { |
|
|
|
|
|
for (auto oup : cmd.outputs) { |
|
|
|
|
|
oup->invalid = true; |
|
|
|
|
|
} |
|
|
|
|
|
} else if constexpr (std::is_same_v<T, Put>) { |
|
|
|
|
|
cmd.dest->invalid = true; |
|
|
|
|
|
} |
|
|
m_worker_exc = std::current_exception(); |
|
|
m_worker_exc = std::current_exception(); |
|
|
m_cv.notify_all(); |
|
|
m_cv.notify_all(); |
|
|
} |
|
|
} |
|
|