Browse Source

feat(mge/profiler): add compatible and graphviz mode

GitOrigin-RevId: 40ca9ea80e
release-1.1
Megvii Engine Team 4 years ago
parent
commit
2c9fa7f650
5 changed files with 435 additions and 134 deletions
  1. +205
    -87
      imperative/python/megengine/utils/profiler.py
  2. +18
    -8
      imperative/python/src/utils.cpp
  3. +16
    -2
      imperative/src/impl/function_hook.h
  4. +137
    -16
      imperative/src/impl/profiler.cpp
  5. +59
    -21
      imperative/src/include/megbrain/imperative/profiler.h

+ 205
- 87
imperative/python/megengine/utils/profiler.py View File

@@ -9,13 +9,155 @@
import base64 import base64
import json import json
import os import os
from typing import List, Optional
import re
from typing import Iterable, List, Optional


from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry
from ..core._imperative_rt import ProfilerImpl as _Profiler from ..core._imperative_rt import ProfilerImpl as _Profiler
from ..core._imperative_rt.imperative import sync from ..core._imperative_rt.imperative import sync
from ..core._imperative_rt.ops import CollectiveCommMode from ..core._imperative_rt.ops import CollectiveCommMode
from ..core.ops.builtin import GetVarShape


def _make_dict(**kwargs):
unused_keys = []
for k, v in kwargs.items():
if v is None:
unused_keys.append(k)
for k in unused_keys:
del kwargs[k]
return kwargs


def _print_opnode_config(config):
return _make_dict(
name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr,
)


def _dump_chrome_timeline(entries: List[ProfileEntry], path: str):
pid = os.getpid()
trace_events = []

def append_event(**kwargs):
trace_events.append(_make_dict(**kwargs))

for id, entry in enumerate(entries):
op = entry.op
name = type(op).__name__
host_begin, host_end = entry.host
device_list = entry.device_list
args = Profiler.fetch_attrs(op)
args["__id__"] = "[{}]".format(id)
cat = name
for ts, ph in [(host_begin, "B"), (host_end, "E")]:
append_event(
name=name, ph=ph, ts=ts * 1000, pid=pid, tid="host", args=args, cat=cat,
)
for device, device_begin, device_end in device_list:
for ts, ph in [(device_begin(), "B"), (device_end(), "E")]:
append_event(
name=name, ph=ph, ts=ts * 1000, pid=pid, tid=str(device), args=args,
)
with open("{}.chrome_timeline.json".format(path), "w") as f:
json.dump(trace_events, f, indent=2)


def _dump_compatible(entries: List[ProfileEntry], path: str):
obj = {
"graph_exec": {"var": [], "operator": {}},
"profiler": {"device": {}, "host": {}, "opr_footprint": {}},
}
var_list = obj["graph_exec"]["var"]
operator_dict = obj["graph_exec"]["operator"]
device_dict = obj["profiler"]["device"]
host_dict = obj["profiler"]["host"]
opr_foot_print_dict = obj["profiler"]["opr_footprint"]

def add_var(var) -> int:
var_id = len(var_list)
var_list.append(
{"comp_node": str(var[2]),}
)
return var_id

for op_id, entry in enumerate(entries):
operator_dict[op_id] = {
"input": [add_var(var) for var in entry.inputs],
"output": [add_var(var) for var in entry.outputs],
"name": str(entry.op.ctype()),
"type": "imperative",
"id": entry.id,
}
op_device_dict = {}
for device, device_begin, device_end in entry.device_list:
op_device_dict[str(device)] = {
"start": device_begin(),
"kern": device_begin(),
"end": device_end(),
}
device_dict[op_id] = op_device_dict
host_begin, host_end = entry.host
host_dict[op_id] = {
"host": {"start": host_begin, "kern": host_begin, "end": host_end}
}
opr_footprint = {
"out_shapes": [oup[1] for oup in entry.outputs],
"in_shapes": [inp[1] for inp in entry.inputs],
"params": {},
}
if entry.memory > 0:
opr_footprint["memory"] = entry.memory
if entry.computation > 0:
opr_footprint["computation"] = entry.computation
opr_foot_print_dict[op_id] = opr_footprint
with open("{}.compatible.json".format(path), "w") as f:
json.dump(obj, f, indent=2)


