|
|
@@ -15,6 +15,7 @@ |
|
|
|
#include "megbrain/graph/static_infer.h" |
|
|
|
#include "megbrain/imperative/ops/backward_graph.h" |
|
|
|
#include "megbrain/imperative/ops/opr_attr.h" |
|
|
|
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
|
#include "megbrain/opr/utility.h" |
|
|
@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs( |
|
|
|
const OpDef& opdef, const SmallVector<Tensor*>& inputs) { |
|
|
|
SmallVector<LogicalTensorDesc> ret; |
|
|
|
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); |
|
|
|
::mgb::opr::intl::WorkspaceLimitHook::set_impl( |
|
|
|
m_graph.get(), ProxyGraph::get_workspace_limit); |
|
|
|
do_shape_infer(true); |
|
|
|
for (auto&& i : m_cur_opr->usable_output()) { |
|
|
|
mgb_assert(i->dtype().valid() && i->comp_node().valid()); |
|
|
@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor( |
|
|
|
// get proxy opr |
|
|
|
auto proxy = m_cur_opr; |
|
|
|
|
|
|
|
auto get_workspace_size = [=](CompNode cn, size_t old_limit) { |
|
|
|
size_t limit = 0; |
|
|
|
for (auto&& var : workspaces) { |
|
|
|
limit += var->dtype().size(var->shape().total_nr_elems()); |
|
|
|
} |
|
|
|
return limit; |
|
|
|
}; |
|
|
|
::mgb::opr::intl::WorkspaceLimitHook::set_impl(m_graph.get(), get_workspace_size); |
|
|
|
do_shape_infer(true); |
|
|
|
|
|
|
|
size_t j = 0; |
|
|
@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph:: |
|
|
|
const SmallVector<MemoryDesc>& inputs_mems) { |
|
|
|
auto opr = get_proxy_opr(def, inputs_tensors); |
|
|
|
CUR_OPR_GUARD(opr); |
|
|
|
::mgb::opr::intl::WorkspaceLimitHook::set_impl( |
|
|
|
m_graph.get(), ProxyGraph::get_workspace_limit); |
|
|
|
do_shape_infer(true); |
|
|
|
SmallVector<MemoryDesc> outputs; |
|
|
|
SmallVector<MemoryDesc> workspaces; |
|
|
|