diff --git a/imperative/python/megengine/tools/svg_viewer.html b/imperative/python/megengine/tools/svg_viewer.html
new file mode 100644
index 00000000..885091c1
--- /dev/null
+++ b/imperative/python/megengine/tools/svg_viewer.html
@@ -0,0 +1,154 @@
+
+
+
Visualizer
+
+
+
+
+
+
+
+ desc
+ info
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp
index 2a2587e7..926bf821 100644
--- a/src/core/impl/graph/cg_impl_seq.cpp
+++ b/src/core/impl/graph/cg_impl_seq.cpp
@@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() {
return *this;
}
+void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info(
+ const std::string& svg_name) {
+ check_not_finalized();
+ auto& recorder = StaticMemRecorder::Instance();
+ recorder.active();
+ ExecContext exec_ctx{this};
+ // regist weights
+ size_t addr_base = recorder.peak_mem_size();
+ size_t chunk_id = recorder.set_weight_chunk_id();
+ for (auto&& i : *(this->m_opr_seq)) {
+ auto op = i->output();
+ for (auto&& j : op) {
+ auto& mp = j->mem_plan();
+ if (mp.valid()) {
+ auto& mc = mp.chunk();
+ if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) {
+ recorder.regist_memory_chunk(
+ {chunk_id++, mc.size(), 0, this->m_opr_seq->size(),
+ addr_base, addr_base + mc.size(), 0, false,
+ mc.owner_var->name()});
+ addr_base += mc.size();
+ }
+ }
+ }
+ }
+ recorder.set_sum_mem_size(addr_base);
+ mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n");
+ mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0,
+ "svg_name must be end with \".svg\"\n");
+ recorder.show(svg_name);
+}
+
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() {
do_wait(true);
return *this;
diff --git a/src/core/impl/graph/cg_impl_seq.h b/src/core/impl/graph/cg_impl_seq.h
index a999542b..8cdc3bee 100644
--- a/src/core/impl/graph/cg_impl_seq.h
+++ b/src/core/impl/graph/cg_impl_seq.h
@@ -16,6 +16,7 @@
#include "megbrain/comp_node_env.h"
#include "megbrain/plugin/var_sanity_check.h"
#include "megbrain/utils/arith_helper.h"
+#include "megbrain/plugin/static_mem_record.h"
namespace mgb {
namespace cg {
@@ -169,6 +170,9 @@ public:
}
std::unique_ptr as_recorded_seq();
+
+ void get_static_memory_alloc_info(
+ const std::string& svg_name = "static_mem_record.svg") override;
};
class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj {
diff --git a/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp b/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp
index 348df67d..c5f41656 100644
--- a/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp
+++ b/src/core/impl/graph/var_node_mem_mgr/seq_mem_opt.cpp
@@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() {
ThinHashMap chk2interval;
// get all memory chunks
+ if (StaticMemRecorder::Instance().valid()) {
+ StaticMemRecorder::Instance().clear_opr_seq();
+ }
+
for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) {
OperatorNodeBase *opr = m_cur_seq_full->at(idx);
+ if (StaticMemRecorder::Instance().valid()) {
+ StaticMemRecorder::Instance().regist_opr_seq(
+ {idx, 0, opr->name()});
+ }
+
auto &&dep_map = opr->node_prop().dep_map();
if (in_sys_alloc(opr)) {
@@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
chk.chunk->mem_alloc_status.set_static_offset(
allocator->get_start_addr(&chk));
}
+ auto& recorder = StaticMemRecorder::Instance();
+ if (recorder.valid()) {
+ for (size_t i = 0; i < chunks.size(); i++) {
+ recorder.regist_memory_chunk_owner_var_name(
+ i, chunks.at(i).chunk->owner_var->name());
+ }
+ recorder.regist_peak_mem_size(size);
+ }
}
return should_realloc;
diff --git a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h
index aae006a0..77c0ad4e 100644
--- a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h
+++ b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc.h
@@ -11,6 +11,7 @@
#pragma once
+#include "megbrain/plugin/static_mem_record.h"
#include "megbrain_build_config.h"
#include
diff --git a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp
index 2834aa35..69308008 100644
--- a/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp
+++ b/src/core/impl/graph/var_node_mem_mgr/static_mem_alloc/impl.cpp
@@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() {
check_result_and_calc_lower_bound();
+ if (StaticMemRecorder::Instance().valid()) {
+ StaticMemRecorder::Instance().clear_memory_chunk();
+ for (auto&& i : m_interval) {
+ size_t overwrite_dest_id = 0;
+ bool is_overwrite = !i->is_overwrite_root();
+ if (is_overwrite) {
+ overwrite_dest_id = i->overwrite_dest_root()->id;
+ }
+
+ StaticMemRecorder::Instance().regist_memory_chunk(
+ {i->id, i->size_orig, i->time_begin, i->time_end,
+ i->addr_begin, i->addr_end(), overwrite_dest_id,
+ is_overwrite, ""});
+ }
+ }
+
return *this;
}
diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h
index f8ee265c..ed8b5888 100644
--- a/src/core/include/megbrain/graph/bases.h
+++ b/src/core/include/megbrain/graph/bases.h
@@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable,
m_user_data.get_user_data();
return (*(output_vars_pair.first))->get_output_vars();
}
+
+ virtual void get_static_memory_alloc_info(const std::string& svg_name) {
+ mgb_assert(svg_name.length() < 0,
+ "can't call this function directly\n");
+ }
};
diff --git a/src/plugin/impl/static_mem_record.cpp b/src/plugin/impl/static_mem_record.cpp
new file mode 100644
index 00000000..cb228569
--- /dev/null
+++ b/src/plugin/impl/static_mem_record.cpp
@@ -0,0 +1,319 @@
+/**
+ * \file src/plugin/impl/static_mem_record.cpp
+ * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+ *
+ * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied.
+ */
+
+#include "megbrain/plugin/static_mem_record.h"
+#include
+#include
+
+using namespace mgb;
+using namespace cg;
+
+namespace {
+#define SVG_WIDTH 20000.0
+#define SVG_HEIGHT 15000.0
+#define OPR_RECT_WIDTH 40.0
+#define OPR_RECT_HEIGHT 20.0
+
+const std::string rect =
+ "";
+const std::string text = "{}";
+const std::string polyline =
+ "";
+const std::string opr_info =
+ "mge:type=\"opr\" mge:id=\"{}\" mge:size=\"{}\" mge:name=\"{}\"";
+const std::string chunk_info =
+ "mge:type=\"chunk\" mge:id=\"{}\" mge:time=\"{}\" mge:addr=\"{}\" "
+ "mge:size=\"{}\" mge:owner_var_name=\"{}\"";
+const std::string animate =
+ "\n";
+
+std::string& replace_by_parameter(std::string& original_str, size_t index) {
+ return original_str;
+}
+
+template
+std::string& replace_by_parameter(std::string& original_str, size_t index,
+ const std::string& parameter,
+ const Args&... args) {
+ index = original_str.find("{}", index);
+ original_str.replace(index, 2, parameter);
+ index += parameter.length();
+ replace_by_parameter(original_str, index, args...);
+ return original_str;
+}
+
+std::string set_opr_info(std::string id, std::string size, std::string name,
+ std::string info = opr_info) {
+ return replace_by_parameter(info, 0, id, size, name);
+}
+
+std::string set_chunk_info(std::string id, std::string time, std::string addr,
+ std::string size, std::string owner_var_name,
+ std::string info = chunk_info) {
+ return replace_by_parameter(info, 0, id, time, addr, size, owner_var_name);
+}
+
+std::string draw_rect(std::string x, std::string y, std::string widith,
+ std::string height, std::string color, std::string info,
+ std::string r = rect) {
+ return replace_by_parameter(r, 0, x, y, widith, height, color, info);
+}
+
+std::string draw_text(std::string x, std::string y, std::string font_size,
+ std::string txt, std::string t = text) {
+ return replace_by_parameter(t, 0, x, y, font_size, txt);
+}
+
+std::string draw_polyline(std::string point_seq, std::string color,
+ std::string width, std::string p = polyline) {
+ return replace_by_parameter(p, 0, point_seq, color, width);
+}
+} // namespace
+
+void StaticMemRecorder::dump_svg(std::string svg_name) {
+ float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT,
+ opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT;
+ float address_scale = 1;
+ size_t opr_nr = m_opr_seq_recorder.size();
+ if (opr_nr * OPR_RECT_WIDTH > SVG_WIDTH) {
+ svg_width = SVG_WIDTH;
+ opr_rect_width = svg_width / opr_nr;
+ opr_rect_height = opr_rect_width / 2;
+ } else {
+ opr_rect_width = OPR_RECT_WIDTH;
+ svg_width = opr_nr * opr_rect_width;
+ }
+ if (m_sum_mem_size > SVG_HEIGHT) {
+ svg_height = SVG_HEIGHT;
+ address_scale = svg_height / m_sum_mem_size;
+ } else {
+ svg_height = m_sum_mem_size;
+ }
+
+ // Rescale
+ float aspect_ratio = SVG_WIDTH / SVG_HEIGHT;
+ if (svg_width / svg_height < 1) {
+ svg_width = svg_height * aspect_ratio;
+ opr_rect_width = svg_width / opr_nr;
+ opr_rect_height = opr_rect_width / 2;
+ } else if (svg_width / svg_height > aspect_ratio) {
+ svg_height = svg_width / aspect_ratio;
+ address_scale = svg_height / m_sum_mem_size;
+ }
+
+ svg_height = svg_height + opr_rect_height * 2;
+
+ std::ofstream outfile;
+ outfile.open(svg_name);
+ outfile << "" << std::endl;
+ outfile << ""
+ << std::endl;
+ outfile << "" << std::endl;
+ outfile.close();
+}
+
+void StaticMemRecorder::show(std::string svg_name) {
+ for (auto&& i : m_memory_chunk_recorder) {
+ if (i.id >= m_weight_chunk_id) {
+ break;
+ }
+ size_t begin = i.time_begin, end = i.time_end;
+ if (i.is_overwrite) {
+ begin++;
+ }
+ for (size_t j = begin; j < end; j++) {
+ m_opr_seq_recorder.at(j).size += i.size_orig;
+ }
+ }
+
+ // log peak memory size, where it is reached and which chunks constitute it.
+ mgb_log("peak_mem_size = %zu\n", m_peak_mem_size);
+ size_t max_size = 0;
+ std::vector opr_ids;
+ for (auto&& i : m_opr_seq_recorder) {
+ if (i.size == max_size) {
+ opr_ids.push_back(i.id);
+ } else if (i.size > max_size) {
+ max_size = i.size;
+ opr_ids.clear();
+ opr_ids.push_back(i.id);
+ }
+ }
+
+ auto opr2chunk = get_chunk_construct(opr_ids);
+ mgb_log("oprs reach the peak memory:\n");
+ for (auto&& i : opr_ids) {
+ mgb_log("opr id = %zu\n", i);
+ }
+ mgb_log("More details:\n");
+ for (size_t i = 0; i < opr2chunk.size(); i++) {
+ mgb_log("opr id = %zu\n", opr_ids.at(i));
+ if (i + 1 < opr2chunk.size() &&
+ opr2chunk.at(i) == opr2chunk.at(i + 1)) {
+ continue;
+ }
+ for (size_t j = 0; j < opr2chunk.at(i).size(); j++) {
+ auto&& chunk = m_memory_chunk_recorder.at(opr2chunk.at(i).at(j));
+ mgb_log("[memory_chunk_id=%zu, size=%zu B, "
+ "[life_begin=%zu,life_end=%zu), owner_opr_name=%s]\n",
+ chunk.id, chunk.size_orig, chunk.time_begin, chunk.time_end,
+ m_opr_seq_recorder.at(chunk.time_begin).name.c_str());
+ }
+ }
+ dump_svg(svg_name);
+}
+
+std::vector> StaticMemRecorder::get_chunk_construct(
+ std::vector opr_ids) {
+ std::vector> chunk_ids;
+ chunk_ids.resize(opr_ids.size());
+ for (auto&& i : m_memory_chunk_recorder) {
+ if (i.id >= m_weight_chunk_id) {
+ break;
+ }
+ size_t begin = i.time_begin, end = i.time_end;
+ if (i.is_overwrite) {
+ begin = begin + 1;
+ }
+ if (opr_ids.front() >= end || opr_ids.back() < begin) {
+ continue;
+ }
+ for (size_t k = 0; k < opr_ids.size(); k++) {
+ if (opr_ids.at(k) >= end) {
+ break;
+ } else if (opr_ids.at(k) >= begin) {
+ chunk_ids.at(k).push_back(i.id);
+ }
+ }
+ }
+ return chunk_ids;
+}
\ No newline at end of file
diff --git a/src/plugin/include/megbrain/plugin/static_mem_record.h b/src/plugin/include/megbrain/plugin/static_mem_record.h
new file mode 100644
index 00000000..6276227d
--- /dev/null
+++ b/src/plugin/include/megbrain/plugin/static_mem_record.h
@@ -0,0 +1,85 @@
+/**
+ * \file src/plugin/include/megbrain/plugin/static_mem_record.h
+ * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+ *
+ * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied.
+ */
+
+#pragma once
+#include "megbrain/utils/metahelper.h"
+
+namespace mgb {
+namespace cg {
+
+class StaticMemRecorder : public NonCopyableObj {
+public:
+ static StaticMemRecorder& Instance() {
+ static StaticMemRecorder StaticMemRecorder;
+ return StaticMemRecorder;
+ }
+
+ struct opr_record {
+ size_t id, size;
+ std::string name;
+ };
+ struct memory_chunk_record {
+ size_t id, size_orig, time_begin, time_end, addr_begin,
+ addr_end, overwrite_dest_id;
+ bool is_overwrite;
+ std::string owner_var_name;
+ };
+
+ void active() { m_is_record = true; }
+
+ bool valid() { return m_is_record; }
+
+ void clear_opr_seq() { m_opr_seq_recorder.clear(); }
+
+ void regist_opr_seq(opr_record opr) { m_opr_seq_recorder.push_back(opr); }
+
+ void clear_memory_chunk() { m_memory_chunk_recorder.clear(); }
+
+ void regist_memory_chunk(memory_chunk_record mcr) {
+ m_memory_chunk_recorder.push_back(mcr);
+ }
+
+ void regist_memory_chunk_owner_var_name(size_t id, std::string name) {
+ m_memory_chunk_recorder.at(id).owner_var_name = name;
+ }
+
+ void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; }
+
+ const size_t& peak_mem_size() { return m_peak_mem_size; }
+
+ void set_sum_mem_size(size_t size) { m_sum_mem_size = size; }
+
+ const size_t& sum_mem_size() { return m_sum_mem_size; }
+
+ const size_t& set_weight_chunk_id() {
+ m_weight_chunk_id = m_memory_chunk_recorder.size();
+ return m_weight_chunk_id;
+ }
+
+ const size_t& weight_chunk_id() { return m_weight_chunk_id; }
+
+ void dump_svg(std::string svg_name);
+
+ void show(std::string svg_name);
+
+private:
+ bool m_is_record = false;
+ // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are
+ // weights memory chunks
+ size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id;
+ std::vector m_opr_seq_recorder;
+ std::vector m_memory_chunk_recorder;
+ std::vector> get_chunk_construct(
+ std::vector opr_ids);
+};
+} // namespace cg
+} // namespace mgb