def _dump_graphviz(entries: List[ProfileEntry], path: str):
import graphviz
import json

graph = graphviz.Digraph()
graph.graph_attr["ordering"] = "out"
var_cache = {}

def cache_var(var_id, var_shape):
if var_id not in var_cache:
var_name = "var({})".format(var_id)
var_label = "{}\nshape:{}\n".format(var_name, shape)
graph.node(var_name, var_label)
var_cache[var_id] = var_name
return var_cache[var_id]

for op_id, entry in enumerate(entries):
op = entry.op
op_name = "op({})".format(op_id)
op_type = type(op).__name__
op_attrs = Profiler.fetch_attrs(op)
label_lines = []
if "param" in op_attrs:
del op_attrs["param"]
label_lines.append("{}:{}".format(op_name, op_type))
for k, v in op_attrs.items():
label_lines.append("attr[{}]: {}".format(k, v))
op_param_str = entry.param
if len(op_param_str) > 0:
op_param = json.loads(op_param_str)
for k, v in op_param.items():
label_lines.append("param[{}]:{}".format(k, v))
host_begin, host_end = entry.host
label_lines.append("time[host]: {:f}ms".format(host_end - host_begin))
for device, device_begin, device_end in entry.device_list:
device_time = device_end() - device_begin()
label_lines.append("time[{}]: {:f}ms".format(device, device_time))
op_label = "\n".join(label_lines)
graph.node(op_name, op_label, shape="rectangle")
for var_id, shape, device in entry.inputs:
graph.edge(cache_var(var_id, shape), op_name)
for var_id, shape, device in entry.outputs:
graph.edge(op_name, cache_var(var_id, shape))
graph.save("{}.graphviz.dot".format(path))




class Profiler: class Profiler:
@@ -23,7 +165,7 @@ class Profiler:
Profile graph execution in imperative mode. Profile graph execution in imperative mode.


:type path: Optional[str] :type path: Optional[str]
:param path: default path for profiler to dump.
:param path: default path prefix for profiler to dump.


Examples: Examples:


@@ -31,59 +173,67 @@ class Profiler:


import megengine as mge import megengine as mge
import megengine.module as M import megengine.module as M
import megengine.utils.profiler.Profiler
from megengine.utils.profiler import Profiler


# With Learnable Parameters # With Learnable Parameters
for iter in range(0, 10): for iter in range(0, 10):
# Only profile record of last iter would be saved # Only profile record of last iter would be saved
with Profiler("profile.json"):
with Profiler("profile"):
# your code here # your code here
# Then open the profile file in chrome timeline window # Then open the profile file in chrome timeline window
""" """


# see https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html
GOOD = "good"
BAD = "bad"
TERRIBLE = "terrible"
CHROME_TIMELINE = "chrome_timeline"
COMPATIBLE = "compatible"
GRAPHVIZ = "graphviz"

WITH_FOOTPRINT = 1


BLACK = "black"
GREY = "grey"
WHITE = "white"
YELLOW = "yellow"
OLIVE = "olive"
_type_map = {
OperatorNodeConfig: lambda x: _print_opnode_config(x),
bytes: lambda x: base64.encodebytes(x).decode("ascii"),
CollectiveCommMode: lambda x: str(x),
}


def __init__(self, path: str = "profile.json"):
_dumper_map = {
CHROME_TIMELINE: _dump_chrome_timeline,
COMPATIBLE: _dump_compatible,
GRAPHVIZ: _dump_graphviz,
}

def __init__(
self,
path: str = "profile",
*,
formats: Iterable[str] = (CHROME_TIMELINE,),
type_filter: str = ".*",
exit_dump: bool = True
) -> None:
self._impl = _Profiler() self._impl = _Profiler()
self._path = path self._path = path
self._color_map = {}
self._type_map = {
OperatorNodeConfig: lambda x: self.print_opnode_config(x),
bytes: lambda x: base64.encodebytes(x).decode("ascii"),
CollectiveCommMode: lambda x: str(x),
}

