Browse Source

fix(mgb): remove static mem record from tee

GitOrigin-RevId: ac61b2a5eb
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
07de15713c
8 changed files with 27 additions and 17 deletions
  1. +10
    -0
      sdk/load-and-run/src/mgblar.cpp
  2. +2
    -2
      src/core/impl/graph/cg_impl_seq.cpp
  3. +2
    -1
      src/core/impl/graph/cg_impl_seq.h
  4. +6
    -3
      src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp
  5. +2
    -2
      src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp
  6. +2
    -1
      src/core/include/megbrain/graph/bases.h
  7. +1
    -7
      src/plugin/impl/static_mem_record.cpp
  8. +2
    -1
      src/plugin/include/megbrain/plugin/static_mem_record.h

+ 10
- 0
sdk/load-and-run/src/mgblar.cpp View File

@@ -141,9 +141,13 @@ R"__usage__(
level 2 the computing graph can be destructed to reduce memory usage. Read level 2 the computing graph can be destructed to reduce memory usage. Read
the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more
details. details.
)__usage__"
#ifndef __IN_TEE_ENV__
R"__usage__(
--get-static-mem-info <svgname> --get-static-mem-info <svgname>
Record the static graph's static memory info. Record the static graph's static memory info.
)__usage__" )__usage__"
#endif
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
R"__usage__( R"__usage__(
--full-run --full-run
@@ -538,7 +542,9 @@ struct Args {
#endif #endif
bool reproducible = false; bool reproducible = false;
std::string fast_run_cache_path; std::string fast_run_cache_path;
#ifndef __IN_TEE_ENV__
std::string static_mem_svg_path; std::string static_mem_svg_path;
#endif
bool copy_to_host = false; bool copy_to_host = false;
int nr_run = 10; int nr_run = 10;
int nr_warmup = 1; int nr_warmup = 1;
@@ -797,9 +803,11 @@ void run_test_st(Args &env) {
} }


auto func = env.load_ret.graph_compile(out_spec); auto func = env.load_ret.graph_compile(out_spec);
#ifndef __IN_TEE_ENV__
if (!env.static_mem_svg_path.empty()) { if (!env.static_mem_svg_path.empty()) {
func->get_static_memory_alloc_info(env.static_mem_svg_path); func->get_static_memory_alloc_info(env.static_mem_svg_path);
} }
#endif
auto warmup = [&]() { auto warmup = [&]() {
printf("=== prepare: %.3fms; going to warmup\n", printf("=== prepare: %.3fms; going to warmup\n",
timer.get_msecs_reset()); timer.get_msecs_reset());
@@ -1383,6 +1391,7 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.comp_node_seq_record_level = 2; graph_opt.comp_node_seq_record_level = 2;
continue; continue;
} }
#ifndef __IN_TEE_ENV__
if (!strcmp(argv[i], "--get-static-mem-info")) { if (!strcmp(argv[i], "--get-static-mem-info")) {
++i; ++i;
mgb_assert(i < argc, "value not given for --get-static-mem-info"); mgb_assert(i < argc, "value not given for --get-static-mem-info");
@@ -1393,6 +1402,7 @@ Args Args::from_argv(int argc, char **argv) {
ret.static_mem_svg_path.c_str()); ret.static_mem_svg_path.c_str());
continue; continue;
} }
#endif
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (!strcmp(argv[i], "--fast-run")) { if (!strcmp(argv[i], "--fast-run")) {
ret.use_fast_run = true; ret.use_fast_run = true;


+ 2
- 2
src/core/impl/graph/cg_impl_seq.cpp View File

@@ -491,7 +491,7 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
do_execute(nullptr); do_execute(nullptr);
return *this; return *this;
} }
#ifndef __IN_TEE_ENV__
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
const std::string& svg_name) { const std::string& svg_name) {
check_not_finalized(); check_not_finalized();
@@ -523,7 +523,7 @@ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
"svg_name must be end with \".svg\"\n"); "svg_name must be end with \".svg\"\n");
recorder.show(svg_name); recorder.show(svg_name);
} }
#endif
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() {
do_wait(true); do_wait(true);
return *this; return *this;


+ 2
- 1
src/core/impl/graph/cg_impl_seq.h View File

@@ -170,9 +170,10 @@ public:
} }


std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); std::unique_ptr<RecordedComputingSequence> as_recorded_seq();
#ifndef __IN_TEE_ENV__
void get_static_memory_alloc_info( void get_static_memory_alloc_info(
const std::string& svg_name = "static_mem_record.svg") override; const std::string& svg_name = "static_mem_record.svg") override;
#endif
}; };


