@@ -141,9 +141,13 @@ R"__usage__( | |||||
level 2 the computing graph can be destructed to reduce memory usage. Read | level 2 the computing graph can be destructed to reduce memory usage. Read | ||||
the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more | the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more | ||||
details. | details. | ||||
)__usage__" | |||||
#ifndef __IN_TEE_ENV__ | |||||
R"__usage__( | |||||
--get-static-mem-info <svgname> | --get-static-mem-info <svgname> | ||||
Record the static graph's static memory info. | Record the static graph's static memory info. | ||||
)__usage__" | )__usage__" | ||||
#endif | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
R"__usage__( | R"__usage__( | ||||
--full-run | --full-run | ||||
@@ -538,7 +542,9 @@ struct Args { | |||||
#endif | #endif | ||||
bool reproducible = false; | bool reproducible = false; | ||||
std::string fast_run_cache_path; | std::string fast_run_cache_path; | ||||
#ifndef __IN_TEE_ENV__ | |||||
std::string static_mem_svg_path; | std::string static_mem_svg_path; | ||||
#endif | |||||
bool copy_to_host = false; | bool copy_to_host = false; | ||||
int nr_run = 10; | int nr_run = 10; | ||||
int nr_warmup = 1; | int nr_warmup = 1; | ||||
@@ -797,9 +803,11 @@ void run_test_st(Args &env) { | |||||
} | } | ||||
auto func = env.load_ret.graph_compile(out_spec); | auto func = env.load_ret.graph_compile(out_spec); | ||||
#ifndef __IN_TEE_ENV__ | |||||
if (!env.static_mem_svg_path.empty()) { | if (!env.static_mem_svg_path.empty()) { | ||||
func->get_static_memory_alloc_info(env.static_mem_svg_path); | func->get_static_memory_alloc_info(env.static_mem_svg_path); | ||||
} | } | ||||
#endif | |||||
auto warmup = [&]() { | auto warmup = [&]() { | ||||
printf("=== prepare: %.3fms; going to warmup\n", | printf("=== prepare: %.3fms; going to warmup\n", | ||||
timer.get_msecs_reset()); | timer.get_msecs_reset()); | ||||
@@ -1383,6 +1391,7 @@ Args Args::from_argv(int argc, char **argv) { | |||||
graph_opt.comp_node_seq_record_level = 2; | graph_opt.comp_node_seq_record_level = 2; | ||||
continue; | continue; | ||||
} | } | ||||
#ifndef __IN_TEE_ENV__ | |||||
if (!strcmp(argv[i], "--get-static-mem-info")) { | if (!strcmp(argv[i], "--get-static-mem-info")) { | ||||
++i; | ++i; | ||||
mgb_assert(i < argc, "value not given for --get-static-mem-info"); | mgb_assert(i < argc, "value not given for --get-static-mem-info"); | ||||
@@ -1393,6 +1402,7 @@ Args Args::from_argv(int argc, char **argv) { | |||||
ret.static_mem_svg_path.c_str()); | ret.static_mem_svg_path.c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
#endif | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (!strcmp(argv[i], "--fast-run")) { | if (!strcmp(argv[i], "--fast-run")) { | ||||
ret.use_fast_run = true; | ret.use_fast_run = true; | ||||
@@ -491,7 +491,7 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||||
do_execute(nullptr); | do_execute(nullptr); | ||||
return *this; | return *this; | ||||
} | } | ||||
#ifndef __IN_TEE_ENV__ | |||||
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | ||||
const std::string& svg_name) { | const std::string& svg_name) { | ||||
check_not_finalized(); | check_not_finalized(); | ||||
@@ -523,7 +523,7 @@ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | |||||
"svg_name must be end with \".svg\"\n"); | "svg_name must be end with \".svg\"\n"); | ||||
recorder.show(svg_name); | recorder.show(svg_name); | ||||
} | } | ||||
#endif | |||||
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | ||||
do_wait(true); | do_wait(true); | ||||
return *this; | return *this; | ||||
@@ -170,9 +170,10 @@ public: | |||||
} | } | ||||
std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | ||||
#ifndef __IN_TEE_ENV__ | |||||
void get_static_memory_alloc_info( | void get_static_memory_alloc_info( | ||||
const std::string& svg_name = "static_mem_record.svg") override; | const std::string& svg_name = "static_mem_record.svg") override; | ||||
#endif | |||||
}; | }; | ||||
class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | ||||
@@ -178,18 +178,19 @@ bool SeqMemOptimizer::run_static_mem_alloc() { | |||||
ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ||||
// get all memory chunks | // get all memory chunks | ||||
#ifndef __IN_TEE_ENV__ | |||||
if (StaticMemRecorder::Instance().valid()) { | if (StaticMemRecorder::Instance().valid()) { | ||||
StaticMemRecorder::Instance().clear_opr_seq(); | StaticMemRecorder::Instance().clear_opr_seq(); | ||||
} | } | ||||
#endif | |||||
for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | ||||
OperatorNodeBase *opr = m_cur_seq_full->at(idx); | OperatorNodeBase *opr = m_cur_seq_full->at(idx); | ||||
#ifndef __IN_TEE_ENV__ | |||||
if (StaticMemRecorder::Instance().valid()) { | if (StaticMemRecorder::Instance().valid()) { | ||||
StaticMemRecorder::Instance().regist_opr_seq( | StaticMemRecorder::Instance().regist_opr_seq( | ||||
{idx, 0, opr->name()}); | {idx, 0, opr->name()}); | ||||
} | } | ||||
#endif | |||||
auto &&dep_map = opr->node_prop().dep_map(); | auto &&dep_map = opr->node_prop().dep_map(); | ||||
if (in_sys_alloc(opr)) { | if (in_sys_alloc(opr)) { | ||||
@@ -358,6 +359,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( | |||||
chk.chunk->mem_alloc_status.set_static_offset( | chk.chunk->mem_alloc_status.set_static_offset( | ||||
allocator->get_start_addr(&chk)); | allocator->get_start_addr(&chk)); | ||||
} | } | ||||
#ifndef __IN_TEE_ENV__ | |||||
auto& recorder = StaticMemRecorder::Instance(); | auto& recorder = StaticMemRecorder::Instance(); | ||||
if (recorder.valid()) { | if (recorder.valid()) { | ||||
for (size_t i = 0; i < chunks.size(); i++) { | for (size_t i = 0; i < chunks.size(); i++) { | ||||
@@ -366,6 +368,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( | |||||
} | } | ||||
recorder.regist_peak_mem_size(size); | recorder.regist_peak_mem_size(size); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
return should_realloc; | return should_realloc; | ||||
@@ -119,7 +119,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { | |||||
do_solve(); | do_solve(); | ||||
check_result_and_calc_lower_bound(); | check_result_and_calc_lower_bound(); | ||||
#ifndef __IN_TEE_ENV__ | |||||
if (StaticMemRecorder::Instance().valid()) { | if (StaticMemRecorder::Instance().valid()) { | ||||
StaticMemRecorder::Instance().clear_memory_chunk(); | StaticMemRecorder::Instance().clear_memory_chunk(); | ||||
for (auto&& i : m_interval) { | for (auto&& i : m_interval) { | ||||
@@ -135,7 +135,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { | |||||
is_overwrite, ""}); | is_overwrite, ""}); | ||||
} | } | ||||
} | } | ||||
#endif | |||||
return *this; | return *this; | ||||
} | } | ||||
@@ -194,11 +194,12 @@ class AsyncExecutable : public json::Serializable, | |||||
m_user_data.get_user_data<OutputVarsUserData>(); | m_user_data.get_user_data<OutputVarsUserData>(); | ||||
return (*(output_vars_pair.first))->get_output_vars(); | return (*(output_vars_pair.first))->get_output_vars(); | ||||
} | } | ||||
#ifndef __IN_TEE_ENV__ | |||||
virtual void get_static_memory_alloc_info(const std::string& svg_name) { | virtual void get_static_memory_alloc_info(const std::string& svg_name) { | ||||
mgb_assert(svg_name.length() < 0, | mgb_assert(svg_name.length() < 0, | ||||
"can't call this function directly\n"); | "can't call this function directly\n"); | ||||
} | } | ||||
#endif | |||||
}; | }; | ||||
@@ -14,13 +14,11 @@ | |||||
#ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
#include <fstream> | #include <fstream> | ||||
#include <iostream> | #include <iostream> | ||||
#endif | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace cg; | using namespace cg; | ||||
namespace { | namespace { | ||||
#ifndef __IN_TEE_ENV__ | |||||
#define SVG_WIDTH 20000.0 | #define SVG_WIDTH 20000.0 | ||||
#define SVG_HEIGHT 15000.0 | #define SVG_HEIGHT 15000.0 | ||||
#define OPR_RECT_WIDTH 40.0 | #define OPR_RECT_WIDTH 40.0 | ||||
@@ -86,13 +84,9 @@ std::string draw_polyline(std::string point_seq, std::string color, | |||||
std::string width, std::string p = polyline) { | std::string width, std::string p = polyline) { | ||||
return replace_by_parameter(p, 0, point_seq, color, width); | return replace_by_parameter(p, 0, point_seq, color, width); | ||||
} | } | ||||
#endif | |||||
} // namespace | } // namespace | ||||
void StaticMemRecorder::dump_svg(std::string svg_name) { | void StaticMemRecorder::dump_svg(std::string svg_name) { | ||||
#ifdef __IN_TEE_ENV__ | |||||
MGB_MARK_USED_VAR(svg_name); | |||||
#else | |||||
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | ||||
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | ||||
float address_scale = 1; | float address_scale = 1; | ||||
@@ -247,7 +241,6 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
<< std::endl; | << std::endl; | ||||
outfile << "</svg>" << std::endl; | outfile << "</svg>" << std::endl; | ||||
outfile.close(); | outfile.close(); | ||||
#endif | |||||
} | } | ||||
void StaticMemRecorder::show(std::string svg_name) { | void StaticMemRecorder::show(std::string svg_name) { | ||||
@@ -326,3 +319,4 @@ std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | |||||
} | } | ||||
return chunk_ids; | return chunk_ids; | ||||
} | } | ||||
#endif |
@@ -12,7 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/utils/metahelper.h" | #include "megbrain/utils/metahelper.h" | ||||
#ifndef __IN_TEE_ENV__ | |||||
namespace mgb { | namespace mgb { | ||||
namespace cg { | namespace cg { | ||||
@@ -83,3 +83,4 @@ private: | |||||
}; | }; | ||||
} // namespace cg | } // namespace cg | ||||
} // namespace mgb | } // namespace mgb | ||||
#endif |