GitOrigin-RevId: 40ca9ea80e
release-1.1
@@ -9,13 +9,155 @@ | |||
import base64 | |||
import json | |||
import os | |||
from typing import List, Optional | |||
import re | |||
from typing import Iterable, List, Optional | |||
from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry | |||
from ..core._imperative_rt import ProfilerImpl as _Profiler | |||
from ..core._imperative_rt.imperative import sync | |||
from ..core._imperative_rt.ops import CollectiveCommMode | |||
from ..core.ops.builtin import GetVarShape | |||
def _make_dict(**kwargs): | |||
unused_keys = [] | |||
for k, v in kwargs.items(): | |||
if v is None: | |||
unused_keys.append(k) | |||
for k in unused_keys: | |||
del kwargs[k] | |||
return kwargs | |||
def _print_opnode_config(config): | |||
return _make_dict( | |||
name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr, | |||
) | |||
def _dump_chrome_timeline(entries: List[ProfileEntry], path: str): | |||
pid = os.getpid() | |||
trace_events = [] | |||
def append_event(**kwargs): | |||
trace_events.append(_make_dict(**kwargs)) | |||
for id, entry in enumerate(entries): | |||
op = entry.op | |||
name = type(op).__name__ | |||
host_begin, host_end = entry.host | |||
device_list = entry.device_list | |||
args = Profiler.fetch_attrs(op) | |||
args["__id__"] = "[{}]".format(id) | |||
cat = name | |||
for ts, ph in [(host_begin, "B"), (host_end, "E")]: | |||
append_event( | |||
name=name, ph=ph, ts=ts * 1000, pid=pid, tid="host", args=args, cat=cat, | |||
) | |||
for device, device_begin, device_end in device_list: | |||
for ts, ph in [(device_begin(), "B"), (device_end(), "E")]: | |||
append_event( | |||
name=name, ph=ph, ts=ts * 1000, pid=pid, tid=str(device), args=args, | |||
) | |||
with open("{}.chrome_timeline.json".format(path), "w") as f: | |||
json.dump(trace_events, f, indent=2) | |||
def _dump_compatible(entries: List[ProfileEntry], path: str): | |||
obj = { | |||
"graph_exec": {"var": [], "operator": {}}, | |||
"profiler": {"device": {}, "host": {}, "opr_footprint": {}}, | |||
} | |||
var_list = obj["graph_exec"]["var"] | |||
operator_dict = obj["graph_exec"]["operator"] | |||
device_dict = obj["profiler"]["device"] | |||
host_dict = obj["profiler"]["host"] | |||
opr_foot_print_dict = obj["profiler"]["opr_footprint"] | |||
def add_var(var) -> int: | |||
var_id = len(var_list) | |||
var_list.append( | |||
{"comp_node": str(var[2]),} | |||
) | |||
return var_id | |||
for op_id, entry in enumerate(entries): | |||
operator_dict[op_id] = { | |||
"input": [add_var(var) for var in entry.inputs], | |||
"output": [add_var(var) for var in entry.outputs], | |||
"name": str(entry.op.ctype()), | |||
"type": "imperative", | |||
"id": entry.id, | |||
} | |||
op_device_dict = {} | |||
for device, device_begin, device_end in entry.device_list: | |||
op_device_dict[str(device)] = { | |||
"start": device_begin(), | |||
"kern": device_begin(), | |||
"end": device_end(), | |||
} | |||
device_dict[op_id] = op_device_dict | |||
host_begin, host_end = entry.host | |||
host_dict[op_id] = { | |||
"host": {"start": host_begin, "kern": host_begin, "end": host_end} | |||
} | |||
opr_footprint = { | |||
"out_shapes": [oup[1] for oup in entry.outputs], | |||
"in_shapes": [inp[1] for inp in entry.inputs], | |||
"params": {}, | |||
} | |||
if entry.memory > 0: | |||
opr_footprint["memory"] = entry.memory | |||
if entry.computation > 0: | |||
opr_footprint["computation"] = entry.computation | |||
opr_foot_print_dict[op_id] = opr_footprint | |||
with open("{}.compatible.json".format(path), "w") as f: | |||
json.dump(obj, f, indent=2) | |||
def _dump_graphviz(entries: List[ProfileEntry], path: str): | |||
import graphviz | |||
import json | |||
graph = graphviz.Digraph() | |||
graph.graph_attr["ordering"] = "out" | |||
var_cache = {} | |||
def cache_var(var_id, var_shape): | |||
if var_id not in var_cache: | |||
var_name = "var({})".format(var_id) | |||
var_label = "{}\nshape:{}\n".format(var_name, shape) | |||
graph.node(var_name, var_label) | |||
var_cache[var_id] = var_name | |||
return var_cache[var_id] | |||
for op_id, entry in enumerate(entries): | |||
op = entry.op | |||
op_name = "op({})".format(op_id) | |||
op_type = type(op).__name__ | |||
op_attrs = Profiler.fetch_attrs(op) | |||
label_lines = [] | |||
if "param" in op_attrs: | |||
del op_attrs["param"] | |||
label_lines.append("{}:{}".format(op_name, op_type)) | |||
for k, v in op_attrs.items(): | |||
label_lines.append("attr[{}]: {}".format(k, v)) | |||
op_param_str = entry.param | |||
if len(op_param_str) > 0: | |||
op_param = json.loads(op_param_str) | |||
for k, v in op_param.items(): | |||
label_lines.append("param[{}]:{}".format(k, v)) | |||
host_begin, host_end = entry.host | |||
label_lines.append("time[host]: {:f}ms".format(host_end - host_begin)) | |||
for device, device_begin, device_end in entry.device_list: | |||
device_time = device_end() - device_begin() | |||
label_lines.append("time[{}]: {:f}ms".format(device, device_time)) | |||
op_label = "\n".join(label_lines) | |||
graph.node(op_name, op_label, shape="rectangle") | |||
for var_id, shape, device in entry.inputs: | |||
graph.edge(cache_var(var_id, shape), op_name) | |||
for var_id, shape, device in entry.outputs: | |||
graph.edge(op_name, cache_var(var_id, shape)) | |||
graph.save("{}.graphviz.dot".format(path)) | |||
class Profiler: | |||
@@ -23,7 +165,7 @@ class Profiler: | |||
Profile graph execution in imperative mode. | |||
:type path: Optional[str] | |||
:param path: default path for profiler to dump. | |||
:param path: default path prefix for profiler to dump. | |||
Examples: | |||
@@ -31,59 +173,67 @@ class Profiler: | |||
import megengine as mge | |||
import megengine.module as M | |||
import megengine.utils.profiler.Profiler | |||
from megengine.utils.profiler import Profiler | |||
# With Learnable Parameters | |||
for iter in range(0, 10): | |||
# Only profile record of last iter would be saved | |||
with Profiler("profile.json"): | |||
with Profiler("profile"): | |||
# your code here | |||
# Then open the profile file in chrome timeline window | |||
""" | |||
# see https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html | |||
GOOD = "good" | |||
BAD = "bad" | |||
TERRIBLE = "terrible" | |||
CHROME_TIMELINE = "chrome_timeline" | |||
COMPATIBLE = "compatible" | |||
GRAPHVIZ = "graphviz" | |||
WITH_FOOTPRINT = 1 | |||
BLACK = "black" | |||
GREY = "grey" | |||
WHITE = "white" | |||
YELLOW = "yellow" | |||
OLIVE = "olive" | |||
_type_map = { | |||
OperatorNodeConfig: lambda x: _print_opnode_config(x), | |||
bytes: lambda x: base64.encodebytes(x).decode("ascii"), | |||
CollectiveCommMode: lambda x: str(x), | |||
} | |||
def __init__(self, path: str = "profile.json"): | |||
_dumper_map = { | |||
CHROME_TIMELINE: _dump_chrome_timeline, | |||
COMPATIBLE: _dump_compatible, | |||
GRAPHVIZ: _dump_graphviz, | |||
} | |||
def __init__( | |||
self, | |||
path: str = "profile", | |||
*, | |||
formats: Iterable[str] = (CHROME_TIMELINE,), | |||
type_filter: str = ".*", | |||
exit_dump: bool = True | |||
) -> None: | |||
self._impl = _Profiler() | |||
self._path = path | |||
self._color_map = {} | |||
self._type_map = { | |||
OperatorNodeConfig: lambda x: self.print_opnode_config(x), | |||
bytes: lambda x: base64.encodebytes(x).decode("ascii"), | |||
CollectiveCommMode: lambda x: str(x), | |||
} | |||
if isinstance(formats, str): | |||
formats = (formats,) | |||
self._filter = type_filter | |||
self._dumpers = [Profiler._dumper_map[fmt] for fmt in formats] | |||
self._exit_dump = exit_dump | |||
def __enter__(self): | |||
sync() | |||
self._impl.start() | |||
self._impl.start(Profiler.WITH_FOOTPRINT) | |||
return self | |||
def __exit__(self, val, type, trace): | |||
def __exit__(self, val, tp, trace): | |||
if self._exit_dump: | |||
self.dump() | |||
sync() | |||
self._impl.stop() | |||
if self._path is not None: | |||
self.dump() | |||
def recolor(self, target: str, color: str): | |||
self._color_map[target] = color | |||
return self | |||
self._impl.clear() | |||
def print_opnode_config(self, config): | |||
return self.make_dict( | |||
name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr, | |||
) | |||
def fetch_attrs(self, op): | |||
@classmethod | |||
def fetch_attrs(cls, op): | |||
attrs = dir(op) | |||
results = {} | |||
for attr in attrs: | |||
@@ -93,61 +243,29 @@ class Profiler: | |||
if callable(value): | |||
continue | |||
value_type = type(value) | |||
if value_type in self._type_map: | |||
value = self._type_map[value_type](value) | |||
if value_type in cls._type_map: | |||
value = cls._type_map[value_type](value) | |||
results[attr] = value | |||
return results | |||
def make_dict(self, **kwargs): | |||
unused_keys = [] | |||
for k, v in kwargs.items(): | |||
if v is None: | |||
unused_keys.append(k) | |||
for k in unused_keys: | |||
del kwargs[k] | |||
return kwargs | |||
def dump(self, path: Optional[str] = None): | |||
pid = os.getpid() | |||
sync() | |||
raw = [ | |||
entry | |||
for entry in self._impl.dump() | |||
if re.match(self._filter, type(entry.op).__name__) | |||
] | |||
if path is None: | |||
path = self._path | |||
trace_events = [] | |||
def append_event(**kwargs): | |||
trace_events.append(self.make_dict(**kwargs)) | |||
entries: List[ProfileEntry] = self._impl.dump() | |||
for id, entry in enumerate(entries): | |||
op = entry.op | |||
name = type(op).__name__ | |||
host_begin, host_end = entry.host | |||
device_list = entry.device_list | |||
args = self.fetch_attrs(op) | |||
args["__id__"] = "[{}]".format(id) | |||
cname = self._color_map[name] if name in self._color_map else None | |||
cat = name | |||
for ts, ph in [(host_begin, "B"), (host_end, "E")]: | |||
append_event( | |||
name=name, | |||
ph=ph, | |||
ts=ts * 1000, | |||
pid=pid, | |||
tid="host", | |||
args=args, | |||
cname=cname, | |||
cat=cat, | |||
) | |||
for device, device_begin, device_end in device_list: | |||
for ts, ph in [(device_begin(), "B"), (device_end(), "E")]: | |||
append_event( | |||
name=name, | |||
ph=ph, | |||
ts=ts * 1000, | |||
pid=pid, | |||
tid=str(device), | |||
args=args, | |||
cname=cname, | |||
) | |||
with open(path, "w") as f: | |||
json.dump(trace_events, f, indent=2) | |||
for dumper in self._dumpers: | |||
dumper(raw, path) | |||
def __call__(self, func): | |||
def wrapper(*args, **kwargs): | |||
with self: | |||
return func(*args, **kwargs) | |||
return wrapper | |||
profile = Profiler |
@@ -204,17 +204,27 @@ void init_utils(py::module m) { | |||
py::class_<ProfileEntry>(m, "ProfileEntry") | |||
.def_readwrite("op", &ProfileEntry::op) | |||
.def_readwrite("host", &ProfileEntry::host) | |||
.def_readwrite("device_list", &ProfileEntry::device_list); | |||
.def_readwrite("device_list", &ProfileEntry::device_list) | |||
.def_readwrite("inputs", &ProfileEntry::inputs) | |||
.def_readwrite("outputs", &ProfileEntry::outputs) | |||
.def_readwrite("id", &ProfileEntry::id) | |||
.def_readwrite("parent", &ProfileEntry::parent) | |||
.def_readwrite("memory", &ProfileEntry::memory) | |||
.def_readwrite("computation", &ProfileEntry::computation) | |||
.def_property_readonly("param", [](ProfileEntry& self)->std::string{ | |||
if(self.param){ | |||
return self.param->to_string(); | |||
} else { | |||
return {}; | |||
} | |||
}); | |||
py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl") | |||
.def(py::init<>()) | |||
.def("start", | |||
[](mgb::imperative::Profiler& profiler) { profiler.start(); }) | |||
.def("stop", | |||
[](mgb::imperative::Profiler& profiler) { profiler.stop(); }) | |||
.def("dump", [](mgb::imperative::Profiler& profiler) { | |||
return profiler.get_profile(); | |||
}); | |||
.def("start", &mgb::imperative::Profiler::start) | |||
.def("stop", &mgb::imperative::Profiler::stop) | |||
.def("clear", &mgb::imperative::Profiler::clear) | |||
.def("dump", &mgb::imperative::Profiler::get_profile); | |||
using mgb::imperative::TensorSanityCheck; | |||
py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl") | |||
@@ -15,6 +15,7 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
template <typename TFunction> | |||
class FunctionHooker; | |||
@@ -22,13 +23,18 @@ template <typename TRet, typename... TArgs> | |||
class FunctionHooker<TRet(TArgs...)> { | |||
public: | |||
using FunctionType = thin_function<TRet(TArgs&&...)>; | |||
//Type of hooks. Hook should accept a real function as argument | |||
//and invoke it on an appropriate time | |||
using HookType = thin_function<TRet(FunctionType, TArgs&&...)>; | |||
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} {} | |||
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { | |||
m_backup = {nullptr, [](FunctionType*){}}; | |||
} | |||
public: | |||
FunctionHooker& apply_hook(HookType&& hook) { | |||
if (!m_backup) { | |||
FunctionType* backup = new FunctionType(*m_fptr); | |||
//Restore hooked function, would be invoked when destructed | |||
std::function<void(FunctionType*)> restorer = | |||
[fptr = m_fptr](FunctionType* bkp) -> void { | |||
*fptr = *bkp; | |||
@@ -36,9 +42,11 @@ public: | |||
}; | |||
m_backup = decltype(m_backup)(backup, restorer); | |||
} | |||
//Replace with hooked version | |||
*m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet { | |||
return hook(func, std::forward<TArgs>(args)...); | |||
}; | |||
//Convinent for chain call | |||
return *this; | |||
} | |||
@@ -47,9 +55,15 @@ private: | |||
std::unique_ptr<FunctionType, std::function<void(FunctionType*)>> m_backup; | |||
}; | |||
//Helps to deduce template args | |||
template <typename TRet, typename... TArgs> | |||
FunctionHooker(thin_function<TRet(TArgs...)>* f) | |||
->FunctionHooker<TRet(TArgs...)>; | |||
} // namespace imperative | |||
template<typename TSignature> | |||
auto make_shared_hook(thin_function<TSignature>* fptr){ | |||
return std::make_shared<FunctionHooker<TSignature>>(fptr); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -11,19 +11,20 @@ | |||
#include "megbrain/imperative/profiler.h" | |||
#include <variant> | |||
#include "./function_hook.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/physical_tensor.h" | |||
#include "megbrain/plugin/opr_footprint.h" | |||
#include "./event_pool.h" | |||
#include "./op_trait.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
CompNode::UnorderedSet collect_comp_nodes( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
CompNode::UnorderedSet comp_nodes; | |||
@@ -36,37 +37,101 @@ CompNode::UnorderedSet collect_comp_nodes( | |||
return comp_nodes; | |||
} | |||
DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) { | |||
auto event = EventPool::with_timer().alloc_shared(device); | |||
event->record(); | |||
return event; | |||
} | |||
OprFootprint footprint{}; | |||
} // namespace | |||
void DeviceTimer::reset(thin_function<double()> host_timer) { | |||
CompNode::foreach ([this, host_timer](CompNode device) { | |||
auto base_event = EventPool::with_timer().alloc_shared(device); | |||
base_event->record(); | |||
m_base_event_table[device] = {std::move(base_event), host_timer()}; | |||
m_base_event_table[device] = {alloc_recorded_event(device), host_timer()}; | |||
}); | |||
m_host_timer = host_timer; | |||
} | |||
thin_function<double()> DeviceTimer::get_device_time(CompNode device) { | |||
auto event = EventPool::with_timer().alloc_shared(device); | |||
event->record(); | |||
if(m_base_event_table.count(device) == 0) { | |||
m_base_event_table[device] = {alloc_recorded_event(device), m_host_timer()}; | |||
} | |||
auto base = m_base_event_table[device]; | |||
return [base, event] { | |||
auto [base_event, host_time] = base; | |||
//TODO: sync once for each compnode | |||
// TODO: sync once for each compnode | |||
event->host_wait(); | |||
return base_event->elapsed_time_until(*event) * 1000 + host_time; | |||
}; | |||
} | |||
void Profiler::start() { | |||
void DeviceTimer::clear() { | |||
m_base_event_table.clear(); | |||
} | |||
size_t TensorRecorder::record_tensor(const TensorPtr& tensor) { | |||
if (m_tensor_map.count(tensor.get()) > 0) { | |||
auto& [prev, id] = m_tensor_map[tensor.get()]; | |||
if (prev.lock() != tensor) { | |||
prev = tensor; | |||
id = m_next_id++; | |||
} | |||
return id; | |||
} else { | |||
auto id = m_next_id++; | |||
m_tensor_map.insert({tensor.get(), {std::weak_ptr{tensor}, id}}); | |||
return id; | |||
} | |||
} | |||
void TensorRecorder::clear() { | |||
m_next_id = 0; | |||
m_tensor_map.clear(); | |||
} | |||
Profile& Profiler::get_profile() { | |||
for (auto& entry : m_profile) { | |||
for (auto& [device, device_begin, device_end] : entry.device_list) { | |||
MGB_MARK_USED_VAR(device); | |||
device_begin = [value = device_begin()] { return value; }; | |||
device_end = [value = device_end()] { return value; }; | |||
} | |||
} | |||
return m_profile; | |||
} | |||
void Profiler::start(uint32_t flags) { | |||
m_host_timer.reset(); | |||
m_device_timer.reset([&]{ return m_host_timer.get_msecs();} ); | |||
OpTrait::for_each_trait([this](OpTrait& trait) { | |||
FunctionHooker hooker{&trait.apply_on_physical_tensor}; | |||
hooker.apply_hook([this](auto&& apply, const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
m_device_timer.reset([&] { return m_host_timer.get_msecs(); }); | |||
OpTrait::for_each_trait([this, flags](OpTrait& trait) { | |||
auto hook_apply_on_physical_tensor = | |||
make_shared_hook(&trait.apply_on_physical_tensor); | |||
auto hook_apply_on_var_node = | |||
make_shared_hook(&trait.apply_on_var_node); | |||
hook_apply_on_physical_tensor->apply_hook([this, flags] | |||
(auto&& apply, const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto shape2vector = [](const TensorShape& shape) { | |||
std::vector<size_t> vector_shape; | |||
for (size_t i = 0; i < shape.ndim; i++) { | |||
vector_shape.push_back(shape[i]); | |||
} | |||
return vector_shape; | |||
}; | |||
ProfileEntry entry; | |||
entry.id = m_entry_count++; | |||
// TODO: assign parent | |||
entry.parent = 0; | |||
// Record apply context and save to m_profile | |||
entry.op = def.copy(); | |||
for (auto&& input : inputs) { | |||
entry.inputs.push_back({m_tensor_recorder.record_tensor(input), | |||
shape2vector(input->layout()), | |||
input->comp_node()}); | |||
} | |||
double host_begin = m_host_timer.get_msecs(); | |||
auto&& comp_nodes = collect_comp_nodes(def, inputs); | |||
for (auto&& comp_node : comp_nodes) { | |||
@@ -75,6 +140,11 @@ void Profiler::start() { | |||
m_device_timer.get_device_time(comp_node), | |||
{}}); | |||
} | |||
if (flags & PROFILE_FOOTPRINT) { | |||
MGB_LOCK_GUARD(m_lock); | |||
m_entry_stack.push({&def, &entry, std::this_thread::get_id()}); | |||
} | |||
// Do real apply | |||
auto outputs = apply(def, inputs); | |||
for (auto& [cn, dev_begin, dev_end] : entry.device_list) { | |||
MGB_MARK_USED_VAR(cn); | |||
@@ -82,20 +152,71 @@ void Profiler::start() { | |||
dev_end = m_device_timer.get_device_time(cn); | |||
} | |||
entry.host = {host_begin, m_host_timer.get_msecs()}; | |||
m_profile->push_back(std::move(entry)); | |||
for (auto&& output : outputs) { | |||
entry.outputs.push_back( | |||
{m_tensor_recorder.record_tensor(output), | |||
shape2vector(output->layout()), output->comp_node()}); | |||
} | |||
if (flags & PROFILE_FOOTPRINT) { | |||
mgb_assert(std::get<1>(m_entry_stack.top()) == &entry); | |||
MGB_LOCK_GUARD(m_lock); | |||
m_entry_stack.pop(); | |||
} | |||
m_profile.push_back(std::move(entry)); | |||
return outputs; | |||
}); | |||
m_hooker_list.push_back(std::move(hooker)); | |||
if (flags & PROFILE_FOOTPRINT) { | |||
hook_apply_on_var_node->apply_hook( | |||
[this](auto&& apply, const OpDef& def, | |||
VarNodeArray inputs) -> cg::OperatorNodeBase* { | |||
auto* operator_node = apply(def, std::move(inputs)); | |||
std::remove_reference_t<decltype(m_entry_stack.top())> | |||
top; | |||
{ | |||
MGB_LOCK_GUARD(m_lock); | |||
if (m_entry_stack.empty()) { | |||
return operator_node; | |||
} | |||
top = m_entry_stack.top(); | |||
} | |||
auto [current_op, current_entry, thread_id] = top; | |||
if (current_op != &def || | |||
thread_id != std::this_thread::get_id()) { | |||
return operator_node; | |||
} | |||
auto&& footprint_result = | |||
footprint.calc_footprint(operator_node); | |||
current_entry->memory = footprint_result.memory; | |||
current_entry->computation = | |||
footprint_result.computation; | |||
#if MGB_ENABLE_JSON | |||
current_entry->param = footprint_result.param; | |||
#endif | |||
return operator_node; | |||
}); | |||
} | |||
m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); | |||
m_hooker_list.push_back(std::move(hook_apply_on_var_node)); | |||
}); | |||
} | |||
void Profiler::stop() { | |||
m_hooker_list.clear(); | |||
for (auto& entry : *m_profile) { | |||
for (auto& entry : m_profile) { | |||
entry.wait_device(); | |||
} | |||
} | |||
void Profiler::clear() { | |||
mgb_assert(m_entry_stack.empty(), | |||
"entry_stack should be empty after profile"); | |||
mgb_assert(m_hooker_list.empty(), "hooks should be released"); | |||
m_profile.clear(); | |||
m_entry_count = 0; | |||
m_device_timer.clear(); | |||
m_tensor_recorder.clear(); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -11,7 +11,10 @@ | |||
#pragma once | |||
#include <variant> | |||
#include <any> | |||
#include <optional> | |||
#include <stack> | |||
#include <list> | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/graph/event.h" | |||
@@ -19,27 +22,39 @@ | |||
#include "megbrain/utils/timer.h" | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/imperative/function_hook.h" | |||
#include "megbrain/imperative/physical_tensor.h" | |||
namespace mgb { | |||
namespace imperative { | |||
struct ProfileEntry{ | |||
using ProfileTensor = std::tuple<size_t, std::vector<size_t>, CompNode>; | |||
struct ProfileEntry { | |||
using TimeClosure = std::function<double()>; | |||
size_t id; | |||
size_t parent; | |||
std::shared_ptr<OpDef> op; | |||
//(host_begin, host_end) | |||
std::tuple<double, double> host; | |||
//[(device, device_begin, device_end)] | |||
std::vector<std::tuple<CompNode, TimeClosure, TimeClosure>> device_list; | |||
void wait_device(){ | |||
for(auto& [cn, begin, end]: device_list){ | |||
std::vector<ProfileTensor> inputs; | |||
std::vector<ProfileTensor> outputs; | |||
ssize_t memory = 0; | |||
ssize_t computation = 0; | |||
#if MGB_ENABLE_JSON | |||
std::shared_ptr<json::Value> param; | |||
#endif | |||
void wait_device() { | |||
for (auto& [cn, begin, end] : device_list) { | |||
MGB_MARK_USED_VAR(cn); | |||
begin = [begin=begin()]{ return begin; }; | |||
end = [end = end()]{ return end; }; | |||
begin = [begin = begin()] { return begin; }; | |||
end = [end = end()] { return end; }; | |||
} | |||
} | |||
}; | |||
using Profile = std::vector<ProfileEntry>; | |||
using Profile = std::list<ProfileEntry>; | |||
class DeviceTimer { | |||
public: | |||
@@ -47,31 +62,54 @@ public: | |||
DeviceTimer() = default; | |||
void reset(thin_function<double()> host_timer); | |||
thin_function<double()> get_device_time(CompNode device); | |||
void clear(); | |||
private: | |||
CompNode::UnorderedMap<std::tuple<SharedEvent, double>> m_base_event_table; | |||
thin_function<double()> m_host_timer; | |||
}; | |||
class TensorRecorder { | |||
private: | |||
// active tensors | |||
std::unordered_map<Tensor*, std::tuple<std::weak_ptr<Tensor>, size_t>> | |||
m_tensor_map; | |||
size_t m_next_id; | |||
public: | |||
size_t record_tensor(const TensorPtr& tensor); | |||
void clear(); | |||
}; | |||
class Profiler { | |||
public: | |||
Profiler(Profile* profile = nullptr) { | |||
if (!profile) { | |||
m_owned_profile = std::make_unique<Profile>(); | |||
profile = m_owned_profile.get(); | |||
} | |||
m_profile = profile; | |||
} | |||
void start(); | |||
enum Flags { | |||
PROFILE_FOOTPRINT = 1, | |||
}; | |||
public: | |||
Profiler() = default; | |||
// Start profiler by hook OpTrait | |||
void start(uint32_t flags); | |||
// Stop profiler and clean environment | |||
void stop(); | |||
Profile& get_profile() { return *m_profile; } | |||
void clear(); | |||
Profile& get_profile(); | |||
private: | |||
DeviceTimer m_device_timer; | |||
RealTimer m_host_timer; | |||
Profile* m_profile; | |||
Profile m_profile; | |||
TensorRecorder m_tensor_recorder; | |||
std::stack<std::tuple<const OpDef*, ProfileEntry*, std::thread::id>> | |||
m_entry_stack; | |||
// Hold profile owned by this Profiler | |||
std::unique_ptr<Profile> m_owned_profile; | |||
std::vector<FunctionHooker<decltype(OpDef::apply_on_physical_tensor)>> | |||
m_hooker_list; | |||
// Hold hooks, cleared when stop | |||
std::vector<std::any> m_hooker_list; | |||
size_t m_entry_count = 0; | |||
Spinlock m_lock; | |||
std::unordered_map<Tensor*, std::weak_ptr<Tensor>> m_recorded_tensors; | |||
}; | |||
} // namespace imperative | |||