Browse Source

fix(mge/interpreter): add check for invalid tensor ptr

GitOrigin-RevId: e8edcd92a4
release-1.1
Megvii Engine Team 4 years ago
parent
commit
51fa530d2a
2 changed files with 15 additions and 0 deletions
  1. +14
    -0
      imperative/src/impl/interpreter_impl.cpp
  2. +1
    -0
      imperative/src/impl/interpreter_impl.h

+ 14
- 0
imperative/src/impl/interpreter_impl.cpp View File

@@ -10,6 +10,7 @@
*/

#include "./interpreter_impl.h"
#include "megbrain/common.h"


using namespace mgb;
@@ -58,11 +59,14 @@ SmallVector<void*> ChannelImpl::apply_op(
input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs;
input_descs.reserve(inputs.size());
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info);
input_descs.push_back(info->desc);
}
lock.unlock();

auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)};
@@ -101,6 +105,7 @@ HostTensorND ChannelImpl::get_value(void* handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee);
if (!info->value_fetched) {
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
m_waitee = info;
m_worker.add_task(GetValue{info});
m_cv.wait(lock, [&]() {
@@ -222,6 +227,7 @@ void ChannelImpl::process_one_task(Command& cmd) {
SmallVector<TensorPtr> tensor_inputs;
tensor_inputs.reserve(cmd.inputs.size());
for (auto i : cmd.inputs) {
mgb_assert(i->ptr, "Invalid input tensor ptr!");
tensor_inputs.push_back(i->ptr);
}
auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs);
@@ -232,6 +238,7 @@ void ChannelImpl::process_one_task(Command& cmd) {
} else if constexpr (std::is_same_v<T, Del>) {
free(cmd.dest);
} else if constexpr (std::is_same_v<T, GetValue>) {
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex);
cmd.dest->value_fetched = true;
@@ -243,6 +250,13 @@ void ChannelImpl::process_one_task(Command& cmd) {
}
} catch (...) {
MGB_LOCK_GUARD(m_mutex);
if constexpr (std::is_same_v<T, ApplyOp>) {
for (auto oup : cmd.outputs) {
oup->invalid = true;
}
} else if constexpr (std::is_same_v<T, Put>) {
cmd.dest->invalid = true;
}
m_worker_exc = std::current_exception();
m_cv.notify_all();
}


+ 1
- 0
imperative/src/impl/interpreter_impl.h View File

@@ -28,6 +28,7 @@ struct TensorInfo {
TensorPtr ptr;
LogicalTensorDesc desc;
bool value_fetched = false;
bool invalid = false;
};

struct Put {


Loading…
Cancel
Save