diff --git a/imperative/python/src/dispatcher.cpp b/imperative/python/src/dispatcher.cpp index 48b1f253..79e93e26 100644 --- a/imperative/python/src/dispatcher.cpp +++ b/imperative/python/src/dispatcher.cpp @@ -11,6 +11,7 @@ #include "./dispatcher.h" #include "./pyext17.h" +#include "megbrain/exception.h" #include "megbrain/utils/hash.h" #include "megbrain/utils/small_vector.h" @@ -56,9 +57,25 @@ struct ObjectIdHash : std::hash { } }; +namespace { +using Container = std::vector; +struct DispatcherStack: Container { + constexpr static size_t MAX_RECURSIVE_DEPTH = 1024u; + DispatcherStack() { reserve(MAX_RECURSIVE_DEPTH); } + + template + auto&& emplace_back_safely(Args&& ...args) { + mgb_throw_if(size() >= MAX_RECURSIVE_DEPTH, mgb::MegBrainError, + "recursion depth %zu is greater than the MAX_RECURSIVE_DEPTH(%zu)", + size(), MAX_RECURSIVE_DEPTH); + return emplace_back(std::forward(args)...); + } +}; +} // anonymous namespace + struct Dispatcher { std::unordered_map, FastSigHash> cache; - std::vector stack; + DispatcherStack stack; std::unordered_map, ObjectIdHash> registry; inline py::handle self() { @@ -78,7 +95,7 @@ struct Dispatcher { return false; } } - stack.emplace_back(it->second.get()); + stack.emplace_back_safely(it->second.get()); return true; } @@ -145,7 +162,7 @@ public: PyErr_SetString(PyExc_RuntimeError, "super called at top level"); return nullptr; } - stack.emplace_back(stack.back()).mro_offset++; + stack.emplace_back_safely(stack.back()).mro_offset++; return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast(args), nargs);}); }