@@ -0,0 +1,154 @@ | |||||
<html> | |||||
<title>Visualizer</title> | |||||
<head> | |||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /> | |||||
</head> | |||||
<script> | |||||
window.onload = () => { | |||||
var board = document.getElementById('board'); | |||||
var fileInput = document.getElementById('fileInput'); | |||||
var desc = document.getElementById('desc'); | |||||
var hRange = document.getElementById('hRange'); | |||||
var vRange = document.getElementById('vRange'); | |||||
var lastColor = undefined; | |||||
var lastElem = undefined; | |||||
var scale = 1; | |||||
var svg = undefined; | |||||
var svgWidth = undefined; | |||||
var svgHeight = undefined; | |||||
var loadDesc = (svgElem) => { | |||||
var mgeType = svgElem.attributes['mge:type']; | |||||
if (mgeType === undefined) { | |||||
return; | |||||
} | |||||
var elemList = []; | |||||
for (attrName of svgElem.getAttributeNames()) { | |||||
var prefix = 'mge:'; | |||||
if (!attrName.startsWith(prefix)) { | |||||
continue; | |||||
} | |||||
var elem = '<p>' + attrName.substr(prefix.length) + ': ' + svgElem.attributes[attrName].value + '</p>' | |||||
elemList.push(elem); | |||||
} | |||||
desc.innerHTML = elemList.join(''); | |||||
}; | |||||
var selectElem = svgElem => { | |||||
loadDesc(svgElem); | |||||
lastColor = svgElem.attributes['fill'].value; | |||||
lastElem = svgElem; | |||||
svgElem.attributes['fill'].value = 'green'; | |||||
}; | |||||
var unselectLast = svgElem => { | |||||
if (lastElem) { | |||||
lastElem.attributes['fill'].value = lastColor; | |||||
} | |||||
lastElem = undefined; | |||||
lastColor = undefined; | |||||
}; | |||||
function recLoadSVG(svgElem) { | |||||
if (svgElem.children === undefined) { | |||||
return; | |||||
} | |||||
svgElem.onmousedown = e => { | |||||
var mgeType = svgElem.attributes['mge:type']; | |||||
if (mgeType === undefined) { | |||||
return; | |||||
} | |||||
unselectLast(); | |||||
selectElem(svgElem); | |||||
e.stopPropagation(); | |||||
}; | |||||
for (child of svgElem.children) { | |||||
recLoadSVG(child); | |||||
} | |||||
} | |||||
function loadSVG() { | |||||
var file = fileInput.files[0]; | |||||
var reader = new FileReader(); | |||||
reader.readAsText(file, "UTF-8"); | |||||
reader.onload = e => { | |||||
board.innerHTML = '<p style="margin: 0;">' + e.target.result + '</p>'; | |||||
svg = board.children[0].children[0]; | |||||
svgWidth = svg.attributes['width'].value; | |||||
svgHeight = svg.attributes['height'].value; | |||||
for (child of board.children) { | |||||
recLoadSVG(child); | |||||
var svgInfo = child.attributes['svg:info']; | |||||
if (svgInfo !== undefined) { | |||||
var elemList = []; | |||||
for (attrName of child.getAttributeNames()) { | |||||
var prefix = 'svg:'; | |||||
if (!attrName.startsWith(prefix)) { | |||||
continue; | |||||
} | |||||
var elem = '<p>' + attrName.substr(prefix.length) + ': ' + child.attributes[attrName].value + '</p>' | |||||
elemList.push(elem); | |||||
} | |||||
info.innerHTML = elemList.join(''); | |||||
} | |||||
} | |||||
}; | |||||
} | |||||
function scaleBoard(x, y) { | |||||
var transform = 'scale(' + x + ',' + y + ')'; | |||||
svg.setAttribute('transform', transform); | |||||
board.style['width'] = svgWidth * x; | |||||
board.style['height'] = svgHeight * y; | |||||
} | |||||
function autoScaleBoard() { | |||||
var hRangeValue = Math.sqrt(Number(hRange.value) / 10); | |||||
var vRangeValue = Math.sqrt(Number(vRange.value) / 10); | |||||
scaleBoard(Number(hRangeValue), Number(vRangeValue)); | |||||
} | |||||
fileInput.onchange = loadSVG; | |||||
var zoomBoard = dScale => { | |||||
scale *= dScale; | |||||
scaleBoard(scale, scale); | |||||
}; | |||||
window.addEventListener('wheel', e => { | |||||
console.log(e); | |||||
if (e.ctrlKey) { | |||||
e.preventDefault(); | |||||
e.stopPropagation(); | |||||
var factor = 1; | |||||
if (e.deltaY < 0) { | |||||
factor = 1.1; | |||||
} else if (e.deltaY > 0) { | |||||
factor = 1 / 1.1; | |||||
} | |||||
zoomBoard(factor); | |||||
var newPageX = e.pageX * factor; | |||||
var newPageY = e.pageY * factor; | |||||
x = newPageX - e.x; | |||||
y = newPageY - e.y; | |||||
window.scrollTo({ | |||||
top: y, | |||||
left: x, | |||||
}); | |||||
console.log('scroll', [x, y]); | |||||
} | |||||
}, { 'passive': false }); | |||||
}; | |||||
</script> | |||||
<body> | |||||
<p id="desc" style="position: fixed;bottom: 0; background-color: white;">desc</p> | |||||
<p id="info" style="position: fixed;top: 0; right: 0; background-color: white;">info</p> | |||||
<p id="board" | |||||
style="white-space: nowrap; display: flex; justify-content: center; align-content: center; align-items: center; margin: 0;opacity: 0.7;"> | |||||
</p> | |||||
<input type='file' id='fileInput' style="position: fixed; top: 0; background-color: white;"></input> | |||||
</body> | |||||
</html> |
@@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||||
return *this; | return *this; | ||||
} | } | ||||
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | |||||
const std::string& svg_name) { | |||||
check_not_finalized(); | |||||
auto& recorder = StaticMemRecorder::Instance(); | |||||
recorder.active(); | |||||
ExecContext exec_ctx{this}; | |||||
// regist weights | |||||
size_t addr_base = recorder.peak_mem_size(); | |||||
size_t chunk_id = recorder.set_weight_chunk_id(); | |||||
for (auto&& i : *(this->m_opr_seq)) { | |||||
auto op = i->output(); | |||||
for (auto&& j : op) { | |||||
auto& mp = j->mem_plan(); | |||||
if (mp.valid()) { | |||||
auto& mc = mp.chunk(); | |||||
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||||
recorder.regist_memory_chunk( | |||||
{chunk_id++, mc.size(), 0, this->m_opr_seq->size(), | |||||
addr_base, addr_base + mc.size(), 0, false, | |||||
mc.owner_var->name()}); | |||||
addr_base += mc.size(); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
recorder.set_sum_mem_size(addr_base); | |||||
mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); | |||||
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||||
"svg_name must be end with \".svg\"\n"); | |||||
recorder.show(svg_name); | |||||
} | |||||
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | ||||
do_wait(true); | do_wait(true); | ||||
return *this; | return *this; | ||||
@@ -16,6 +16,7 @@ | |||||
#include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
#include "megbrain/plugin/var_sanity_check.h" | #include "megbrain/plugin/var_sanity_check.h" | ||||
#include "megbrain/utils/arith_helper.h" | #include "megbrain/utils/arith_helper.h" | ||||
#include "megbrain/plugin/static_mem_record.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace cg { | namespace cg { | ||||
@@ -169,6 +170,9 @@ public: | |||||
} | } | ||||
std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | ||||
void get_static_memory_alloc_info( | |||||
const std::string& svg_name = "static_mem_record.svg") override; | |||||
}; | }; | ||||
class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | ||||
@@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() { | |||||
ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ||||
// get all memory chunks | // get all memory chunks | ||||
if (StaticMemRecorder::Instance().valid()) { | |||||
StaticMemRecorder::Instance().clear_opr_seq(); | |||||
} | |||||
for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | ||||
OperatorNodeBase *opr = m_cur_seq_full->at(idx); | OperatorNodeBase *opr = m_cur_seq_full->at(idx); | ||||
if (StaticMemRecorder::Instance().valid()) { | |||||
StaticMemRecorder::Instance().regist_opr_seq( | |||||
{idx, 0, opr->name()}); | |||||
} | |||||
auto &&dep_map = opr->node_prop().dep_map(); | auto &&dep_map = opr->node_prop().dep_map(); | ||||
if (in_sys_alloc(opr)) { | if (in_sys_alloc(opr)) { | ||||
@@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( | |||||
chk.chunk->mem_alloc_status.set_static_offset( | chk.chunk->mem_alloc_status.set_static_offset( | ||||
allocator->get_start_addr(&chk)); | allocator->get_start_addr(&chk)); | ||||
} | } | ||||
auto& recorder = StaticMemRecorder::Instance(); | |||||
if (recorder.valid()) { | |||||
for (size_t i = 0; i < chunks.size(); i++) { | |||||
recorder.regist_memory_chunk_owner_var_name( | |||||
i, chunks.at(i).chunk->owner_var->name()); | |||||
} | |||||
recorder.regist_peak_mem_size(size); | |||||
} | |||||
} | } | ||||
return should_realloc; | return should_realloc; | ||||
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/plugin/static_mem_record.h" | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#include <cstddef> | #include <cstddef> | ||||
@@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { | |||||
check_result_and_calc_lower_bound(); | check_result_and_calc_lower_bound(); | ||||
if (StaticMemRecorder::Instance().valid()) { | |||||
StaticMemRecorder::Instance().clear_memory_chunk(); | |||||
for (auto&& i : m_interval) { | |||||
size_t overwrite_dest_id = 0; | |||||
bool is_overwrite = !i->is_overwrite_root(); | |||||
if (is_overwrite) { | |||||
overwrite_dest_id = i->overwrite_dest_root()->id; | |||||
} | |||||
StaticMemRecorder::Instance().regist_memory_chunk( | |||||
{i->id, i->size_orig, i->time_begin, i->time_end, | |||||
i->addr_begin, i->addr_end(), overwrite_dest_id, | |||||
is_overwrite, ""}); | |||||
} | |||||
} | |||||
return *this; | return *this; | ||||
} | } | ||||
@@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable, | |||||
m_user_data.get_user_data<OutputVarsUserData>(); | m_user_data.get_user_data<OutputVarsUserData>(); | ||||
return (*(output_vars_pair.first))->get_output_vars(); | return (*(output_vars_pair.first))->get_output_vars(); | ||||
} | } | ||||
virtual void get_static_memory_alloc_info(const std::string& svg_name) { | |||||
mgb_assert(svg_name.length() < 0, | |||||
"can't call this function directly\n"); | |||||
} | |||||
}; | }; | ||||
@@ -0,0 +1,319 @@ | |||||
/** | |||||
* \file src/plugin/impl/static_mem_record.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/plugin/static_mem_record.h" | |||||
#include <fstream> | |||||
#include <iostream> | |||||
using namespace mgb; | |||||
using namespace cg; | |||||
namespace { | |||||
#define SVG_WIDTH 20000.0 | |||||
#define SVG_HEIGHT 15000.0 | |||||
#define OPR_RECT_WIDTH 40.0 | |||||
#define OPR_RECT_HEIGHT 20.0 | |||||
const std::string rect = | |||||
"<rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"{}\" " | |||||
" {}></rect>"; | |||||
const std::string text = "<text x=\"{}\" y=\"{}\" font-size=\"{}\">{}</text>"; | |||||
const std::string polyline = | |||||
"<polyline points=\"{}\" style=\"fill:none;stroke:{};stroke-width:{}\" " | |||||
"/>"; | |||||
const std::string opr_info = | |||||
"mge:type=\"opr\" mge:id=\"{}\" mge:size=\"{}\" mge:name=\"{}\""; | |||||
const std::string chunk_info = | |||||
"mge:type=\"chunk\" mge:id=\"{}\" mge:time=\"{}\" mge:addr=\"{}\" " | |||||
"mge:size=\"{}\" mge:owner_var_name=\"{}\""; | |||||
const std::string animate = | |||||
"<animate attributeName=\"opacity\" from=\"0\" to=\"1\" " | |||||
"begin=\"{}.mouseover\" fill=\"freeze\" dur=\"1s\"/>\n<animate " | |||||
"attributeName=\"opacity\" from=\"1\" to=\"0\" begin=\"{}.mouseout\" " | |||||
"fill=\"freeze\" dur=\"1s\"/>"; | |||||
std::string& replace_by_parameter(std::string& original_str, size_t index) { | |||||
return original_str; | |||||
} | |||||
template <typename... Args> | |||||
std::string& replace_by_parameter(std::string& original_str, size_t index, | |||||
const std::string& parameter, | |||||
const Args&... args) { | |||||
index = original_str.find("{}", index); | |||||
original_str.replace(index, 2, parameter); | |||||
index += parameter.length(); | |||||
replace_by_parameter(original_str, index, args...); | |||||
return original_str; | |||||
} | |||||
std::string set_opr_info(std::string id, std::string size, std::string name, | |||||
std::string info = opr_info) { | |||||
return replace_by_parameter(info, 0, id, size, name); | |||||
} | |||||
std::string set_chunk_info(std::string id, std::string time, std::string addr, | |||||
std::string size, std::string owner_var_name, | |||||
std::string info = chunk_info) { | |||||
return replace_by_parameter(info, 0, id, time, addr, size, owner_var_name); | |||||
} | |||||
std::string draw_rect(std::string x, std::string y, std::string widith, | |||||
std::string height, std::string color, std::string info, | |||||
std::string r = rect) { | |||||
return replace_by_parameter(r, 0, x, y, widith, height, color, info); | |||||
} | |||||
std::string draw_text(std::string x, std::string y, std::string font_size, | |||||
std::string txt, std::string t = text) { | |||||
return replace_by_parameter(t, 0, x, y, font_size, txt); | |||||
} | |||||
std::string draw_polyline(std::string point_seq, std::string color, | |||||
std::string width, std::string p = polyline) { | |||||
return replace_by_parameter(p, 0, point_seq, color, width); | |||||
} | |||||
} // namespace | |||||
void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | |||||
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | |||||
float address_scale = 1; | |||||
size_t opr_nr = m_opr_seq_recorder.size(); | |||||
if (opr_nr * OPR_RECT_WIDTH > SVG_WIDTH) { | |||||
svg_width = SVG_WIDTH; | |||||
opr_rect_width = svg_width / opr_nr; | |||||
opr_rect_height = opr_rect_width / 2; | |||||
} else { | |||||
opr_rect_width = OPR_RECT_WIDTH; | |||||
svg_width = opr_nr * opr_rect_width; | |||||
} | |||||
if (m_sum_mem_size > SVG_HEIGHT) { | |||||
svg_height = SVG_HEIGHT; | |||||
address_scale = svg_height / m_sum_mem_size; | |||||
} else { | |||||
svg_height = m_sum_mem_size; | |||||
} | |||||
// Rescale | |||||
float aspect_ratio = SVG_WIDTH / SVG_HEIGHT; | |||||
if (svg_width / svg_height < 1) { | |||||
svg_width = svg_height * aspect_ratio; | |||||
opr_rect_width = svg_width / opr_nr; | |||||
opr_rect_height = opr_rect_width / 2; | |||||
} else if (svg_width / svg_height > aspect_ratio) { | |||||
svg_height = svg_width / aspect_ratio; | |||||
address_scale = svg_height / m_sum_mem_size; | |||||
} | |||||
svg_height = svg_height + opr_rect_height * 2; | |||||
std::ofstream outfile; | |||||
outfile.open(svg_name); | |||||
outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | |||||
outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | |||||
"\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | |||||
<< std::endl; | |||||
outfile << "<svg width=\"" + std::to_string(svg_width) + "\" height=\"" + | |||||
std::to_string(svg_height) + | |||||
"\" version=\"1.1\" " | |||||
"xmlns=\"http://www.w3.org/2000/svg\">" | |||||
<< std::endl; | |||||
float base_height = svg_height - opr_rect_height; | |||||
std::string peak_mem_polyline = | |||||
"0," + | |||||
std::to_string(base_height - m_peak_mem_size * address_scale) + | |||||
" " + std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + | |||||
"," + std::to_string(base_height - m_peak_mem_size * address_scale); | |||||
std::string sum_mem_polyline = | |||||
"0," + | |||||
std::to_string(base_height - m_sum_mem_size * address_scale) + " " + | |||||
std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + "," + | |||||
std::to_string(base_height - m_sum_mem_size * address_scale); | |||||
std::string memory_polyline = ""; | |||||
for (size_t i = 0; i < m_opr_seq_recorder.size(); i++) { | |||||
auto&& opr = m_opr_seq_recorder.at(i); | |||||
memory_polyline += | |||||
std::to_string((i + 0.5) * opr_rect_width) + "," + | |||||
std::to_string(base_height - opr.size * address_scale) + " "; | |||||
outfile << draw_text(std::to_string(i * opr_rect_width), | |||||
std::to_string(svg_height - opr_rect_height * 0.5), | |||||
std::to_string(opr_rect_height * 0.5), | |||||
"opr" + std::to_string(i)) | |||||
<< std::endl; | |||||
std::string opr_info = | |||||
set_opr_info( | |||||
std::to_string(opr.id), | |||||
std::to_string(opr.size) + "B(" + | |||||
std::to_string(opr.size / 1024.0 / 1024.0) + | |||||
"MiB)", | |||||
opr.name) + | |||||
" opacity=\"0\""; | |||||
outfile << draw_rect(std::to_string(i * opr_rect_width), | |||||
std::to_string(base_height), | |||||
std::to_string(opr_rect_width), | |||||
std::to_string(opr_rect_height), "white", opr_info) | |||||
<< std::endl; | |||||
} | |||||
for (size_t i = 0; i < m_memory_chunk_recorder.size(); i++) { | |||||
auto&& chunk = m_memory_chunk_recorder.at(i); | |||||
std::string chunk_info = set_chunk_info( | |||||
std::to_string(chunk.id), | |||||
"[" + std::to_string(chunk.time_begin) + "," + | |||||
std::to_string(chunk.time_end) + ")", | |||||
"[" + std::to_string(chunk.addr_begin) + "," + | |||||
std::to_string(chunk.addr_end) + ")", | |||||
std::to_string(chunk.addr_end - chunk.addr_begin) + "B(" + | |||||
std::to_string((chunk.addr_end - chunk.addr_begin) / | |||||
1024.0 / 1024.0) + | |||||
"MiB)", | |||||
chunk.owner_var_name); | |||||
outfile << draw_rect( | |||||
std::to_string(chunk.time_begin * opr_rect_width), | |||||
std::to_string(base_height - | |||||
chunk.addr_end * address_scale), | |||||
std::to_string((chunk.time_end - chunk.time_begin) * | |||||
opr_rect_width), | |||||
std::to_string((chunk.addr_end - chunk.addr_begin) * | |||||
address_scale), | |||||
"gray", chunk_info) | |||||
<< std::endl; | |||||
outfile << draw_text(std::to_string(chunk.time_begin * opr_rect_width), | |||||
std::to_string(base_height - | |||||
chunk.addr_end * address_scale + 9), | |||||
std::to_string(9), | |||||
"chunk" + std::to_string(chunk.id)) | |||||
<< std::endl; | |||||
} | |||||
outfile << draw_text("0", | |||||
std::to_string(base_height - | |||||
m_peak_mem_size * address_scale + | |||||
opr_rect_height * 0.5), | |||||
std::to_string(opr_rect_height * 0.5), | |||||
"peak_memory_size:" + std::to_string(m_peak_mem_size) + | |||||
"B(" + | |||||
std::to_string(m_peak_mem_size / 1024.0 / | |||||
1024.0) + | |||||
"MiB)") | |||||
<< std::endl; | |||||
outfile << draw_text("0", | |||||
std::to_string(base_height - | |||||
m_sum_mem_size * address_scale + | |||||
opr_rect_height * 0.5), | |||||
std::to_string(opr_rect_height * 0.5), | |||||
"sum_memory_size:" + std::to_string(m_sum_mem_size) + | |||||
"B(" + | |||||
std::to_string(m_sum_mem_size / 1024.0 / | |||||
1024.0) + | |||||
"MiB)") | |||||
<< std::endl; | |||||
outfile << draw_polyline(memory_polyline, "blue", | |||||
std::to_string(opr_rect_height * 0.1)) | |||||
<< std::endl; | |||||
outfile << draw_polyline(peak_mem_polyline, "green", | |||||
std::to_string(opr_rect_height * 0.1)) | |||||
<< std::endl; | |||||
outfile << draw_polyline(sum_mem_polyline, "red", | |||||
std::to_string(opr_rect_height * 0.1)) | |||||
<< std::endl; | |||||
outfile << "<text svg:info=\"The abscissa represents the opr sequence, the " | |||||
"ordinate represents the logical address.\" " | |||||
"svg:chunk_time=\"[opra,oprb) means the chunk is created when " | |||||
"opra execute and is freed before oprb\" " | |||||
"svg:chunk_oner_var_name=\"var that first creates this " | |||||
"chunk\"></text>" | |||||
<< std::endl; | |||||
outfile << "</svg>" << std::endl; | |||||
outfile.close(); | |||||
} | |||||
void StaticMemRecorder::show(std::string svg_name) { | |||||
for (auto&& i : m_memory_chunk_recorder) { | |||||
if (i.id >= m_weight_chunk_id) { | |||||
break; | |||||
} | |||||
size_t begin = i.time_begin, end = i.time_end; | |||||
if (i.is_overwrite) { | |||||
begin++; | |||||
} | |||||
for (size_t j = begin; j < end; j++) { | |||||
m_opr_seq_recorder.at(j).size += i.size_orig; | |||||
} | |||||
} | |||||
// log peak memory size, where it is reached and which chunks constitute it. | |||||
mgb_log("peak_mem_size = %zu\n", m_peak_mem_size); | |||||
size_t max_size = 0; | |||||
std::vector<size_t> opr_ids; | |||||
for (auto&& i : m_opr_seq_recorder) { | |||||
if (i.size == max_size) { | |||||
opr_ids.push_back(i.id); | |||||
} else if (i.size > max_size) { | |||||
max_size = i.size; | |||||
opr_ids.clear(); | |||||
opr_ids.push_back(i.id); | |||||
} | |||||
} | |||||
auto opr2chunk = get_chunk_construct(opr_ids); | |||||
mgb_log("oprs reach the peak memory:\n"); | |||||
for (auto&& i : opr_ids) { | |||||
mgb_log("opr id = %zu\n", i); | |||||
} | |||||
mgb_log("More details:\n"); | |||||
for (size_t i = 0; i < opr2chunk.size(); i++) { | |||||
mgb_log("opr id = %zu\n", opr_ids.at(i)); | |||||
if (i + 1 < opr2chunk.size() && | |||||
opr2chunk.at(i) == opr2chunk.at(i + 1)) { | |||||
continue; | |||||
} | |||||
for (size_t j = 0; j < opr2chunk.at(i).size(); j++) { | |||||
auto&& chunk = m_memory_chunk_recorder.at(opr2chunk.at(i).at(j)); | |||||
mgb_log("[memory_chunk_id=%zu, size=%zu B, " | |||||
"[life_begin=%zu,life_end=%zu), owner_opr_name=%s]\n", | |||||
chunk.id, chunk.size_orig, chunk.time_begin, chunk.time_end, | |||||
m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | |||||
} | |||||
} | |||||
dump_svg(svg_name); | |||||
} | |||||
std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | |||||
std::vector<size_t> opr_ids) { | |||||
std::vector<std::vector<size_t>> chunk_ids; | |||||
chunk_ids.resize(opr_ids.size()); | |||||
for (auto&& i : m_memory_chunk_recorder) { | |||||
if (i.id >= m_weight_chunk_id) { | |||||
break; | |||||
} | |||||
size_t begin = i.time_begin, end = i.time_end; | |||||
if (i.is_overwrite) { | |||||
begin = begin + 1; | |||||
} | |||||
if (opr_ids.front() >= end || opr_ids.back() < begin) { | |||||
continue; | |||||
} | |||||
for (size_t k = 0; k < opr_ids.size(); k++) { | |||||
if (opr_ids.at(k) >= end) { | |||||
break; | |||||
} else if (opr_ids.at(k) >= begin) { | |||||
chunk_ids.at(k).push_back(i.id); | |||||
} | |||||
} | |||||
} | |||||
return chunk_ids; | |||||
} |
@@ -0,0 +1,85 @@ | |||||
/** | |||||
* \file src/plugin/include/megbrain/plugin/static_mem_record.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb { | |||||
namespace cg { | |||||
class StaticMemRecorder : public NonCopyableObj { | |||||
public: | |||||
static StaticMemRecorder& Instance() { | |||||
static StaticMemRecorder StaticMemRecorder; | |||||
return StaticMemRecorder; | |||||
} | |||||
struct opr_record { | |||||
size_t id, size; | |||||
std::string name; | |||||
}; | |||||
struct memory_chunk_record { | |||||
size_t id, size_orig, time_begin, time_end, addr_begin, | |||||
addr_end, overwrite_dest_id; | |||||
bool is_overwrite; | |||||
std::string owner_var_name; | |||||
}; | |||||
void active() { m_is_record = true; } | |||||
bool valid() { return m_is_record; } | |||||
void clear_opr_seq() { m_opr_seq_recorder.clear(); } | |||||
void regist_opr_seq(opr_record opr) { m_opr_seq_recorder.push_back(opr); } | |||||
void clear_memory_chunk() { m_memory_chunk_recorder.clear(); } | |||||
void regist_memory_chunk(memory_chunk_record mcr) { | |||||
m_memory_chunk_recorder.push_back(mcr); | |||||
} | |||||
void regist_memory_chunk_owner_var_name(size_t id, std::string name) { | |||||
m_memory_chunk_recorder.at(id).owner_var_name = name; | |||||
} | |||||
void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | |||||
const size_t& peak_mem_size() { return m_peak_mem_size; } | |||||
void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | |||||
const size_t& sum_mem_size() { return m_sum_mem_size; } | |||||
const size_t& set_weight_chunk_id() { | |||||
m_weight_chunk_id = m_memory_chunk_recorder.size(); | |||||
return m_weight_chunk_id; | |||||
} | |||||
const size_t& weight_chunk_id() { return m_weight_chunk_id; } | |||||
void dump_svg(std::string svg_name); | |||||
void show(std::string svg_name); | |||||
private: | |||||
bool m_is_record = false; | |||||
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | |||||
// weights memory chunks | |||||
size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | |||||
std::vector<opr_record> m_opr_seq_recorder; | |||||
std::vector<memory_chunk_record> m_memory_chunk_recorder; | |||||
std::vector<std::vector<size_t>> get_chunk_construct( | |||||
std::vector<size_t> opr_ids); | |||||
}; | |||||
} // namespace cg | |||||
} // namespace mgb |