if isinstance(formats, str):
formats = (formats,)

self._filter = type_filter
self._dumpers = [Profiler._dumper_map[fmt] for fmt in formats]
self._exit_dump = exit_dump


def __enter__(self): def __enter__(self):
sync() sync()
self._impl.start()
self._impl.start(Profiler.WITH_FOOTPRINT)
return self return self


def __exit__(self, val, type, trace):
def __exit__(self, val, tp, trace):
if self._exit_dump:
self.dump()
sync() sync()
self._impl.stop() self._impl.stop()
if self._path is not None:
self.dump()

def recolor(self, target: str, color: str):
self._color_map[target] = color
return self
self._impl.clear()


def print_opnode_config(self, config):
return self.make_dict(
name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr,
)

def fetch_attrs(self, op):
@classmethod
def fetch_attrs(cls, op):
attrs = dir(op) attrs = dir(op)
results = {} results = {}
for attr in attrs: for attr in attrs:
@@ -93,61 +243,29 @@ class Profiler:
if callable(value): if callable(value):
continue continue
value_type = type(value) value_type = type(value)
if value_type in self._type_map:
value = self._type_map[value_type](value)
if value_type in cls._type_map:
value = cls._type_map[value_type](value)
results[attr] = value results[attr] = value
return results return results


def make_dict(self, **kwargs):
unused_keys = []
for k, v in kwargs.items():
if v is None:
unused_keys.append(k)
for k in unused_keys:
del kwargs[k]
return kwargs

def dump(self, path: Optional[str] = None): def dump(self, path: Optional[str] = None):
pid = os.getpid()
sync()
raw = [
entry
for entry in self._impl.dump()
if re.match(self._filter, type(entry.op).__name__)
]
if path is None: if path is None:
path = self._path path = self._path
trace_events = []

def append_event(**kwargs):
trace_events.append(self.make_dict(**kwargs))

entries: List[ProfileEntry] = self._impl.dump()

for id, entry in enumerate(entries):
op = entry.op
name = type(op).__name__
host_begin, host_end = entry.host
device_list = entry.device_list
args = self.fetch_attrs(op)
args["__id__"] = "[{}]".format(id)
cname = self._color_map[name] if name in self._color_map else None
cat = name
for ts, ph in [(host_begin, "B"), (host_end, "E")]:
append_event(
name=name,
ph=ph,
ts=ts * 1000,
pid=pid,
tid="host",
args=args,
cname=cname,
cat=cat,
)
for device, device_begin, device_end in device_list:
for ts, ph in [(device_begin(), "B"), (device_end(), "E")]:
append_event(
name=name,
ph=ph,
ts=ts * 1000,
pid=pid,
tid=str(device),
args=args,
cname=cname,
)
with open(path, "w") as f:
json.dump(trace_events, f, indent=2)
for dumper in self._dumpers:
dumper(raw, path)

def __call__(self, func):
def wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)

return wrapper


profile = Profiler

+ 18
- 8
imperative/python/src/utils.cpp View File

@@ -204,17 +204,27 @@ void init_utils(py::module m) {
py::class_<ProfileEntry>(m, "ProfileEntry") py::class_<ProfileEntry>(m, "ProfileEntry")
.def_readwrite("op", &ProfileEntry::op) .def_readwrite("op", &ProfileEntry::op)
.def_readwrite("host", &ProfileEntry::host) .def_readwrite("host", &ProfileEntry::host)
.def_readwrite("device_list", &ProfileEntry::device_list);
.def_readwrite("device_list", &ProfileEntry::device_list)
.def_readwrite("inputs", &ProfileEntry::inputs)
.def_readwrite("outputs", &ProfileEntry::outputs)
.def_readwrite("id", &ProfileEntry::id)
.def_readwrite("parent", &ProfileEntry::parent)
.def_readwrite("memory", &ProfileEntry::memory)
.def_readwrite("computation", &ProfileEntry::computation)
.def_property_readonly("param", [](ProfileEntry& self)->std::string{
if(self.param){
return self.param->to_string();
} else {
return {};
}
});


py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl") py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl")
.def(py::init<>()) .def(py::init<>())
.def("start",
[](mgb::imperative::Profiler& profiler) { profiler.start(); })
.def("stop",
[](mgb::imperative::Profiler& profiler) { profiler.stop(); })
.def("dump", [](mgb::imperative::Profiler& profiler) {
return profiler.get_profile();
});
.def("start", &mgb::imperative::Profiler::start)
.def("stop", &mgb::imperative::Profiler::stop)
.def("clear", &mgb::imperative::Profiler::clear)
.def("dump", &mgb::imperative::Profiler::get_profile);


