GitOrigin-RevId: 474dc691a3
release-1.7
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/graph/static_infer.h" | #include "megbrain/graph/static_infer.h" | ||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs( | |||||
const OpDef& opdef, const SmallVector<Tensor*>& inputs) { | const OpDef& opdef, const SmallVector<Tensor*>& inputs) { | ||||
SmallVector<LogicalTensorDesc> ret; | SmallVector<LogicalTensorDesc> ret; | ||||
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); | CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); | ||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl( | |||||
m_graph.get(), ProxyGraph::get_workspace_limit); | |||||
do_shape_infer(true); | do_shape_infer(true); | ||||
for (auto&& i : m_cur_opr->usable_output()) { | for (auto&& i : m_cur_opr->usable_output()) { | ||||
mgb_assert(i->dtype().valid() && i->comp_node().valid()); | mgb_assert(i->dtype().valid() && i->comp_node().valid()); | ||||
@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor( | |||||
// get proxy opr | // get proxy opr | ||||
auto proxy = m_cur_opr; | auto proxy = m_cur_opr; | ||||
auto get_workspace_size = [=](CompNode cn, size_t old_limit) { | |||||
size_t limit = 0; | |||||
for (auto&& var : workspaces) { | |||||
limit += var->dtype().size(var->shape().total_nr_elems()); | |||||
} | |||||
return limit; | |||||
}; | |||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl(m_graph.get(), get_workspace_size); | |||||
do_shape_infer(true); | do_shape_infer(true); | ||||
size_t j = 0; | size_t j = 0; | ||||
@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph:: | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | const SmallVector<MemoryDesc>& inputs_mems) { | ||||
auto opr = get_proxy_opr(def, inputs_tensors); | auto opr = get_proxy_opr(def, inputs_tensors); | ||||
CUR_OPR_GUARD(opr); | CUR_OPR_GUARD(opr); | ||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl( | |||||
m_graph.get(), ProxyGraph::get_workspace_limit); | |||||
do_shape_infer(true); | do_shape_infer(true); | ||||
SmallVector<MemoryDesc> outputs; | SmallVector<MemoryDesc> outputs; | ||||
SmallVector<MemoryDesc> workspaces; | SmallVector<MemoryDesc> workspaces; | ||||
@@ -27,6 +27,11 @@ public: | |||||
static std::unique_ptr<MegBrainError> get_async_error() { | static std::unique_ptr<MegBrainError> get_async_error() { | ||||
return std::move(tm_async_error); | return std::move(tm_async_error); | ||||
} | } | ||||
static size_t get_workspace_limit(CompNode cn, size_t old_limit) { | |||||
size_t free = cn.get_free_mem(); | |||||
size_t lmt = cn.get_max_block_size_available(); | |||||
return std::max(lmt, free); | |||||
} | |||||
/********************** Physical Tensor API **********************/ | /********************** Physical Tensor API **********************/ | ||||
@@ -273,6 +273,13 @@ public: | |||||
activate(); | activate(); | ||||
return m_mem_alloc->get_max_block_size_available(); | return m_mem_alloc->get_max_block_size_available(); | ||||
} | } | ||||
size_t get_free_mem() override { | |||||
m_env.cuda_env().activate(); | |||||
size_t tot, free; | |||||
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); | |||||
return free; | |||||
} | |||||
#endif | #endif | ||||
Locator locator() override { return m_locator; } | Locator locator() override { return m_locator; } | ||||
@@ -336,6 +336,8 @@ public: | |||||
size_t get_max_block_size_available() const { | size_t get_max_block_size_available() const { | ||||
return m_impl->get_max_block_size_available(); | return m_impl->get_max_block_size_available(); | ||||
} | } | ||||
size_t get_free_mem() const { return m_impl->get_free_mem(); } | |||||
#endif | #endif | ||||
//! change to another stream on the same memory node | //! change to another stream on the same memory node | ||||
@@ -519,6 +521,7 @@ protected: | |||||
} | } | ||||
virtual size_t get_used_memory() { return 0; } | virtual size_t get_used_memory() { return 0; } | ||||
virtual size_t get_max_block_size_available() { return 0; } | virtual size_t get_max_block_size_available() { return 0; } | ||||
virtual size_t get_free_mem() { return 0; } | |||||
#endif | #endif | ||||
virtual Locator locator() = 0; | virtual Locator locator() = 0; | ||||
@@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) { | |||||
mgb_assert(false); | mgb_assert(false); | ||||
} | } | ||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() | |||||
const { | |||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const { | |||||
mgb_assert(false); | mgb_assert(false); | ||||
} | } | ||||
void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */, | |||||
GetWorkspaceLimitImpl /* impl */) { | |||||
void WorkspaceLimitHook::set_impl( | |||||
ComputingGraph* /* graph */, GetWorkspaceLimitImpl /* impl */) { | |||||
mgb_assert(false); | mgb_assert(false); | ||||
} | } | ||||
@@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) { | |||||
m_impl = std::move(impl); | m_impl = std::move(impl); | ||||
} | } | ||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() | |||||
const { | |||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const { | |||||
return m_impl; | return m_impl; | ||||
} | } | ||||
void WorkspaceLimitHook::set_impl(ComputingGraph* graph, | |||||
GetWorkspaceLimitImpl impl) { | |||||
void WorkspaceLimitHook::set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl) { | |||||
mgb_assert(graph->options().imperative_proxy_graph); | mgb_assert(graph->options().imperative_proxy_graph); | ||||
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); }; | auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); }; | ||||
graph->options() | graph->options() | ||||
@@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph, | |||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( | const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( | ||||
ComputingGraph* graph) { | ComputingGraph* graph) { | ||||
mgb_assert(graph->options().imperative_proxy_graph); | mgb_assert(graph->options().imperative_proxy_graph); | ||||
auto container = | |||||
graph->options().user_data.get_user_data<WorkspaceLimitHook>(); | |||||
auto container = graph->options().user_data.get_user_data<WorkspaceLimitHook>(); | |||||
mgb_assert(container.second == 1); | mgb_assert(container.second == 1); | ||||
return container.first[0]->get_impl(); | return container.first[0]->get_impl(); | ||||
} | } | ||||