class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj {


+ 6
- 3
src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp View File

@@ -178,18 +178,19 @@ bool SeqMemOptimizer::run_static_mem_alloc() {
ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval;


// get all memory chunks // get all memory chunks
#ifndef __IN_TEE_ENV__
if (StaticMemRecorder::Instance().valid()) { if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().clear_opr_seq(); StaticMemRecorder::Instance().clear_opr_seq();
} }
#endif
for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) {
OperatorNodeBase *opr = m_cur_seq_full->at(idx); OperatorNodeBase *opr = m_cur_seq_full->at(idx);
#ifndef __IN_TEE_ENV__
if (StaticMemRecorder::Instance().valid()) { if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().regist_opr_seq( StaticMemRecorder::Instance().regist_opr_seq(
{idx, 0, opr->name()}); {idx, 0, opr->name()});
} }
#endif
auto &&dep_map = opr->node_prop().dep_map(); auto &&dep_map = opr->node_prop().dep_map();


if (in_sys_alloc(opr)) { if (in_sys_alloc(opr)) {
@@ -358,6 +359,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
chk.chunk->mem_alloc_status.set_static_offset( chk.chunk->mem_alloc_status.set_static_offset(
allocator->get_start_addr(&chk)); allocator->get_start_addr(&chk));
} }
#ifndef __IN_TEE_ENV__
auto& recorder = StaticMemRecorder::Instance(); auto& recorder = StaticMemRecorder::Instance();
if (recorder.valid()) { if (recorder.valid()) {
for (size_t i = 0; i < chunks.size(); i++) { for (size_t i = 0; i < chunks.size(); i++) {
@@ -366,6 +368,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
} }
recorder.regist_peak_mem_size(size); recorder.regist_peak_mem_size(size);
} }
#endif
} }


return should_realloc; return should_realloc;


+ 2
- 2
src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp View File

@@ -119,7 +119,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
do_solve(); do_solve();


check_result_and_calc_lower_bound(); check_result_and_calc_lower_bound();
#ifndef __IN_TEE_ENV__
if (StaticMemRecorder::Instance().valid()) { if (StaticMemRecorder::Instance().valid()) {
StaticMemRecorder::Instance().clear_memory_chunk(); StaticMemRecorder::Instance().clear_memory_chunk();
for (auto&& i : m_interval) { for (auto&& i : m_interval) {
@@ -135,7 +135,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
is_overwrite, ""}); is_overwrite, ""});
} }
} }
#endif
return *this; return *this;
} }




+ 2
- 1
src/core/include/megbrain/graph/bases.h View File

@@ -194,11 +194,12 @@ class AsyncExecutable : public json::Serializable,
m_user_data.get_user_data<OutputVarsUserData>(); m_user_data.get_user_data<OutputVarsUserData>();
return (*(output_vars_pair.first))->get_output_vars(); return (*(output_vars_pair.first))->get_output_vars();
} }
#ifndef __IN_TEE_ENV__
virtual void get_static_memory_alloc_info(const std::string& svg_name) { virtual void get_static_memory_alloc_info(const std::string& svg_name) {
mgb_assert(svg_name.length() < 0, mgb_assert(svg_name.length() < 0,
"can't call this function directly\n"); "can't call this function directly\n");
} }
#endif
}; };






+ 1
- 7
src/plugin/impl/static_mem_record.cpp View File

@@ -14,13 +14,11 @@
#ifndef __IN_TEE_ENV__ #ifndef __IN_TEE_ENV__
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#endif


using namespace mgb; using namespace mgb;
using namespace cg; using namespace cg;


namespace { namespace {
#ifndef __IN_TEE_ENV__
#define SVG_WIDTH 20000.0 #define SVG_WIDTH 20000.0
#define SVG_HEIGHT 15000.0 #define SVG_HEIGHT 15000.0
#define OPR_RECT_WIDTH 40.0 #define OPR_RECT_WIDTH 40.0
@@ -86,13 +84,9 @@ std::string draw_polyline(std::string point_seq, std::string color,
std::string width, std::string p = polyline) { std::string width, std::string p = polyline) {
return replace_by_parameter(p, 0, point_seq, color, width); return replace_by_parameter(p, 0, point_seq, color, width);
} }
#endif
} // namespace } // namespace


void StaticMemRecorder::dump_svg(std::string svg_name) { void StaticMemRecorder::dump_svg(std::string svg_name) {
#ifdef __IN_TEE_ENV__
MGB_MARK_USED_VAR(svg_name);
#else
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT,
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT;
float address_scale = 1; float address_scale = 1;
@@ -247,7 +241,6 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
<< std::endl; << std::endl;
outfile << "</svg>" << std::endl; outfile << "</svg>" << std::endl;
outfile.close(); outfile.close();
#endif
} }


void StaticMemRecorder::show(std::string svg_name) { void StaticMemRecorder::show(std::string svg_name) {
@@ -326,3 +319,4 @@ std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct(
} }
return chunk_ids; return chunk_ids;
} }
#endif

+ 2
- 1
src/plugin/include/megbrain/plugin/static_mem_record.h View File

@@ -12,7 +12,7 @@


#pragma once #pragma once
#include "megbrain/utils/metahelper.h" #include "megbrain/utils/metahelper.h"
#ifndef __IN_TEE_ENV__
namespace mgb { namespace mgb {
namespace cg { namespace cg {


@@ -83,3 +83,4 @@ private:
}; };
} // namespace cg } // namespace cg
} // namespace mgb } // namespace mgb
#endif

Loading…
Cancel
Save