@@ -0,0 +1,154 @@ | |||
<html> | |||
<title>Visualizer</title> | |||
<head> | |||
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /> | |||
</head> | |||
<script> | |||
window.onload = () => { | |||
var board = document.getElementById('board'); | |||
var fileInput = document.getElementById('fileInput'); | |||
var desc = document.getElementById('desc'); | |||
var hRange = document.getElementById('hRange'); | |||
var vRange = document.getElementById('vRange'); | |||
var lastColor = undefined; | |||
var lastElem = undefined; | |||
var scale = 1; | |||
var svg = undefined; | |||
var svgWidth = undefined; | |||
var svgHeight = undefined; | |||
var loadDesc = (svgElem) => { | |||
var mgeType = svgElem.attributes['mge:type']; | |||
if (mgeType === undefined) { | |||
return; | |||
} | |||
var elemList = []; | |||
for (attrName of svgElem.getAttributeNames()) { | |||
var prefix = 'mge:'; | |||
if (!attrName.startsWith(prefix)) { | |||
continue; | |||
} | |||
var elem = '<p>' + attrName.substr(prefix.length) + ': ' + svgElem.attributes[attrName].value + '</p>' | |||
elemList.push(elem); | |||
} | |||
desc.innerHTML = elemList.join(''); | |||
}; | |||
var selectElem = svgElem => { | |||
loadDesc(svgElem); | |||
lastColor = svgElem.attributes['fill'].value; | |||
lastElem = svgElem; | |||
svgElem.attributes['fill'].value = 'green'; | |||
}; | |||
var unselectLast = svgElem => { | |||
if (lastElem) { | |||
lastElem.attributes['fill'].value = lastColor; | |||
} | |||
lastElem = undefined; | |||
lastColor = undefined; | |||
}; | |||
function recLoadSVG(svgElem) { | |||
if (svgElem.children === undefined) { | |||
return; | |||
} | |||
svgElem.onmousedown = e => { | |||
var mgeType = svgElem.attributes['mge:type']; | |||
if (mgeType === undefined) { | |||
return; | |||
} | |||
unselectLast(); | |||
selectElem(svgElem); | |||
e.stopPropagation(); | |||
}; | |||
for (child of svgElem.children) { | |||
recLoadSVG(child); | |||
} | |||
} | |||
function loadSVG() { | |||
var file = fileInput.files[0]; | |||
var reader = new FileReader(); | |||
reader.readAsText(file, "UTF-8"); | |||
reader.onload = e => { | |||
board.innerHTML = '<p style="margin: 0;">' + e.target.result + '</p>'; | |||
svg = board.children[0].children[0]; | |||
svgWidth = svg.attributes['width'].value; | |||
svgHeight = svg.attributes['height'].value; | |||
for (child of board.children) { | |||
recLoadSVG(child); | |||
var svgInfo = child.attributes['svg:info']; | |||
if (svgInfo !== undefined) { | |||
var elemList = []; | |||
for (attrName of child.getAttributeNames()) { | |||
var prefix = 'svg:'; | |||
if (!attrName.startsWith(prefix)) { | |||
continue; | |||
} | |||
var elem = '<p>' + attrName.substr(prefix.length) + ': ' + child.attributes[attrName].value + '</p>' | |||
elemList.push(elem); | |||
} | |||
info.innerHTML = elemList.join(''); | |||
} | |||
} | |||
}; | |||
} | |||
function scaleBoard(x, y) { | |||
var transform = 'scale(' + x + ',' + y + ')'; | |||
svg.setAttribute('transform', transform); | |||
board.style['width'] = svgWidth * x; | |||
board.style['height'] = svgHeight * y; | |||
} | |||
function autoScaleBoard() { | |||
var hRangeValue = Math.sqrt(Number(hRange.value) / 10); | |||
var vRangeValue = Math.sqrt(Number(vRange.value) / 10); | |||
scaleBoard(Number(hRangeValue), Number(vRangeValue)); | |||
} | |||
fileInput.onchange = loadSVG; | |||
var zoomBoard = dScale => { | |||
scale *= dScale; | |||
scaleBoard(scale, scale); | |||
}; | |||
window.addEventListener('wheel', e => { | |||
console.log(e); | |||
if (e.ctrlKey) { | |||
e.preventDefault(); | |||
e.stopPropagation(); | |||
var factor = 1; | |||
if (e.deltaY < 0) { | |||
factor = 1.1; | |||
} else if (e.deltaY > 0) { | |||
factor = 1 / 1.1; | |||
} | |||
zoomBoard(factor); | |||
var newPageX = e.pageX * factor; | |||
var newPageY = e.pageY * factor; | |||
x = newPageX - e.x; | |||
y = newPageY - e.y; | |||
window.scrollTo({ | |||
top: y, | |||
left: x, | |||
}); | |||
console.log('scroll', [x, y]); | |||
} | |||
}, { 'passive': false }); | |||
}; | |||
</script> | |||
<body> | |||
<p id="desc" style="position: fixed;bottom: 0; background-color: white;">desc</p> | |||
<p id="info" style="position: fixed;top: 0; right: 0; background-color: white;">info</p> | |||
<p id="board" | |||
style="white-space: nowrap; display: flex; justify-content: center; align-content: center; align-items: center; margin: 0;opacity: 0.7;"> | |||
</p> | |||
<input type='file' id='fileInput' style="position: fixed; top: 0; background-color: white;"></input> | |||
</body> | |||
</html> |
@@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||
return *this; | |||
} | |||
void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | |||
const std::string& svg_name) { | |||
check_not_finalized(); | |||
auto& recorder = StaticMemRecorder::Instance(); | |||
recorder.active(); | |||
ExecContext exec_ctx{this}; | |||
// regist weights | |||
size_t addr_base = recorder.peak_mem_size(); | |||
size_t chunk_id = recorder.set_weight_chunk_id(); | |||
for (auto&& i : *(this->m_opr_seq)) { | |||
auto op = i->output(); | |||
for (auto&& j : op) { | |||
auto& mp = j->mem_plan(); | |||
if (mp.valid()) { | |||
auto& mc = mp.chunk(); | |||
if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||
recorder.regist_memory_chunk( | |||
{chunk_id++, mc.size(), 0, this->m_opr_seq->size(), | |||
addr_base, addr_base + mc.size(), 0, false, | |||
mc.owner_var->name()}); | |||
addr_base += mc.size(); | |||
} | |||
} | |||
} | |||
} | |||
recorder.set_sum_mem_size(addr_base); | |||
mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); | |||
mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||
"svg_name must be end with \".svg\"\n"); | |||
recorder.show(svg_name); | |||
} | |||
AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | |||
do_wait(true); | |||
return *this; | |||
@@ -16,6 +16,7 @@ | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/plugin/var_sanity_check.h" | |||
#include "megbrain/utils/arith_helper.h" | |||
#include "megbrain/plugin/static_mem_record.h" | |||
namespace mgb { | |||
namespace cg { | |||
@@ -169,6 +170,9 @@ public: | |||
} | |||
std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | |||
void get_static_memory_alloc_info( | |||
const std::string& svg_name = "static_mem_record.svg") override; | |||
}; | |||
class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | |||
@@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() { | |||
ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | |||
// get all memory chunks | |||
if (StaticMemRecorder::Instance().valid()) { | |||
StaticMemRecorder::Instance().clear_opr_seq(); | |||
} | |||
for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | |||
OperatorNodeBase *opr = m_cur_seq_full->at(idx); | |||
if (StaticMemRecorder::Instance().valid()) { | |||
StaticMemRecorder::Instance().regist_opr_seq( | |||
{idx, 0, opr->name()}); | |||
} | |||
auto &&dep_map = opr->node_prop().dep_map(); | |||
if (in_sys_alloc(opr)) { | |||
@@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( | |||
chk.chunk->mem_alloc_status.set_static_offset( | |||
allocator->get_start_addr(&chk)); | |||
} | |||
auto& recorder = StaticMemRecorder::Instance(); | |||
if (recorder.valid()) { | |||
for (size_t i = 0; i < chunks.size(); i++) { | |||
recorder.regist_memory_chunk_owner_var_name( | |||
i, chunks.at(i).chunk->owner_var->name()); | |||
} | |||
recorder.regist_peak_mem_size(size); | |||
} | |||
} | |||
return should_realloc; | |||
@@ -11,6 +11,7 @@ | |||
#pragma once | |||
#include "megbrain/plugin/static_mem_record.h" | |||
#include "megbrain_build_config.h" | |||
#include <cstddef> | |||
@@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { | |||
check_result_and_calc_lower_bound(); | |||
if (StaticMemRecorder::Instance().valid()) { | |||
StaticMemRecorder::Instance().clear_memory_chunk(); | |||
for (auto&& i : m_interval) { | |||
size_t overwrite_dest_id = 0; | |||
bool is_overwrite = !i->is_overwrite_root(); | |||
if (is_overwrite) { | |||
overwrite_dest_id = i->overwrite_dest_root()->id; | |||
} | |||
StaticMemRecorder::Instance().regist_memory_chunk( | |||
{i->id, i->size_orig, i->time_begin, i->time_end, | |||
i->addr_begin, i->addr_end(), overwrite_dest_id, | |||
is_overwrite, ""}); | |||
} | |||
} | |||
return *this; | |||
} | |||
@@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable, | |||
m_user_data.get_user_data<OutputVarsUserData>(); | |||
return (*(output_vars_pair.first))->get_output_vars(); | |||
} | |||
virtual void get_static_memory_alloc_info(const std::string& svg_name) { | |||
mgb_assert(svg_name.length() < 0, | |||
"can't call this function directly\n"); | |||
} | |||
}; | |||
@@ -0,0 +1,319 @@ | |||
/** | |||
* \file src/plugin/impl/static_mem_record.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/plugin/static_mem_record.h" | |||
#include <fstream> | |||
#include <iostream> | |||
using namespace mgb; | |||
using namespace cg; | |||
namespace { | |||
#define SVG_WIDTH 20000.0 | |||
#define SVG_HEIGHT 15000.0 | |||
#define OPR_RECT_WIDTH 40.0 | |||
#define OPR_RECT_HEIGHT 20.0 | |||
const std::string rect = | |||
"<rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"{}\" " | |||
" {}></rect>"; | |||
const std::string text = "<text x=\"{}\" y=\"{}\" font-size=\"{}\">{}</text>"; | |||
const std::string polyline = | |||
"<polyline points=\"{}\" style=\"fill:none;stroke:{};stroke-width:{}\" " | |||
"/>"; | |||
const std::string opr_info = | |||
"mge:type=\"opr\" mge:id=\"{}\" mge:size=\"{}\" mge:name=\"{}\""; | |||
const std::string chunk_info = | |||
"mge:type=\"chunk\" mge:id=\"{}\" mge:time=\"{}\" mge:addr=\"{}\" " | |||
"mge:size=\"{}\" mge:owner_var_name=\"{}\""; | |||
const std::string animate = | |||
"<animate attributeName=\"opacity\" from=\"0\" to=\"1\" " | |||
"begin=\"{}.mouseover\" fill=\"freeze\" dur=\"1s\"/>\n<animate " | |||
"attributeName=\"opacity\" from=\"1\" to=\"0\" begin=\"{}.mouseout\" " | |||
"fill=\"freeze\" dur=\"1s\"/>"; | |||
std::string& replace_by_parameter(std::string& original_str, size_t index) { | |||
return original_str; | |||
} | |||
template <typename... Args> | |||
std::string& replace_by_parameter(std::string& original_str, size_t index, | |||
const std::string& parameter, | |||
const Args&... args) { | |||
index = original_str.find("{}", index); | |||
original_str.replace(index, 2, parameter); | |||
index += parameter.length(); | |||
replace_by_parameter(original_str, index, args...); | |||
return original_str; | |||
} | |||
std::string set_opr_info(std::string id, std::string size, std::string name, | |||
std::string info = opr_info) { | |||
return replace_by_parameter(info, 0, id, size, name); | |||
} | |||
std::string set_chunk_info(std::string id, std::string time, std::string addr, | |||
std::string size, std::string owner_var_name, | |||
std::string info = chunk_info) { | |||
return replace_by_parameter(info, 0, id, time, addr, size, owner_var_name); | |||
} | |||
std::string draw_rect(std::string x, std::string y, std::string widith, | |||
std::string height, std::string color, std::string info, | |||
std::string r = rect) { | |||
return replace_by_parameter(r, 0, x, y, widith, height, color, info); | |||
} | |||
std::string draw_text(std::string x, std::string y, std::string font_size, | |||
std::string txt, std::string t = text) { | |||
return replace_by_parameter(t, 0, x, y, font_size, txt); | |||
} | |||
std::string draw_polyline(std::string point_seq, std::string color, | |||
std::string width, std::string p = polyline) { | |||
return replace_by_parameter(p, 0, point_seq, color, width); | |||
} | |||
} // namespace | |||
void StaticMemRecorder::dump_svg(std::string svg_name) { | |||
float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | |||
opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | |||
float address_scale = 1; | |||
size_t opr_nr = m_opr_seq_recorder.size(); | |||
if (opr_nr * OPR_RECT_WIDTH > SVG_WIDTH) { | |||
svg_width = SVG_WIDTH; | |||
opr_rect_width = svg_width / opr_nr; | |||
opr_rect_height = opr_rect_width / 2; | |||
} else { | |||
opr_rect_width = OPR_RECT_WIDTH; | |||
svg_width = opr_nr * opr_rect_width; | |||
} | |||
if (m_sum_mem_size > SVG_HEIGHT) { | |||
svg_height = SVG_HEIGHT; | |||
address_scale = svg_height / m_sum_mem_size; | |||
} else { | |||
svg_height = m_sum_mem_size; | |||
} | |||
// Rescale | |||
float aspect_ratio = SVG_WIDTH / SVG_HEIGHT; | |||
if (svg_width / svg_height < 1) { | |||
svg_width = svg_height * aspect_ratio; | |||
opr_rect_width = svg_width / opr_nr; | |||
opr_rect_height = opr_rect_width / 2; | |||
} else if (svg_width / svg_height > aspect_ratio) { | |||
svg_height = svg_width / aspect_ratio; | |||
address_scale = svg_height / m_sum_mem_size; | |||
} | |||
svg_height = svg_height + opr_rect_height * 2; | |||
std::ofstream outfile; | |||
outfile.open(svg_name); | |||
outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | |||
outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | |||
"\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | |||
<< std::endl; | |||
outfile << "<svg width=\"" + std::to_string(svg_width) + "\" height=\"" + | |||
std::to_string(svg_height) + | |||
"\" version=\"1.1\" " | |||
"xmlns=\"http://www.w3.org/2000/svg\">" | |||
<< std::endl; | |||
float base_height = svg_height - opr_rect_height; | |||
std::string peak_mem_polyline = | |||
"0," + | |||
std::to_string(base_height - m_peak_mem_size * address_scale) + | |||
" " + std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + | |||
"," + std::to_string(base_height - m_peak_mem_size * address_scale); | |||
std::string sum_mem_polyline = | |||
"0," + | |||
std::to_string(base_height - m_sum_mem_size * address_scale) + " " + | |||
std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + "," + | |||
std::to_string(base_height - m_sum_mem_size * address_scale); | |||
std::string memory_polyline = ""; | |||
for (size_t i = 0; i < m_opr_seq_recorder.size(); i++) { | |||
auto&& opr = m_opr_seq_recorder.at(i); | |||
memory_polyline += | |||
std::to_string((i + 0.5) * opr_rect_width) + "," + | |||
std::to_string(base_height - opr.size * address_scale) + " "; | |||
outfile << draw_text(std::to_string(i * opr_rect_width), | |||
std::to_string(svg_height - opr_rect_height * 0.5), | |||
std::to_string(opr_rect_height * 0.5), | |||
"opr" + std::to_string(i)) | |||
<< std::endl; | |||
std::string opr_info = | |||
set_opr_info( | |||
std::to_string(opr.id), | |||
std::to_string(opr.size) + "B(" + | |||
std::to_string(opr.size / 1024.0 / 1024.0) + | |||
"MiB)", | |||
opr.name) + | |||
" opacity=\"0\""; | |||
outfile << draw_rect(std::to_string(i * opr_rect_width), | |||
std::to_string(base_height), | |||
std::to_string(opr_rect_width), | |||
std::to_string(opr_rect_height), "white", opr_info) | |||
<< std::endl; | |||
} | |||
for (size_t i = 0; i < m_memory_chunk_recorder.size(); i++) { | |||
auto&& chunk = m_memory_chunk_recorder.at(i); | |||
std::string chunk_info = set_chunk_info( | |||
std::to_string(chunk.id), | |||
"[" + std::to_string(chunk.time_begin) + "," + | |||
std::to_string(chunk.time_end) + ")", | |||
"[" + std::to_string(chunk.addr_begin) + "," + | |||
std::to_string(chunk.addr_end) + ")", | |||
std::to_string(chunk.addr_end - chunk.addr_begin) + "B(" + | |||
std::to_string((chunk.addr_end - chunk.addr_begin) / | |||
1024.0 / 1024.0) + | |||
"MiB)", | |||
chunk.owner_var_name); | |||
outfile << draw_rect( | |||
std::to_string(chunk.time_begin * opr_rect_width), | |||
std::to_string(base_height - | |||
chunk.addr_end * address_scale), | |||
std::to_string((chunk.time_end - chunk.time_begin) * | |||
opr_rect_width), | |||
std::to_string((chunk.addr_end - chunk.addr_begin) * | |||
address_scale), | |||
"gray", chunk_info) | |||
<< std::endl; | |||
outfile << draw_text(std::to_string(chunk.time_begin * opr_rect_width), | |||
std::to_string(base_height - | |||
chunk.addr_end * address_scale + 9), | |||
std::to_string(9), | |||
"chunk" + std::to_string(chunk.id)) | |||
<< std::endl; | |||
} | |||
outfile << draw_text("0", | |||
std::to_string(base_height - | |||
m_peak_mem_size * address_scale + | |||
opr_rect_height * 0.5), | |||
std::to_string(opr_rect_height * 0.5), | |||
"peak_memory_size:" + std::to_string(m_peak_mem_size) + | |||
"B(" + | |||
std::to_string(m_peak_mem_size / 1024.0 / | |||
1024.0) + | |||
"MiB)") | |||
<< std::endl; | |||
outfile << draw_text("0", | |||
std::to_string(base_height - | |||
m_sum_mem_size * address_scale + | |||
opr_rect_height * 0.5), | |||
std::to_string(opr_rect_height * 0.5), | |||
"sum_memory_size:" + std::to_string(m_sum_mem_size) + | |||
"B(" + | |||
std::to_string(m_sum_mem_size / 1024.0 / | |||
1024.0) + | |||
"MiB)") | |||
<< std::endl; | |||
outfile << draw_polyline(memory_polyline, "blue", | |||
std::to_string(opr_rect_height * 0.1)) | |||
<< std::endl; | |||
outfile << draw_polyline(peak_mem_polyline, "green", | |||
std::to_string(opr_rect_height * 0.1)) | |||
<< std::endl; | |||
outfile << draw_polyline(sum_mem_polyline, "red", | |||
std::to_string(opr_rect_height * 0.1)) | |||
<< std::endl; | |||
outfile << "<text svg:info=\"The abscissa represents the opr sequence, the " | |||
"ordinate represents the logical address.\" " | |||
"svg:chunk_time=\"[opra,oprb) means the chunk is created when " | |||
"opra execute and is freed before oprb\" " | |||
"svg:chunk_oner_var_name=\"var that first creates this " | |||
"chunk\"></text>" | |||
<< std::endl; | |||
outfile << "</svg>" << std::endl; | |||
outfile.close(); | |||
} | |||
void StaticMemRecorder::show(std::string svg_name) { | |||
for (auto&& i : m_memory_chunk_recorder) { | |||
if (i.id >= m_weight_chunk_id) { | |||
break; | |||
} | |||
size_t begin = i.time_begin, end = i.time_end; | |||
if (i.is_overwrite) { | |||
begin++; | |||
} | |||
for (size_t j = begin; j < end; j++) { | |||
m_opr_seq_recorder.at(j).size += i.size_orig; | |||
} | |||
} | |||
// log peak memory size, where it is reached and which chunks constitute it. | |||
mgb_log("peak_mem_size = %zu\n", m_peak_mem_size); | |||
size_t max_size = 0; | |||
std::vector<size_t> opr_ids; | |||
for (auto&& i : m_opr_seq_recorder) { | |||
if (i.size == max_size) { | |||
opr_ids.push_back(i.id); | |||
} else if (i.size > max_size) { | |||
max_size = i.size; | |||
opr_ids.clear(); | |||
opr_ids.push_back(i.id); | |||
} | |||
} | |||
auto opr2chunk = get_chunk_construct(opr_ids); | |||
mgb_log("oprs reach the peak memory:\n"); | |||
for (auto&& i : opr_ids) { | |||
mgb_log("opr id = %zu\n", i); | |||
} | |||
mgb_log("More details:\n"); | |||
for (size_t i = 0; i < opr2chunk.size(); i++) { | |||
mgb_log("opr id = %zu\n", opr_ids.at(i)); | |||
if (i + 1 < opr2chunk.size() && | |||
opr2chunk.at(i) == opr2chunk.at(i + 1)) { | |||
continue; | |||
} | |||
for (size_t j = 0; j < opr2chunk.at(i).size(); j++) { | |||
auto&& chunk = m_memory_chunk_recorder.at(opr2chunk.at(i).at(j)); | |||
mgb_log("[memory_chunk_id=%zu, size=%zu B, " | |||
"[life_begin=%zu,life_end=%zu), owner_opr_name=%s]\n", | |||
chunk.id, chunk.size_orig, chunk.time_begin, chunk.time_end, | |||
m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | |||
} | |||
} | |||
dump_svg(svg_name); | |||
} | |||
std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | |||
std::vector<size_t> opr_ids) { | |||
std::vector<std::vector<size_t>> chunk_ids; | |||
chunk_ids.resize(opr_ids.size()); | |||
for (auto&& i : m_memory_chunk_recorder) { | |||
if (i.id >= m_weight_chunk_id) { | |||
break; | |||
} | |||
size_t begin = i.time_begin, end = i.time_end; | |||
if (i.is_overwrite) { | |||
begin = begin + 1; | |||
} | |||
if (opr_ids.front() >= end || opr_ids.back() < begin) { | |||
continue; | |||
} | |||
for (size_t k = 0; k < opr_ids.size(); k++) { | |||
if (opr_ids.at(k) >= end) { | |||
break; | |||
} else if (opr_ids.at(k) >= begin) { | |||
chunk_ids.at(k).push_back(i.id); | |||
} | |||
} | |||
} | |||
return chunk_ids; | |||
} |
@@ -0,0 +1,85 @@ | |||
/** | |||
* \file src/plugin/include/megbrain/plugin/static_mem_record.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb { | |||
namespace cg { | |||
class StaticMemRecorder : public NonCopyableObj { | |||
public: | |||
static StaticMemRecorder& Instance() { | |||
static StaticMemRecorder StaticMemRecorder; | |||
return StaticMemRecorder; | |||
} | |||
struct opr_record { | |||
size_t id, size; | |||
std::string name; | |||
}; | |||
struct memory_chunk_record { | |||
size_t id, size_orig, time_begin, time_end, addr_begin, | |||
addr_end, overwrite_dest_id; | |||
bool is_overwrite; | |||
std::string owner_var_name; | |||
}; | |||
void active() { m_is_record = true; } | |||
bool valid() { return m_is_record; } | |||
void clear_opr_seq() { m_opr_seq_recorder.clear(); } | |||
void regist_opr_seq(opr_record opr) { m_opr_seq_recorder.push_back(opr); } | |||
void clear_memory_chunk() { m_memory_chunk_recorder.clear(); } | |||
void regist_memory_chunk(memory_chunk_record mcr) { | |||
m_memory_chunk_recorder.push_back(mcr); | |||
} | |||
void regist_memory_chunk_owner_var_name(size_t id, std::string name) { | |||
m_memory_chunk_recorder.at(id).owner_var_name = name; | |||
} | |||
void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | |||
const size_t& peak_mem_size() { return m_peak_mem_size; } | |||
void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | |||
const size_t& sum_mem_size() { return m_sum_mem_size; } | |||
const size_t& set_weight_chunk_id() { | |||
m_weight_chunk_id = m_memory_chunk_recorder.size(); | |||
return m_weight_chunk_id; | |||
} | |||
const size_t& weight_chunk_id() { return m_weight_chunk_id; } | |||
void dump_svg(std::string svg_name); | |||
void show(std::string svg_name); | |||
private: | |||
bool m_is_record = false; | |||
// All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | |||
// weights memory chunks | |||
size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | |||
std::vector<opr_record> m_opr_seq_recorder; | |||
std::vector<memory_chunk_record> m_memory_chunk_recorder; | |||
std::vector<std::vector<size_t>> get_chunk_construct( | |||
std::vector<size_t> opr_ids); | |||
}; | |||
} // namespace cg | |||
} // namespace mgb |