Browse Source

fix(mgb): fix getting static memory alloc info

GitOrigin-RevId: dfc69c3b3f
tags/v1.6.0-rc1
Megvii Engine Team 4 years ago
parent
commit
6070f1272d
5 changed files with 61 additions and 33 deletions
  1. +33
    -22
      src/core/impl/graph/cg_impl_seq.cpp
  2. +4
    -1
      src/core/impl/graph/cg_impl_seq.h
  3. +2
    -1
      src/core/include/megbrain/graph/bases.h
  4. +4
    -4
      src/plugin/impl/static_mem_record.cpp
  5. +18
    -5
      src/plugin/include/megbrain/plugin/static_mem_record.h

+ 33
- 22
src/core/impl/graph/cg_impl_seq.cpp View File

@@ -12,6 +12,7 @@
#include "./cg_impl_seq.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/utils/arith_helper.h"

using namespace mgb;
using namespace cg;
@@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute(
}

exec_ctx.perform(&m_exec_env);
#ifndef __IN_TEE_ENV__
do_regist();
#endif
}

void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) {
@@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
}
#ifndef __IN_TEE_ENV__
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
const std::string& svg_name) {
check_not_finalized();
const std::string& svg_name) const {
auto& recorder = StaticMemRecorder::Instance();
recorder.active();
ExecContext exec_ctx{this};
recorder.set_svg_name(svg_name);
}

void ComputingGraphImpl::ComputingSequence::do_regist() const {
// regist weights
size_t addr_base = recorder.peak_mem_size();
size_t chunk_id = recorder.set_weight_chunk_id();
for (auto&& i : *(this->m_opr_seq)) {
auto op = i->output();
for (auto&& j : op) {
auto& mp = j->mem_plan();
if (mp.valid()) {
auto& mc = mp.chunk();
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) {
recorder.regist_memory_chunk(
{chunk_id++, mc.size(), 0, this->m_opr_seq->size(),
addr_base, addr_base + mc.size(), 0, false,
mc.owner_var->name()});
addr_base += mc.size();
auto& recorder = StaticMemRecorder::Instance();
if (recorder.valid()) {
size_t addr_base = recorder.peak_mem_size();
size_t chunk_id = recorder.set_weight_chunk_id();
for (auto&& i : *(this->m_opr_seq)) {
auto op = i->output();
for (auto&& j : op) {
auto& mp = j->mem_plan();
if (mp.valid()) {
auto& mc = mp.chunk();
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) {
auto size = mgb::get_aligned_power2(
mc.size(),
j->comp_node().get_mem_addr_alignment());

recorder.regist_memory_chunk(
{chunk_id++, size, 0, this->m_opr_seq->size(),
addr_base, addr_base + size, 0, false,
mc.owner_var->name()});

addr_base += size;
}
}
}
}
recorder.set_sum_mem_size(addr_base);
recorder.show();
}
recorder.set_sum_mem_size(addr_base);
mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n");
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0,
"svg_name must be end with \".svg\"\n");
recorder.show(svg_name);
}
#endif
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() {


+ 4
- 1
src/core/impl/graph/cg_impl_seq.h View File

@@ -174,7 +174,10 @@ public:
std::unique_ptr<RecordedComputingSequence> as_recorded_seq();
#ifndef __IN_TEE_ENV__
void get_static_memory_alloc_info(
const std::string& svg_name = "static_mem_record.svg") override;
const std::string& svg_name =
"static_mem_record.svg") const override;

void do_regist() const;
#endif
};



+ 2
- 1
src/core/include/megbrain/graph/bases.h View File

@@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable,
return (*(output_vars_pair.first))->get_output_vars();
}
#ifndef __IN_TEE_ENV__
virtual void get_static_memory_alloc_info(const std::string& svg_name) {
virtual void get_static_memory_alloc_info(
const std::string& svg_name) const {
mgb_assert(svg_name.length() < 0,
"can't call this function directly\n");
}


+ 4
- 4
src/plugin/impl/static_mem_record.cpp View File

@@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color,
}
} // namespace

void StaticMemRecorder::dump_svg(std::string svg_name) {
void StaticMemRecorder::dump_svg() {
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT,
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT;
float address_scale = 1;
@@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
svg_height = svg_height + opr_rect_height * 2;

std::ofstream outfile;
outfile.open(svg_name);
outfile.open(m_svg_name);
outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl;
outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" "
"\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">"
@@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) {
outfile.close();
}

void StaticMemRecorder::show(std::string svg_name) {
void StaticMemRecorder::show() {
for (auto&& i : m_memory_chunk_recorder) {
if (i.id >= m_weight_chunk_id) {
break;
@@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) {
m_opr_seq_recorder.at(chunk.time_begin).name.c_str());
}
}
dump_svg(svg_name);
dump_svg();
}

std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct(


+ 18
- 5
src/plugin/include/megbrain/plugin/static_mem_record.h View File

@@ -54,25 +54,38 @@ public:

void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; }

const size_t& peak_mem_size() { return m_peak_mem_size; }
const size_t& peak_mem_size() const { return m_peak_mem_size; }

void set_sum_mem_size(size_t size) { m_sum_mem_size = size; }

const size_t& sum_mem_size() { return m_sum_mem_size; }
const size_t& sum_mem_size() const { return m_sum_mem_size; }

const size_t& set_weight_chunk_id() {
m_weight_chunk_id = m_memory_chunk_recorder.size();
return m_weight_chunk_id;
}

const size_t& weight_chunk_id() { return m_weight_chunk_id; }
const size_t& weight_chunk_id() const { return m_weight_chunk_id; }

void dump_svg(std::string svg_name);
void dump_svg();

void show(std::string svg_name);
void show();

void set_svg_name(const std::string& svg_name) {
mgb_assert(svg_name.length() > 4,
"svg_name must be end with \".svg\"\n");
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0,
"svg_name must be end with \".svg\"\n");
m_svg_name = svg_name;
}

const std::string& get_svg_name() const{
return m_svg_name;
}

private:
bool m_is_record = false;
std::string m_svg_name;
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are
// weights memory chunks
size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id;


Loading…
Cancel
Save