Browse Source

fix(imperative): fix dangling pointer in dispatcher

GitOrigin-RevId: b21f7e06f0
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
0bb4969696
1 changed files with 20 additions and 3 deletions
  1. +20
    -3
      imperative/python/src/dispatcher.cpp

+ 20
- 3
imperative/python/src/dispatcher.cpp View File

@@ -11,6 +11,7 @@


#include "./dispatcher.h" #include "./dispatcher.h"
#include "./pyext17.h" #include "./pyext17.h"
#include "megbrain/exception.h"
#include "megbrain/utils/hash.h" #include "megbrain/utils/hash.h"
#include "megbrain/utils/small_vector.h" #include "megbrain/utils/small_vector.h"


@@ -56,9 +57,25 @@ struct ObjectIdHash : std::hash<void*> {
} }
}; };


namespace {
using Container = std::vector<Frame>;
struct DispatcherStack: Container {
constexpr static size_t MAX_RECURSIVE_DEPTH = 1024u;
DispatcherStack() { reserve(MAX_RECURSIVE_DEPTH); }

template<typename... Args>
auto&& emplace_back_safely(Args&& ...args) {
mgb_throw_if(size() >= MAX_RECURSIVE_DEPTH, mgb::MegBrainError,
"recursion depth %zu is greater than the MAX_RECURSIVE_DEPTH(%zu)",
size(), MAX_RECURSIVE_DEPTH);
return emplace_back(std::forward<Args>(args)...);
}
};
} // anonymous namespace

struct Dispatcher { struct Dispatcher {
std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache; std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache;
std::vector<Frame> stack;
DispatcherStack stack;
std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry; std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry;


inline py::handle self() { inline py::handle self() {
@@ -78,7 +95,7 @@ struct Dispatcher {
return false; return false;
} }
} }
stack.emplace_back(it->second.get());
stack.emplace_back_safely(it->second.get());
return true; return true;
} }


@@ -145,7 +162,7 @@ public:
PyErr_SetString(PyExc_RuntimeError, "super called at top level"); PyErr_SetString(PyExc_RuntimeError, "super called at top level");
return nullptr; return nullptr;
} }
stack.emplace_back(stack.back()).mro_offset++;
stack.emplace_back_safely(stack.back()).mro_offset++;
return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);}); return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);});
} }




Loading…
Cancel
Save