using mgb::imperative::TensorSanityCheck; using mgb::imperative::TensorSanityCheck;
py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl") py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")


imperative/src/include/megbrain/imperative/function_hook.h → imperative/src/impl/function_hook.h View File

@@ -15,6 +15,7 @@


namespace mgb { namespace mgb {
namespace imperative { namespace imperative {

template <typename TFunction> template <typename TFunction>
class FunctionHooker; class FunctionHooker;


@@ -22,13 +23,18 @@ template <typename TRet, typename... TArgs>
class FunctionHooker<TRet(TArgs...)> { class FunctionHooker<TRet(TArgs...)> {
public: public:
using FunctionType = thin_function<TRet(TArgs&&...)>; using FunctionType = thin_function<TRet(TArgs&&...)>;
//Type of hooks. Hook should accept a real function as argument
//and invoke it on an appropriate time
using HookType = thin_function<TRet(FunctionType, TArgs&&...)>; using HookType = thin_function<TRet(FunctionType, TArgs&&...)>;
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} {}
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} {
m_backup = {nullptr, [](FunctionType*){}};
}


public: public:
FunctionHooker& apply_hook(HookType&& hook) { FunctionHooker& apply_hook(HookType&& hook) {
if (!m_backup) { if (!m_backup) {
FunctionType* backup = new FunctionType(*m_fptr); FunctionType* backup = new FunctionType(*m_fptr);
//Restore hooked function, would be invoked when destructed
std::function<void(FunctionType*)> restorer = std::function<void(FunctionType*)> restorer =
[fptr = m_fptr](FunctionType* bkp) -> void { [fptr = m_fptr](FunctionType* bkp) -> void {
*fptr = *bkp; *fptr = *bkp;
@@ -36,9 +42,11 @@ public:
}; };
m_backup = decltype(m_backup)(backup, restorer); m_backup = decltype(m_backup)(backup, restorer);
} }
//Replace with hooked version
*m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet { *m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet {
return hook(func, std::forward<TArgs>(args)...); return hook(func, std::forward<TArgs>(args)...);
}; };
//Convinent for chain call
return *this; return *this;
} }


@@ -47,9 +55,15 @@ private:
std::unique_ptr<FunctionType, std::function<void(FunctionType*)>> m_backup; std::unique_ptr<FunctionType, std::function<void(FunctionType*)>> m_backup;
}; };


//Helps to deduce template args
template <typename TRet, typename... TArgs> template <typename TRet, typename... TArgs>
FunctionHooker(thin_function<TRet(TArgs...)>* f) FunctionHooker(thin_function<TRet(TArgs...)>* f)
->FunctionHooker<TRet(TArgs...)>; ->FunctionHooker<TRet(TArgs...)>;
} // namespace imperative


template<typename TSignature>
auto make_shared_hook(thin_function<TSignature>* fptr){
return std::make_shared<FunctionHooker<TSignature>>(fptr);
}

} // namespace imperative
} // namespace mgb } // namespace mgb

+ 137
- 16
imperative/src/impl/profiler.cpp View File

@@ -11,19 +11,20 @@


#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"


#include <variant>

#include "./function_hook.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/physical_tensor.h" #include "megbrain/imperative/physical_tensor.h"


#include "megbrain/plugin/opr_footprint.h"

#include "./event_pool.h" #include "./event_pool.h"
#include "./op_trait.h" #include "./op_trait.h"


