GitOrigin-RevId: dfc69c3b3f
tags/v1.6.0-rc1
@@ -12,6 +12,7 @@ | |||
#include "./cg_impl_seq.h" | |||
#include "megbrain/graph/exc_extra_info.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/utils/arith_helper.h" | |||
using namespace mgb; | |||
using namespace cg; | |||
@@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute( | |||
} | |||
exec_ctx.perform(&m_exec_env); | |||
#ifndef __IN_TEE_ENV__ | |||
do_regist(); | |||
#endif | |||
} | |||
void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { | |||
@@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||
} | |||
#ifndef __IN_TEE_ENV__ | |||
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | |||
const std::string& svg_name) { | |||
check_not_finalized(); | |||
const std::string& svg_name) const { | |||
auto& recorder = StaticMemRecorder::Instance(); | |||
recorder.active(); | |||
ExecContext exec_ctx{this}; | |||
recorder.set_svg_name(svg_name); | |||
} | |||
void ComputingGraphImpl::ComputingSequence::do_regist() const { | |||
// regist weights | |||
size_t addr_base = recorder.peak_mem_size(); | |||
size_t chunk_id = recorder.set_weight_chunk_id(); | |||
for (auto&& i : *(this->m_opr_seq)) { | |||
auto op = i->output(); | |||
for (auto&& j : op) { | |||
auto& mp = j->mem_plan(); | |||
if (mp.valid()) { | |||
auto& mc = mp.chunk(); | |||
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||
recorder.regist_memory_chunk( | |||
{chunk_id++, mc.size(), 0, this->m_opr_seq->size(), | |||
addr_base, addr_base + mc.size(), 0, false, | |||
mc.owner_var->name()}); | |||
addr_base += mc.size(); | |||
auto& recorder = StaticMemRecorder::Instance(); | |||
if (recorder.valid()) { | |||
size_t addr_base = recorder.peak_mem_size(); | |||
size_t chunk_id = recorder.set_weight_chunk_id(); | |||
for (auto&& i : *(this->m_opr_seq)) { | |||
auto op = i->output(); | |||
for (auto&& j : op) { | |||
auto& mp = j->mem_plan(); | |||
if (mp.valid()) { | |||
auto& mc = mp.chunk(); | |||
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||
auto size = mgb::get_aligned_power2( | |||
mc.size(), | |||
j->comp_node().get_mem_addr_alignment()); | |||
recorder.regist_memory_chunk( | |||
{chunk_id++, size, 0, this->m_opr_seq->size(), | |||
addr_base, addr_base + size, 0, false, | |||
mc.owner_var->name()}); | |||
addr_base += size; | |||
} | |||
} | |||
} | |||
} | |||
recorder.set_sum_mem_size(addr_base); | |||
recorder.show(); | |||
} | |||
recorder.set_sum_mem_size(addr_base); | |||
mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); | |||
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||
"svg_name must be end with \".svg\"\n"); | |||
recorder.show(svg_name); | |||
} | |||
#endif | |||
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | |||
@@ -174,7 +174,10 @@ public: | |||
std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | |||
#ifndef __IN_TEE_ENV__ | |||
void get_static_memory_alloc_info( | |||
const std::string& svg_name = "static_mem_record.svg") override; | |||
const std::string& svg_name = | |||
"static_mem_record.svg") const override; | |||
void do_regist() const; | |||
#endif | |||
}; | |||
@@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable, | |||
return (*(output_vars_pair.first))->get_output_vars(); | |||
} | |||
#ifndef __IN_TEE_ENV__ | |||
virtual void get_static_memory_alloc_info(const std::string& svg_name) { | |||
virtual void get_static_memory_alloc_info( | |||
const std::string& svg_name) const { | |||
mgb_assert(svg_name.length() < 0, | |||
"can't call this function directly\n"); | |||
} | |||
@@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color, | |||
} | |||
} // namespace | |||
void StaticMemRecorder::dump_svg(std::string svg_name) { | |||
void StaticMemRecorder::dump_svg() { | |||
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | |||
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | |||
float address_scale = 1; | |||
@@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||
svg_height = svg_height + opr_rect_height * 2; | |||
std::ofstream outfile; | |||
outfile.open(svg_name); | |||
outfile.open(m_svg_name); | |||
outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | |||
outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | |||
"\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | |||
@@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||
outfile.close(); | |||
} | |||
void StaticMemRecorder::show(std::string svg_name) { | |||
void StaticMemRecorder::show() { | |||
for (auto&& i : m_memory_chunk_recorder) { | |||
if (i.id >= m_weight_chunk_id) { | |||
break; | |||
@@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) { | |||
m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | |||
} | |||
} | |||
dump_svg(svg_name); | |||
dump_svg(); | |||
} | |||
std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | |||
@@ -54,25 +54,38 @@ public: | |||
void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | |||
const size_t& peak_mem_size() { return m_peak_mem_size; } | |||
const size_t& peak_mem_size() const { return m_peak_mem_size; } | |||
void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | |||
const size_t& sum_mem_size() { return m_sum_mem_size; } | |||
const size_t& sum_mem_size() const { return m_sum_mem_size; } | |||
const size_t& set_weight_chunk_id() { | |||
m_weight_chunk_id = m_memory_chunk_recorder.size(); | |||
return m_weight_chunk_id; | |||
} | |||
const size_t& weight_chunk_id() { return m_weight_chunk_id; } | |||
const size_t& weight_chunk_id() const { return m_weight_chunk_id; } | |||
void dump_svg(std::string svg_name); | |||
void dump_svg(); | |||
void show(std::string svg_name); | |||
void show(); | |||
void set_svg_name(const std::string& svg_name) { | |||
mgb_assert(svg_name.length() > 4, | |||
"svg_name must be end with \".svg\"\n"); | |||
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||
"svg_name must be end with \".svg\"\n"); | |||
m_svg_name = svg_name; | |||
} | |||
const std::string& get_svg_name() const{ | |||
return m_svg_name; | |||
} | |||
private: | |||
bool m_is_record = false; | |||
std::string m_svg_name; | |||
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | |||
// weights memory chunks | |||
size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | |||