GitOrigin-RevId: dfc69c3b3f
tags/v1.6.0-rc1
@@ -12,6 +12,7 @@ | |||||
#include "./cg_impl_seq.h" | #include "./cg_impl_seq.h" | ||||
#include "megbrain/graph/exc_extra_info.h" | #include "megbrain/graph/exc_extra_info.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/utils/arith_helper.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace cg; | using namespace cg; | ||||
@@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute( | |||||
} | } | ||||
exec_ctx.perform(&m_exec_env); | exec_ctx.perform(&m_exec_env); | ||||
#ifndef __IN_TEE_ENV__ | |||||
do_regist(); | |||||
#endif | |||||
} | } | ||||
void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { | void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { | ||||
@@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||||
} | } | ||||
#ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | ||||
const std::string& svg_name) { | |||||
check_not_finalized(); | |||||
const std::string& svg_name) const { | |||||
auto& recorder = StaticMemRecorder::Instance(); | auto& recorder = StaticMemRecorder::Instance(); | ||||
recorder.active(); | recorder.active(); | ||||
ExecContext exec_ctx{this}; | |||||
recorder.set_svg_name(svg_name); | |||||
} | |||||
void ComputingGraphImpl::ComputingSequence::do_regist() const { | |||||
// regist weights | // regist weights | ||||
size_t addr_base = recorder.peak_mem_size(); | |||||
size_t chunk_id = recorder.set_weight_chunk_id(); | |||||
for (auto&& i : *(this->m_opr_seq)) { | |||||
auto op = i->output(); | |||||
for (auto&& j : op) { | |||||
auto& mp = j->mem_plan(); | |||||
if (mp.valid()) { | |||||
auto& mc = mp.chunk(); | |||||
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||||
recorder.regist_memory_chunk( | |||||
{chunk_id++, mc.size(), 0, this->m_opr_seq->size(), | |||||
addr_base, addr_base + mc.size(), 0, false, | |||||
mc.owner_var->name()}); | |||||
addr_base += mc.size(); | |||||
auto& recorder = StaticMemRecorder::Instance(); | |||||
if (recorder.valid()) { | |||||
size_t addr_base = recorder.peak_mem_size(); | |||||
size_t chunk_id = recorder.set_weight_chunk_id(); | |||||
for (auto&& i : *(this->m_opr_seq)) { | |||||
auto op = i->output(); | |||||
for (auto&& j : op) { | |||||
auto& mp = j->mem_plan(); | |||||
if (mp.valid()) { | |||||
auto& mc = mp.chunk(); | |||||
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||||
auto size = mgb::get_aligned_power2( | |||||
mc.size(), | |||||
j->comp_node().get_mem_addr_alignment()); | |||||
recorder.regist_memory_chunk( | |||||
{chunk_id++, size, 0, this->m_opr_seq->size(), | |||||
addr_base, addr_base + size, 0, false, | |||||
mc.owner_var->name()}); | |||||
addr_base += size; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
recorder.set_sum_mem_size(addr_base); | |||||
recorder.show(); | |||||
} | } | ||||
recorder.set_sum_mem_size(addr_base); | |||||
mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); | |||||
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||||
"svg_name must be end with \".svg\"\n"); | |||||
recorder.show(svg_name); | |||||
} | } | ||||
#endif | #endif | ||||
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | ||||
@@ -174,7 +174,10 @@ public: | |||||
std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | ||||
#ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
void get_static_memory_alloc_info( | void get_static_memory_alloc_info( | ||||
const std::string& svg_name = "static_mem_record.svg") override; | |||||
const std::string& svg_name = | |||||
"static_mem_record.svg") const override; | |||||
void do_regist() const; | |||||
#endif | #endif | ||||
}; | }; | ||||
@@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable, | |||||
return (*(output_vars_pair.first))->get_output_vars(); | return (*(output_vars_pair.first))->get_output_vars(); | ||||
} | } | ||||
#ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
virtual void get_static_memory_alloc_info(const std::string& svg_name) { | |||||
virtual void get_static_memory_alloc_info( | |||||
const std::string& svg_name) const { | |||||
mgb_assert(svg_name.length() < 0, | mgb_assert(svg_name.length() < 0, | ||||
"can't call this function directly\n"); | "can't call this function directly\n"); | ||||
} | } | ||||
@@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color, | |||||
} | } | ||||
} // namespace | } // namespace | ||||
void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
void StaticMemRecorder::dump_svg() { | |||||
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | ||||
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | ||||
float address_scale = 1; | float address_scale = 1; | ||||
@@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
svg_height = svg_height + opr_rect_height * 2; | svg_height = svg_height + opr_rect_height * 2; | ||||
std::ofstream outfile; | std::ofstream outfile; | ||||
outfile.open(svg_name); | |||||
outfile.open(m_svg_name); | |||||
outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | ||||
outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | ||||
"\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | "\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | ||||
@@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
outfile.close(); | outfile.close(); | ||||
} | } | ||||
void StaticMemRecorder::show(std::string svg_name) { | |||||
void StaticMemRecorder::show() { | |||||
for (auto&& i : m_memory_chunk_recorder) { | for (auto&& i : m_memory_chunk_recorder) { | ||||
if (i.id >= m_weight_chunk_id) { | if (i.id >= m_weight_chunk_id) { | ||||
break; | break; | ||||
@@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) { | |||||
m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | ||||
} | } | ||||
} | } | ||||
dump_svg(svg_name); | |||||
dump_svg(); | |||||
} | } | ||||
std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | ||||
@@ -54,25 +54,38 @@ public: | |||||
void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | ||||
const size_t& peak_mem_size() { return m_peak_mem_size; } | |||||
const size_t& peak_mem_size() const { return m_peak_mem_size; } | |||||
void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | ||||
const size_t& sum_mem_size() { return m_sum_mem_size; } | |||||
const size_t& sum_mem_size() const { return m_sum_mem_size; } | |||||
const size_t& set_weight_chunk_id() { | const size_t& set_weight_chunk_id() { | ||||
m_weight_chunk_id = m_memory_chunk_recorder.size(); | m_weight_chunk_id = m_memory_chunk_recorder.size(); | ||||
return m_weight_chunk_id; | return m_weight_chunk_id; | ||||
} | } | ||||
const size_t& weight_chunk_id() { return m_weight_chunk_id; } | |||||
const size_t& weight_chunk_id() const { return m_weight_chunk_id; } | |||||
void dump_svg(std::string svg_name); | |||||
void dump_svg(); | |||||
void show(std::string svg_name); | |||||
void show(); | |||||
void set_svg_name(const std::string& svg_name) { | |||||
mgb_assert(svg_name.length() > 4, | |||||
"svg_name must be end with \".svg\"\n"); | |||||
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||||
"svg_name must be end with \".svg\"\n"); | |||||
m_svg_name = svg_name; | |||||
} | |||||
const std::string& get_svg_name() const{ | |||||
return m_svg_name; | |||||
} | |||||
private: | private: | ||||
bool m_is_record = false; | bool m_is_record = false; | ||||
std::string m_svg_name; | |||||
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | ||||
// weights memory chunks | // weights memory chunks | ||||
size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | ||||