namespace mgb { namespace mgb {

namespace imperative { namespace imperative {


namespace { namespace {

CompNode::UnorderedSet collect_comp_nodes( CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
CompNode::UnorderedSet comp_nodes; CompNode::UnorderedSet comp_nodes;
@@ -36,37 +37,101 @@ CompNode::UnorderedSet collect_comp_nodes(
return comp_nodes; return comp_nodes;
} }


DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) {
auto event = EventPool::with_timer().alloc_shared(device);
event->record();
return event;
}

OprFootprint footprint{};

} // namespace } // namespace


void DeviceTimer::reset(thin_function<double()> host_timer) { void DeviceTimer::reset(thin_function<double()> host_timer) {
CompNode::foreach ([this, host_timer](CompNode device) { CompNode::foreach ([this, host_timer](CompNode device) {
auto base_event = EventPool::with_timer().alloc_shared(device);
base_event->record();
m_base_event_table[device] = {std::move(base_event), host_timer()};
m_base_event_table[device] = {alloc_recorded_event(device), host_timer()};
}); });
m_host_timer = host_timer;
} }


thin_function<double()> DeviceTimer::get_device_time(CompNode device) { thin_function<double()> DeviceTimer::get_device_time(CompNode device) {
auto event = EventPool::with_timer().alloc_shared(device); auto event = EventPool::with_timer().alloc_shared(device);
event->record(); event->record();
if(m_base_event_table.count(device) == 0) {
m_base_event_table[device] = {alloc_recorded_event(device), m_host_timer()};
}
auto base = m_base_event_table[device]; auto base = m_base_event_table[device];
return [base, event] { return [base, event] {
auto [base_event, host_time] = base; auto [base_event, host_time] = base;
//TODO: sync once for each compnode
// TODO: sync once for each compnode
event->host_wait(); event->host_wait();
return base_event->elapsed_time_until(*event) * 1000 + host_time; return base_event->elapsed_time_until(*event) * 1000 + host_time;
}; };
} }


void Profiler::start() {
void DeviceTimer::clear() {
m_base_event_table.clear();
}

size_t TensorRecorder::record_tensor(const TensorPtr& tensor) {
if (m_tensor_map.count(tensor.get()) > 0) {
auto& [prev, id] = m_tensor_map[tensor.get()];
if (prev.lock() != tensor) {
prev = tensor;
id = m_next_id++;
}
return id;
} else {
auto id = m_next_id++;
m_tensor_map.insert({tensor.get(), {std::weak_ptr{tensor}, id}});
return id;
}
}

void TensorRecorder::clear() {
m_next_id = 0;
m_tensor_map.clear();
}

Profile& Profiler::get_profile() {
for (auto& entry : m_profile) {
for (auto& [device, device_begin, device_end] : entry.device_list) {
MGB_MARK_USED_VAR(device);
device_begin = [value = device_begin()] { return value; };
device_end = [value = device_end()] { return value; };
}
}
return m_profile;
}

void Profiler::start(uint32_t flags) {
m_host_timer.reset(); m_host_timer.reset();
m_device_timer.reset([&]{ return m_host_timer.get_msecs();} );
OpTrait::for_each_trait([this](OpTrait& trait) {
FunctionHooker hooker{&trait.apply_on_physical_tensor};
hooker.apply_hook([this](auto&& apply, const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
m_device_timer.reset([&] { return m_host_timer.get_msecs(); });
OpTrait::for_each_trait([this, flags](OpTrait& trait) {
auto hook_apply_on_physical_tensor =
make_shared_hook(&trait.apply_on_physical_tensor);
auto hook_apply_on_var_node =
make_shared_hook(&trait.apply_on_var_node);
hook_apply_on_physical_tensor->apply_hook([this, flags]
(auto&& apply, const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto shape2vector = [](const TensorShape& shape) {
std::vector<size_t> vector_shape;
for (size_t i = 0; i < shape.ndim; i++) {
vector_shape.push_back(shape[i]);
}
return vector_shape;
};
ProfileEntry entry; ProfileEntry entry;
entry.id = m_entry_count++;
// TODO: assign parent
entry.parent = 0;
// Record apply context and save to m_profile
entry.op = def.copy(); entry.op = def.copy();
for (auto&& input : inputs) {
entry.inputs.push_back({m_tensor_recorder.record_tensor(input),
shape2vector(input->layout()),
input->comp_node()});
}
double host_begin = m_host_timer.get_msecs(); double host_begin = m_host_timer.get_msecs();
auto&& comp_nodes = collect_comp_nodes(def, inputs); auto&& comp_nodes = collect_comp_nodes(def, inputs);
for (auto&& comp_node : comp_nodes) { for (auto&& comp_node : comp_nodes) {
@@ -75,6 +140,11 @@ void Profiler::start() {
m_device_timer.get_device_time(comp_node), m_device_timer.get_device_time(comp_node),
{}}); {}});
} }
if (flags & PROFILE_FOOTPRINT) {
MGB_LOCK_GUARD(m_lock);
m_entry_stack.push({&def, &entry, std::this_thread::get_id()});
}
// Do real apply
auto outputs = apply(def, inputs); auto outputs = apply(def, inputs);
for (auto& [cn, dev_begin, dev_end] : entry.device_list) { for (auto& [cn, dev_begin, dev_end] : entry.device_list) {
MGB_MARK_USED_VAR(cn); MGB_MARK_USED_VAR(cn);
@@ -82,20 +152,71 @@ void Profiler::start() {
dev_end = m_device_timer.get_device_time(cn); dev_end = m_device_timer.get_device_time(cn);
} }
entry.host = {host_begin, m_host_timer.get_msecs()}; entry.host = {host_begin, m_host_timer.get_msecs()};
m_profile->push_back(std::move(entry));
for (auto&& output : outputs) {
entry.outputs.push_back(
{m_tensor_recorder.record_tensor(output),
shape2vector(output->layout()), output->comp_node()});
}
if (flags & PROFILE_FOOTPRINT) {
mgb_assert(std::get<1>(m_entry_stack.top()) == &entry);
MGB_LOCK_GUARD(m_lock);
m_entry_stack.pop();
}
m_profile.push_back(std::move(entry));
return outputs; return outputs;
}); });
m_hooker_list.push_back(std::move(hooker));
if (flags & PROFILE_FOOTPRINT) {
hook_apply_on_var_node->apply_hook(
[this](auto&& apply, const OpDef& def,
VarNodeArray inputs) -> cg::OperatorNodeBase* {
auto* operator_node = apply(def, std::move(inputs));
std::remove_reference_t<decltype(m_entry_stack.top())>
top;
{
MGB_LOCK_GUARD(m_lock);
if (m_entry_stack.empty()) {
return operator_node;
}
top = m_entry_stack.top();
}
auto [current_op, current_entry, thread_id] = top;
if (current_op != &def ||
thread_id != std::this_thread::get_id()) {
return operator_node;
}
auto&& footprint_result =
footprint.calc_footprint(operator_node);
current_entry->memory = footprint_result.memory;
current_entry->computation =
footprint_result.computation;
#if MGB_ENABLE_JSON
current_entry->param = footprint_result.param;
#endif
return operator_node;
});
}
m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor));
m_hooker_list.push_back(std::move(hook_apply_on_var_node));
}); });
} }


void Profiler::stop() { void Profiler::stop() {
m_hooker_list.clear(); m_hooker_list.clear();
for (auto& entry : *m_profile) {
for (auto& entry : m_profile) {
entry.wait_device(); entry.wait_device();
} }
} }


void Profiler::clear() {
mgb_assert(m_entry_stack.empty(),
"entry_stack should be empty after profile");
mgb_assert(m_hooker_list.empty(), "hooks should be released");
m_profile.clear();
m_entry_count = 0;
m_device_timer.clear();
m_tensor_recorder.clear();
}

} // namespace imperative } // namespace imperative


} // namespace mgb } // namespace mgb

+ 59
- 21
imperative/src/include/megbrain/imperative/profiler.h View File

@@ -11,7 +11,10 @@


#pragma once #pragma once


#include <variant>
#include <any>
#include <optional>
#include <stack>
#include <list>


#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
@@ -19,27 +22,39 @@
#include "megbrain/utils/timer.h" #include "megbrain/utils/timer.h"


#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"

#include "megbrain/imperative/function_hook.h"
#include "megbrain/imperative/physical_tensor.h"


namespace mgb { namespace mgb {
namespace imperative { namespace imperative {


struct ProfileEntry{
using ProfileTensor = std::tuple<size_t, std::vector<size_t>, CompNode>;

struct ProfileEntry {
using TimeClosure = std::function<double()>; using TimeClosure = std::function<double()>;
size_t id;
size_t parent;
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
//(host_begin, host_end)
std::tuple<double, double> host; std::tuple<double, double> host;
//[(device, device_begin, device_end)]
std::vector<std::tuple<CompNode, TimeClosure, TimeClosure>> device_list; std::vector<std::tuple<CompNode, TimeClosure, TimeClosure>> device_list;
void wait_device(){
for(auto& [cn, begin, end]: device_list){
std::vector<ProfileTensor> inputs;
std::vector<ProfileTensor> outputs;
ssize_t memory = 0;
ssize_t computation = 0;
#if MGB_ENABLE_JSON
std::shared_ptr<json::Value> param;
#endif
void wait_device() {
for (auto& [cn, begin, end] : device_list) {
MGB_MARK_USED_VAR(cn); MGB_MARK_USED_VAR(cn);
begin = [begin=begin()]{ return begin; };
end = [end = end()]{ return end; };
begin = [begin = begin()] { return begin; };
end = [end = end()] { return end; };
} }
} }
}; };


using Profile = std::vector<ProfileEntry>;
using Profile = std::list<ProfileEntry>;


class DeviceTimer { class DeviceTimer {
public: public:
@@ -47,31 +62,54 @@ public:
DeviceTimer() = default; DeviceTimer() = default;
void reset(thin_function<double()> host_timer); void reset(thin_function<double()> host_timer);
thin_function<double()> get_device_time(CompNode device); thin_function<double()> get_device_time(CompNode device);
void clear();


private: private:
CompNode::UnorderedMap<std::tuple<SharedEvent, double>> m_base_event_table; CompNode::UnorderedMap<std::tuple<SharedEvent, double>> m_base_event_table;
thin_function<double()> m_host_timer;
};

class TensorRecorder {
private:
// active tensors
std::unordered_map<Tensor*, std::tuple<std::weak_ptr<Tensor>, size_t>>
m_tensor_map;
size_t m_next_id;

public:
size_t record_tensor(const TensorPtr& tensor);
void clear();
}; };


class Profiler { class Profiler {
public: public:
Profiler(Profile* profile = nullptr) {
if (!profile) {
m_owned_profile = std::make_unique<Profile>();
profile = m_owned_profile.get();
}
m_profile = profile;
}
void start();
enum Flags {
PROFILE_FOOTPRINT = 1,
};

public:
Profiler() = default;
// Start profiler by hook OpTrait
void start(uint32_t flags);
// Stop profiler and clean environment
void stop(); void stop();
Profile& get_profile() { return *m_profile; }
void clear();
Profile& get_profile();


private: private:
DeviceTimer m_device_timer; DeviceTimer m_device_timer;
RealTimer m_host_timer; RealTimer m_host_timer;
Profile* m_profile;
Profile m_profile;
TensorRecorder m_tensor_recorder;
std::stack<std::tuple<const OpDef*, ProfileEntry*, std::thread::id>>
m_entry_stack;
// Hold profile owned by this Profiler
std::unique_ptr<Profile> m_owned_profile; std::unique_ptr<Profile> m_owned_profile;
std::vector<FunctionHooker<decltype(OpDef::apply_on_physical_tensor)>>
m_hooker_list;
// Hold hooks, cleared when stop
std::vector<std::any> m_hooker_list;
size_t m_entry_count = 0;
Spinlock m_lock;
std::unordered_map<Tensor*, std::weak_ptr<Tensor>> m_recorded_tensors;
}; };


} // namespace imperative } // namespace imperative


Loading…
Cancel
Save