GitOrigin-RevId: ccc984acbd
tags/v1.3.0
@@ -3,6 +3,7 @@ from collections import defaultdict | |||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from typing import Callable | from typing import Callable | ||||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||||
from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -239,6 +240,7 @@ class GradManager: | |||||
:param y: tensor or list of tensors | :param y: tensor or list of tensors | ||||
:param dy: tensor or list of tensors. Defaults to 1 if y is scalar | :param dy: tensor or list of tensors. Defaults to 1 if y is scalar | ||||
""" | """ | ||||
push_scope("backward") | |||||
from ..functional import ones_like | from ..functional import ones_like | ||||
global backwarding_grad_manager | global backwarding_grad_manager | ||||
@@ -280,6 +282,7 @@ class GradManager: | |||||
finally: | finally: | ||||
self.release() | self.release() | ||||
backwarding_grad_manager = cache | backwarding_grad_manager = cache | ||||
pop_scope("backward") | |||||
def record(self): | def record(self): | ||||
r""" | r""" | ||||
@@ -8,5 +8,17 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import os | import os | ||||
import sys | import sys | ||||
from contextlib import contextmanager | |||||
from ._imperative_rt.core2 import get_option, set_option | |||||
from .tensor.megbrain_graph import Graph | from .tensor.megbrain_graph import Graph | ||||
@contextmanager | |||||
def option(key, value): | |||||
value = int(value) | |||||
old = get_option(key) | |||||
set_option(key, value) | |||||
yield | |||||
assert get_option(key) == value | |||||
set_option(key, old) |
@@ -12,6 +12,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||||
from ..core.tensor.utils import make_shape_tuple | from ..core.tensor.utils import make_shape_tuple | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
@@ -78,6 +79,7 @@ class Module(metaclass=ABCMeta): | |||||
self._forward_hooks = OrderedDict() | self._forward_hooks = OrderedDict() | ||||
self._modules = [] | self._modules = [] | ||||
self._name = "{anonymous}" | |||||
@abstractmethod | @abstractmethod | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
@@ -103,6 +105,7 @@ class Module(metaclass=ABCMeta): | |||||
return HookHandler(self._forward_hooks, hook) | return HookHandler(self._forward_hooks, hook) | ||||
def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
push_scope(self._name) | |||||
for hook in self._forward_pre_hooks.values(): | for hook in self._forward_pre_hooks.values(): | ||||
modified_inputs = hook(self, inputs) | modified_inputs = hook(self, inputs) | ||||
if modified_inputs is not None: | if modified_inputs is not None: | ||||
@@ -116,6 +119,7 @@ class Module(metaclass=ABCMeta): | |||||
modified_outputs = hook(self, inputs, outputs) | modified_outputs = hook(self, inputs, outputs) | ||||
if modified_outputs is not None: | if modified_outputs is not None: | ||||
outputs = modified_outputs | outputs = modified_outputs | ||||
pop_scope(self._name) | |||||
return outputs | return outputs | ||||
def _flatten( | def _flatten( | ||||
@@ -571,6 +575,14 @@ class Module(metaclass=ABCMeta): | |||||
return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
def __getattribute__(self, name: str): | |||||
value = super().__getattribute__(name) | |||||
if name == "_name": | |||||
return value | |||||
if _is_module(value): | |||||
value._name = name | |||||
return value | |||||
def __setattr__(self, name: str, value): | def __setattr__(self, name: str, value): | ||||
if _is_module(value): | if _is_module(value): | ||||
modules = self.__dict__.get("_modules") | modules = self.__dict__.get("_modules") | ||||
@@ -15,6 +15,7 @@ from typing import Union | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||||
from ..core.tensor.utils import set_convert_inputs | from ..core.tensor.utils import set_convert_inputs | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
from ..utils.deprecation import deprecated | from ..utils.deprecation import deprecated | ||||
@@ -155,7 +156,9 @@ class Optimizer(metaclass=ABCMeta): | |||||
"but the ordering of parameters in sets will change between runs. " | "but the ordering of parameters in sets will change between runs. " | ||||
"Please use a list instead." | "Please use a list instead." | ||||
) | ) | ||||
push_scope("step") | |||||
self._updates(group) | self._updates(group) | ||||
pop_scope("step") | |||||
# restore the globle state `_enable_convert_inputs` | # restore the globle state `_enable_convert_inputs` | ||||
set_convert_inputs(backup) | set_convert_inputs(backup) | ||||
return self | return self | ||||
@@ -172,8 +175,10 @@ class Optimizer(metaclass=ABCMeta): | |||||
Set the grad attribute to None for all parameters. | Set the grad attribute to None for all parameters. | ||||
""" | """ | ||||
for param_group in self.param_groups: | for param_group in self.param_groups: | ||||
push_scope("clear_grad") | |||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
param.grad = None | param.grad = None | ||||
pop_scope("clear_grad") | |||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
r""" | r""" | ||||
@@ -6,159 +6,17 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import base64 | |||||
import json | import json | ||||
import os | |||||
import re | |||||
from typing import Iterable, List, Optional | |||||
from contextlib import contextmanager | |||||
from typing import List | |||||
from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry | |||||
from ..core._imperative_rt import ProfilerImpl as _Profiler | |||||
from ..core._imperative_rt.core2 import sync | |||||
from ..core._imperative_rt.ops import CollectiveComm | |||||
def _make_dict(**kwargs): | |||||
unused_keys = [] | |||||
for k, v in kwargs.items(): | |||||
if v is None: | |||||
unused_keys.append(k) | |||||
for k in unused_keys: | |||||
del kwargs[k] | |||||
return kwargs | |||||
def _print_opnode_config(config): | |||||
return _make_dict( | |||||
name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr, | |||||
) | |||||
def _dump_chrome_timeline(entries: List[ProfileEntry], path: str): | |||||
pid = os.getpid() | |||||
trace_events = [] | |||||
def append_event(**kwargs): | |||||
trace_events.append(_make_dict(**kwargs)) | |||||
for id, entry in enumerate(entries): | |||||
op = entry.op | |||||
name = type(op).__name__ | |||||
host_begin, host_end = entry.host | |||||
device_list = entry.device_list | |||||
args = Profiler.fetch_attrs(op) | |||||
args["__id__"] = "[{}]".format(id) | |||||
cat = name | |||||
for ts, ph in [(host_begin, "B"), (host_end, "E")]: | |||||
append_event( | |||||
name=name, ph=ph, ts=ts * 1000, pid=pid, tid="host", args=args, cat=cat, | |||||
) | |||||
for device, device_begin, device_end in device_list: | |||||
for ts, ph in [(device_begin(), "B"), (device_end(), "E")]: | |||||
append_event( | |||||
name=name, ph=ph, ts=ts * 1000, pid=pid, tid=str(device), args=args, | |||||
) | |||||
with open("{}.chrome_timeline.json".format(path), "w") as f: | |||||
json.dump(trace_events, f, indent=2) | |||||
def _dump_compatible(entries: List[ProfileEntry], path: str): | |||||
obj = { | |||||
"graph_exec": {"var": [], "operator": {}}, | |||||
"profiler": {"device": {}, "host": {}, "opr_footprint": {}}, | |||||
} | |||||
var_list = obj["graph_exec"]["var"] | |||||
operator_dict = obj["graph_exec"]["operator"] | |||||
device_dict = obj["profiler"]["device"] | |||||
host_dict = obj["profiler"]["host"] | |||||
opr_foot_print_dict = obj["profiler"]["opr_footprint"] | |||||
def add_var(var) -> int: | |||||
var_id = len(var_list) | |||||
var_list.append( | |||||
{"comp_node": str(var[2]),} | |||||
) | |||||
return var_id | |||||
for op_id, entry in enumerate(entries): | |||||
operator_dict[op_id] = { | |||||
"input": [add_var(var) for var in entry.inputs], | |||||
"output": [add_var(var) for var in entry.outputs], | |||||
"name": str(entry.op.ctype()), | |||||
"type": "imperative", | |||||
"id": entry.id, | |||||
} | |||||
op_device_dict = {} | |||||
for device, device_begin, device_end in entry.device_list: | |||||
op_device_dict[str(device)] = { | |||||
"start": device_begin(), | |||||
"kern": device_begin(), | |||||
"end": device_end(), | |||||
} | |||||
device_dict[op_id] = op_device_dict | |||||
host_begin, host_end = entry.host | |||||
host_dict[op_id] = { | |||||
"host": {"start": host_begin, "kern": host_begin, "end": host_end} | |||||
} | |||||
opr_footprint = { | |||||
"out_shapes": [oup[1] for oup in entry.outputs], | |||||
"in_shapes": [inp[1] for inp in entry.inputs], | |||||
"params": {}, | |||||
} | |||||
if entry.memory > 0: | |||||
opr_footprint["memory"] = entry.memory | |||||
if entry.computation > 0: | |||||
opr_footprint["computation"] = entry.computation | |||||
opr_foot_print_dict[op_id] = opr_footprint | |||||
with open("{}.compatible.json".format(path), "w") as f: | |||||
json.dump(obj, f, indent=2) | |||||
def _dump_graphviz(entries: List[ProfileEntry], path: str): | |||||
import json | |||||
import graphviz | |||||
graph = graphviz.Digraph() | |||||
graph.graph_attr["ordering"] = "out" | |||||
var_cache = {} | |||||
def cache_var(var_id, var_shape): | |||||
if var_id not in var_cache: | |||||
var_name = "var({})".format(var_id) | |||||
var_label = "{}\nshape:{}\n".format(var_name, shape) | |||||
graph.node(var_name, var_label) | |||||
var_cache[var_id] = var_name | |||||
return var_cache[var_id] | |||||
for op_id, entry in enumerate(entries): | |||||
op = entry.op | |||||
op_name = "op({})".format(op_id) | |||||
op_type = type(op).__name__ | |||||
op_attrs = Profiler.fetch_attrs(op) | |||||
label_lines = [] | |||||
if "param" in op_attrs: | |||||
del op_attrs["param"] | |||||
label_lines.append("{}:{}".format(op_name, op_type)) | |||||
for k, v in op_attrs.items(): | |||||
label_lines.append("attr[{}]: {}".format(k, v)) | |||||
op_param_str = entry.param | |||||
if len(op_param_str) > 0: | |||||
op_param = json.loads(op_param_str) | |||||
for k, v in op_param.items(): | |||||
label_lines.append("param[{}]:{}".format(k, v)) | |||||
host_begin, host_end = entry.host | |||||
label_lines.append("time[host]: {:f}ms".format(host_end - host_begin)) | |||||
for device, device_begin, device_end in entry.device_list: | |||||
device_time = device_end() - device_begin() | |||||
label_lines.append("time[{}]: {:f}ms".format(device, device_time)) | |||||
op_label = "\n".join(label_lines) | |||||
graph.node(op_name, op_label, shape="rectangle") | |||||
for var_id, shape, device in entry.inputs: | |||||
graph.edge(cache_var(var_id, shape), op_name) | |||||
for var_id, shape, device in entry.outputs: | |||||
graph.edge(op_name, cache_var(var_id, shape)) | |||||
graph.save("{}.graphviz.dot".format(path)) | |||||
from ..core._imperative_rt.core2 import ( | |||||
pop_scope, | |||||
push_scope, | |||||
start_profile, | |||||
stop_profile, | |||||
sync, | |||||
) | |||||
class Profiler: | class Profiler: | ||||
@@ -181,85 +39,45 @@ class Profiler: | |||||
# Only profile record of last iter would be saved | # Only profile record of last iter would be saved | ||||
with Profiler("profile"): | with Profiler("profile"): | ||||
# your code here | # your code here | ||||
# Then open the profile file in chrome timeline window | # Then open the profile file in chrome timeline window | ||||
""" | """ | ||||
CHROME_TIMELINE = "chrome_timeline" | |||||
COMPATIBLE = "compatible" | |||||
GRAPHVIZ = "graphviz" | |||||
WITH_FOOTPRINT = 1 | |||||
CHROME_TIMELINE = "chrome_timeline.json" | |||||
_type_map = { | |||||
OperatorNodeConfig: lambda x: _print_opnode_config(x), | |||||
bytes: lambda x: base64.encodebytes(x).decode("ascii"), | |||||
CollectiveComm.Mode: lambda x: str(x), | |||||
} | |||||
_dumper_map = { | |||||
CHROME_TIMELINE: _dump_chrome_timeline, | |||||
COMPATIBLE: _dump_compatible, | |||||
GRAPHVIZ: _dump_graphviz, | |||||
} | |||||
COMMAND = 1 << 0 | |||||
OPERATOR = 1 << 1 | |||||
TENSOR_LIFETIME = 1 << 2 | |||||
TENSOR_PROP = 1 << 3 | |||||
SYNC = 1 << 4 | |||||
SCOPE = 1 << 5 | |||||
ALL = (1 << 6) - 1 | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
path: str = "profile", | path: str = "profile", | ||||
format: str = CHROME_TIMELINE, | |||||
*, | *, | ||||
formats: Iterable[str] = (CHROME_TIMELINE,), | |||||
type_filter: str = ".*", | |||||
exit_dump: bool = True | |||||
topic=OPERATOR | SCOPE, | |||||
align_time=True, | |||||
show_operator_name=True | |||||
) -> None: | ) -> None: | ||||
self._impl = _Profiler() | |||||
self._path = path | self._path = path | ||||
if isinstance(formats, str): | |||||
formats = (formats,) | |||||
self._filter = type_filter | |||||
self._dumpers = [Profiler._dumper_map[fmt] for fmt in formats] | |||||
self._exit_dump = exit_dump | |||||
self._format = format | |||||
self._options = { | |||||
"topic": int(topic), | |||||
"align_time": int(align_time), | |||||
"show_operator_name": int(show_operator_name), | |||||
} | |||||
def __enter__(self): | def __enter__(self): | ||||
sync() | |||||
self._impl.start(Profiler.WITH_FOOTPRINT) | |||||
start_profile(self._options) | |||||
return self | return self | ||||
def __exit__(self, val, tp, trace): | def __exit__(self, val, tp, trace): | ||||
if self._exit_dump: | |||||
self.dump() | |||||
sync() | |||||
self._impl.stop() | |||||
self._impl.clear() | |||||
@classmethod | |||||
def fetch_attrs(cls, op): | |||||
attrs = dir(op) | |||||
results = {} | |||||
for attr in attrs: | |||||
if attr.startswith("_"): | |||||
continue | |||||
value = op.__getattribute__(attr) | |||||
if callable(value): | |||||
continue | |||||
value_type = type(value) | |||||
if value_type in cls._type_map: | |||||
value = cls._type_map[value_type](value) | |||||
results[attr] = str(value) | |||||
return results | |||||
def dump(self, path: Optional[str] = None): | |||||
stop_profile(self._path, self._format) | |||||
# dump is async, so it's necessary to sync interpreter | |||||
sync() | sync() | ||||
raw = [ | |||||
entry | |||||
for entry in self._impl.dump() | |||||
if re.match(self._filter, type(entry.op).__name__) | |||||
] | |||||
if path is None: | |||||
path = self._path | |||||
for dumper in self._dumpers: | |||||
dumper(raw, path) | |||||
def __call__(self, func): | def __call__(self, func): | ||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
@@ -269,4 +87,23 @@ class Profiler: | |||||
return wrapper | return wrapper | ||||
@contextmanager | |||||
def scope(name): | |||||
push_scope(name) | |||||
yield | |||||
pop_scope(name) | |||||
profile = Profiler | profile = Profiler | ||||
def merge_trace_events(sources: List[str], target: str): | |||||
names = list(map(lambda x: x + ".chrome_timeline.json", sources)) | |||||
result = [] | |||||
for name in names: | |||||
with open(name, "r", encoding="utf-8") as f: | |||||
content = json.load(f) | |||||
for entry in content: | |||||
result.append(entry) | |||||
with open(target + ".chrome_timeline.json", "w") as f: | |||||
json.dump(result, f, ensure_ascii=False, indent=4) |
@@ -807,16 +807,34 @@ void init_tensor(py::module m) { | |||||
} | } | ||||
} | } | ||||
m.def("set_option", | |||||
[](std::string name, int value){ interpreter_for_py->set_option(name, value); }); | |||||
m.def("get_option", | |||||
[](std::string name){ return interpreter_for_py->get_option(name); }); | |||||
m.def("_set_swap_flag", | m.def("_set_swap_flag", | ||||
[](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | |||||
[](bool flag) { interpreter_for_py->set_option("enable_swap", flag); }); | |||||
m.def("_set_drop_flag", | m.def("_set_drop_flag", | ||||
[](bool flag) { interpreter_for_py->set_drop_flag(flag); }); | |||||
[](bool flag) { interpreter_for_py->set_option("enable_drop", flag); }); | |||||
m.def("config_async_level", | m.def("config_async_level", | ||||
[](int level) { interpreter_for_py->config_async_level(level); }); | |||||
[](int level) { | |||||
mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2"); | |||||
interpreter_for_py->set_option("async_level", level); | |||||
}); | |||||
m.def("get_async_level", | m.def("get_async_level", | ||||
[]() { return interpreter_for_py->get_async_level(); }); | |||||
[]() { return interpreter_for_py->get_option("async_level"); }); | |||||
m.def("set_buffer_length", | m.def("set_buffer_length", | ||||
[](int length) { interpreter_for_py->set_buffer_length(length); }); | |||||
[](int length) { | |||||
mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)"); | |||||
interpreter_for_py->set_option("buffer_length", length); | |||||
}); | |||||
m.def("push_scope", | |||||
[](std::string name) { interpreter_for_py->push_scope(name); }); | |||||
m.def("pop_scope", | |||||
[](std::string name) { interpreter_for_py->pop_scope(name); }); | |||||
m.def("start_profile", | |||||
[](std::unordered_map<std::string, int> option) { return interpreter_for_py->start_profile(option); }); | |||||
m.def("stop_profile", | |||||
[](std::string basename, std::string format) { interpreter_for_py->stop_profile(basename, format); }); | |||||
m.def("sync", | m.def("sync", | ||||
[]() { | []() { | ||||
interpreter_for_py->sync(); | interpreter_for_py->sync(); | ||||
@@ -200,33 +200,6 @@ void init_utils(py::module m) { | |||||
m.def("_get_device_count", &mgb::CompNode::get_device_count, | m.def("_get_device_count", &mgb::CompNode::get_device_count, | ||||
"Get total number of specific devices on this system"); | "Get total number of specific devices on this system"); | ||||
using mgb::imperative::ProfileEntry; | |||||
py::class_<ProfileEntry>(m, "ProfileEntry") | |||||
.def_readwrite("op", &ProfileEntry::op) | |||||
.def_readwrite("host", &ProfileEntry::host) | |||||
.def_readwrite("device_list", &ProfileEntry::device_list) | |||||
.def_readwrite("inputs", &ProfileEntry::inputs) | |||||
.def_readwrite("outputs", &ProfileEntry::outputs) | |||||
.def_readwrite("id", &ProfileEntry::id) | |||||
.def_readwrite("parent", &ProfileEntry::parent) | |||||
.def_readwrite("memory", &ProfileEntry::memory) | |||||
.def_readwrite("computation", &ProfileEntry::computation) | |||||
.def_property_readonly("param", [](ProfileEntry& self)->std::string{ | |||||
if(self.param){ | |||||
return self.param->to_string(); | |||||
} else { | |||||
return {}; | |||||
} | |||||
}); | |||||
py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl") | |||||
.def(py::init<>()) | |||||
.def("start", &mgb::imperative::Profiler::start) | |||||
.def("stop", &mgb::imperative::Profiler::stop) | |||||
.def("clear", &mgb::imperative::Profiler::clear) | |||||
.def("dump", &mgb::imperative::Profiler::get_profile); | |||||
using mgb::imperative::TensorSanityCheck; | using mgb::imperative::TensorSanityCheck; | ||||
py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl") | py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl") | ||||
.def(py::init<>()) | .def(py::init<>()) | ||||
@@ -0,0 +1,54 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied | |||||
import json | |||||
import os | |||||
import pytest | |||||
from megengine import Parameter, tensor | |||||
from megengine.core import option | |||||
from megengine.module import Module | |||||
from megengine.utils.profiler import Profiler, scope | |||||
class Simple(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.a = Parameter([1.23], dtype="float32") | |||||
def forward(self, x): | |||||
x = x * self.a | |||||
return x | |||||
def test_profiler(): | |||||
profile_prefix = "pytest_profile" | |||||
profile_format = "chrome_timeline.json" | |||||
profile_path = "{}.{}".format(profile_prefix, profile_format) | |||||
with Profiler(profile_prefix, format=profile_format): | |||||
with scope("my_scope"): | |||||
oup = Simple()(tensor([1.23], dtype="float32")) | |||||
with open(profile_path, "r") as f: | |||||
events = json.load(f) | |||||
os.remove(profile_path) | |||||
prev_ts = {} | |||||
scope_count = 0 | |||||
for event in events: | |||||
if "dur" in event: | |||||
assert event["dur"] >= 0 | |||||
elif "ts" in event and "tid" in event: | |||||
ts = event["ts"] | |||||
tid = event["tid"] | |||||
if ts == 0: | |||||
continue | |||||
assert (tid not in prev_ts) or prev_ts[tid] <= ts | |||||
prev_ts[tid] = ts | |||||
if "name" in event and event["name"] == "my_scope": | |||||
scope_count += 1 | |||||
assert scope_count > 0 and scope_count % 2 == 0 |
@@ -17,52 +17,37 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
template <typename TFunction> | template <typename TFunction> | ||||
class FunctionHooker; | |||||
class FunctionHook; | |||||
template <typename TRet, typename... TArgs> | |||||
class FunctionHooker<TRet(TArgs...)> { | |||||
template <template <typename> class TFunction, typename TRet, typename... TArgs> | |||||
class FunctionHook<TFunction<TRet(TArgs...)>> { | |||||
public: | public: | ||||
using FunctionType = thin_function<TRet(TArgs...)>; | |||||
//Type of hooks. Hook should accept a real function as argument | |||||
//and invoke it on an appropriate time | |||||
using HookType = thin_function<TRet(FunctionType, TArgs...)>; | |||||
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { | |||||
m_backup = {nullptr, [](FunctionType*){}}; | |||||
using FunctionType = TFunction<TRet(TArgs...)>; | |||||
explicit FunctionHook(FunctionType* fptr) : m_fptr{fptr} { | |||||
m_backup = *fptr; | |||||
} | } | ||||
public: | public: | ||||
FunctionHooker& apply_hook(HookType&& hook) { | |||||
if (!m_backup) { | |||||
FunctionType* backup = new FunctionType(*m_fptr); | |||||
//Restore hooked function, would be invoked when destructed | |||||
std::function<void(FunctionType*)> restorer = | |||||
[fptr = m_fptr](FunctionType* bkp) -> void { | |||||
*fptr = *bkp; | |||||
delete bkp; | |||||
}; | |||||
m_backup = decltype(m_backup)(backup, restorer); | |||||
} | |||||
template <typename THook, typename=std::enable_if_t<std::is_invocable_r_v<TRet, THook, FunctionType, TArgs...>, void>> | |||||
FunctionHook& apply_hook(THook&& hook) { | |||||
//Replace with hooked version | //Replace with hooked version | ||||
*m_fptr = [func = *m_fptr, hook](TArgs... args) -> TRet { | |||||
*m_fptr = [func = *m_fptr, hook=std::forward<THook>(hook)](TArgs... args) -> TRet { | |||||
return hook(func, std::forward<TArgs>(args)...); | return hook(func, std::forward<TArgs>(args)...); | ||||
}; | }; | ||||
//Convinent for chain call | //Convinent for chain call | ||||
return *this; | return *this; | ||||
} | } | ||||
private: | private: | ||||
FunctionType* m_fptr; | FunctionType* m_fptr; | ||||
std::unique_ptr<FunctionType, std::function<void(FunctionType*)>> m_backup; | |||||
FunctionType m_backup; | |||||
public: | |||||
~FunctionHook() { | |||||
*m_fptr = std::move(m_backup); | |||||
} | |||||
}; | }; | ||||
//Helps to deduce template args | |||||
template <typename TRet, typename... TArgs> | |||||
FunctionHooker(thin_function<TRet(TArgs...)>* f) | |||||
-> FunctionHooker<TRet(TArgs...)>; | |||||
template<typename TSignature> | |||||
auto make_shared_hook(thin_function<TSignature>* fptr){ | |||||
return std::make_shared<FunctionHooker<TSignature>>(fptr); | |||||
template<typename TFunction> | |||||
auto make_shared_hook(TFunction* fptr){ | |||||
return std::make_shared<FunctionHook<TFunction>>(fptr); | |||||
} | } | ||||
} // namespace imperative | } // namespace imperative | ||||
@@ -0,0 +1,231 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter/commands.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
#include <variant> | |||||
#include "megbrain/tensor.h" | |||||
#include "megbrain/imperative/op_def.h" | |||||
#include "megbrain/imperative/utils/to_string.h" | |||||
namespace mgb::imperative { | |||||
namespace interpreter::intl { | |||||
struct TensorInfo; | |||||
class InterpreterProfiler; | |||||
struct Put { | |||||
TensorInfo* dest; | |||||
HostTensorND value; | |||||
bool no_cache = false; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("dest", dest); | |||||
functor("no_cache", no_cache); | |||||
//functor("value", value); | |||||
} | |||||
const char* get_name() const { | |||||
return "Put"; | |||||
} | |||||
}; | |||||
struct ApplyOp { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<TensorInfo*> inputs; | |||||
SmallVector<TensorInfo*> outputs; | |||||
SmallVector<TensorInfo*> dels; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("op", op); | |||||
functor("inputs", inputs); | |||||
functor("outputs", outputs); | |||||
functor("dels", dels); | |||||
} | |||||
const char* get_name() const { | |||||
return "ApplyOp"; | |||||
} | |||||
}; | |||||
struct Del { | |||||
TensorInfo* dest; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("dest", dest); | |||||
} | |||||
const char* get_name() const { | |||||
return "Del"; | |||||
} | |||||
}; | |||||
struct GetValue { | |||||
TensorInfo* dest; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("dest", dest); | |||||
} | |||||
const char* get_name() const { | |||||
return "GetValue"; | |||||
} | |||||
}; | |||||
struct SwapIn { | |||||
TensorInfo* dest; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("dest", dest); | |||||
} | |||||
const char* get_name() const { | |||||
return "SwapIn"; | |||||
} | |||||
}; | |||||
struct SwapOut { | |||||
TensorInfo* dest; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("dest", dest); | |||||
} | |||||
const char* get_name() const { | |||||
return "SwapOut"; | |||||
} | |||||
}; | |||||
struct Drop { | |||||
TensorInfo* dest; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("dest", dest); | |||||
} | |||||
const char* get_name() const { | |||||
return "Drop"; | |||||
} | |||||
}; | |||||
struct SetOption { | |||||
std::string key; | |||||
int value; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("key", key); | |||||
functor("value", value); | |||||
} | |||||
const char* get_name() const { | |||||
return "SetOption"; | |||||
} | |||||
}; | |||||
struct StartProfile { | |||||
InterpreterProfiler* profiler; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const {} | |||||
const char* get_name() const { | |||||
return "StartProfile"; | |||||
} | |||||
}; | |||||
struct StopProfile { | |||||
std::string basename; | |||||
std::string format; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("basename", basename); | |||||
functor("format", format); | |||||
} | |||||
const char* get_name() const { | |||||
return "StopProfile"; | |||||
} | |||||
}; | |||||
struct PushScope { | |||||
std::string scope_name; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("scope_name", scope_name); | |||||
} | |||||
const char* get_name() const { | |||||
return "PushScope"; | |||||
} | |||||
}; | |||||
struct PopScope { | |||||
std::string scope_name; | |||||
template <typename TFunctor> | |||||
void get_props(TFunctor&& functor) const { | |||||
functor("scope_name", scope_name); | |||||
} | |||||
const char* get_name() const { | |||||
return "PopScope"; | |||||
} | |||||
}; | |||||
using Command = std::variant<Put, | |||||
ApplyOp, | |||||
Del, | |||||
GetValue, | |||||
SwapIn, | |||||
SwapOut, | |||||
Drop, | |||||
SetOption, | |||||
StartProfile, | |||||
StopProfile, | |||||
PushScope, | |||||
PopScope>; | |||||
using IdentifiedCommand = std::pair<uint64_t, Command>; | |||||
} | |||||
template <> | |||||
struct ToStringTrait<interpreter::intl::Command>{ | |||||
std::string operator()(const interpreter::intl::Command& cmd) const { | |||||
return std::visit([](auto& cmd){ | |||||
std::string result = cmd.get_name(); | |||||
result += "{"; | |||||
cmd.get_props([&](const char* key, auto&& value) { | |||||
result += key; | |||||
result += ": "; | |||||
result += to_string(value); | |||||
result += ","; | |||||
}); | |||||
result += "}"; | |||||
return result; | |||||
}, cmd); | |||||
} | |||||
}; | |||||
} |
@@ -0,0 +1,92 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter/events.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "./commands.h" | |||||
#include "./tensor_info.h" | |||||
namespace mgb::imperative::interpreter::intl { | |||||
struct CommandEvent { | |||||
IdentifiedCommand icmd; | |||||
}; | |||||
struct CommandEnqueueEvent: CommandEvent {}; | |||||
struct CommandExecuteEvent: CommandEvent {}; | |||||
struct CommandFinishEvent: CommandEvent {}; | |||||
struct OpEvent { | |||||
uint64_t id; | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<uint64_t> inputs; | |||||
SmallVector<uint64_t> outputs; | |||||
}; | |||||
struct HostOpExecuteEvent: OpEvent {}; | |||||
struct DeviceOpExecuteEvent: OpEvent {}; | |||||
struct HostOpFinishEvent: OpEvent {}; | |||||
struct DeviceOpFinishEvent: OpEvent {}; | |||||
struct TensorDeclareEvent { | |||||
uint64_t tensor_id; | |||||
}; | |||||
struct TensorProduceEvent { | |||||
uint64_t tensor_id; | |||||
TensorLayout layout; | |||||
CompNode device; | |||||
}; | |||||
struct TensorEraseEvent { | |||||
uint64_t tensor_id; | |||||
}; | |||||
struct TensorPropEvent { | |||||
uint64_t tensor_id; | |||||
TensorInfo::Prop prop; | |||||
std::string prop_desc; | |||||
}; | |||||
struct TensorGetPropEvent: TensorPropEvent{}; | |||||
struct TensorWaitPropEvent: TensorPropEvent{}; | |||||
struct TensorNotifyPropEvent: TensorPropEvent{}; | |||||
struct TensorWaitPropFinishEvent: TensorPropEvent{}; | |||||
struct SyncStartEvent {}; | |||||
struct SyncFinishEvent {}; | |||||
struct ScopeEvent { | |||||
std::string name; | |||||
}; | |||||
struct ChannelBeginScope: ScopeEvent {}; | |||||
struct ChannelEndScope: ScopeEvent {}; | |||||
struct WorkerBeginScope: ScopeEvent {}; | |||||
struct WorkerEndScope: ScopeEvent {}; | |||||
struct DeviceBeginScope: ScopeEvent {}; | |||||
struct DeviceEndScope: ScopeEvent {}; | |||||
} |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file imperative/src/impl/interpreter_impl.cpp | |||||
* \file imperative/src/impl/interpreter/interpreter_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,10 +10,14 @@ | |||||
*/ | */ | ||||
#include "./interpreter_impl.h" | #include "./interpreter_impl.h" | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/imperative/opr_utility.h" | #include "megbrain/imperative/opr_utility.h" | ||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/utils/to_string.h" | |||||
#include "../op_trait.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace imperative; | using namespace imperative; | ||||
@@ -48,6 +52,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { | |||||
info->desc.layout = data.layout(); | info->desc.layout = data.layout(); | ||||
info->desc.comp_node = data.comp_node(); | info->desc.comp_node = data.comp_node(); | ||||
info->ptr = Tensor::make(data); | info->ptr = Tensor::make(data); | ||||
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node); | |||||
return info; | return info; | ||||
} | } | ||||
@@ -61,7 +66,7 @@ void ChannelImpl::del(Handle handle) { | |||||
} | } | ||||
void ChannelImpl::swap_in(Handle handle) { | void ChannelImpl::swap_in(Handle handle) { | ||||
if (m_enable_evict & SWAP) { | |||||
if (m_worker_state.options.enable_swap) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto* info = reinterpret_cast<TensorInfo*>(handle); | auto* info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -71,7 +76,7 @@ void ChannelImpl::swap_in(Handle handle) { | |||||
} | } | ||||
void ChannelImpl::swap_out(Handle handle) { | void ChannelImpl::swap_out(Handle handle) { | ||||
if (m_enable_evict & SWAP) { | |||||
if (m_worker_state.options.enable_swap) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto* info = reinterpret_cast<TensorInfo*>(handle); | auto* info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -81,7 +86,7 @@ void ChannelImpl::swap_out(Handle handle) { | |||||
} | } | ||||
void ChannelImpl::drop(Handle handle) { | void ChannelImpl::drop(Handle handle) { | ||||
if (m_enable_evict & DROP) { | |||||
if (m_worker_state.options.enable_drop) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto* info = reinterpret_cast<TensorInfo*>(handle); | auto* info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -100,6 +105,7 @@ void ChannelImpl::dispatch_default_cpu( | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | const SmallVector<LogicalTensorDesc>& input_descs, | ||||
SmallVector<Handle>* outputs) { | SmallVector<Handle>* outputs) { | ||||
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | ||||
MGB_MARK_USED_VAR(validated); | |||||
SmallVector<DeviceTensorND> input_tensornds; | SmallVector<DeviceTensorND> input_tensornds; | ||||
input_tensornds.reserve(input_descs.size()); | input_tensornds.reserve(input_descs.size()); | ||||
@@ -133,6 +139,17 @@ void ChannelImpl::dispatch_default_cpu( | |||||
output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); | output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); | ||||
} | } | ||||
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { | |||||
SmallVector<uint64_t> tid; | |||||
for (auto* ptinfo: tinfo) { | |||||
tid.push_back(ptinfo->id); | |||||
} | |||||
return tid; | |||||
}; | |||||
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; | |||||
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data); | |||||
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); | OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); | ||||
SmallVector<TensorInfo*> output_infos; | SmallVector<TensorInfo*> output_infos; | ||||
@@ -146,9 +163,14 @@ void ChannelImpl::dispatch_default_cpu( | |||||
output_infos.push_back(info); | output_infos.push_back(info); | ||||
outputs->push_back(info); | outputs->push_back(info); | ||||
} | } | ||||
if (m_enable_evict & DROP) { | |||||
if (m_channel_state.options.enable_drop) { | |||||
TensorInfo::ComputePath::make(op, input_infos, output_infos); | TensorInfo::ComputePath::make(op, input_infos, output_infos); | ||||
} | } | ||||
event_data.outputs = tinfo_to_tid(output_infos); | |||||
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data); | |||||
} | } | ||||
void ChannelImpl::dispatch_kernel( | void ChannelImpl::dispatch_kernel( | ||||
@@ -173,13 +195,13 @@ void ChannelImpl::dispatch_kernel( | |||||
cmd.outputs.push_back(info); | cmd.outputs.push_back(info); | ||||
outputs->push_back(info); | outputs->push_back(info); | ||||
} | } | ||||
if (m_enable_evict & DROP) { | |||||
if (m_channel_state.options.enable_drop) { | |||||
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); | TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); | ||||
} | } | ||||
m_buffer.enqueue(std::move(cmd)); | m_buffer.enqueue(std::move(cmd)); | ||||
if (!validated && m_async_level == 1) { | |||||
if (!validated && m_channel_state.options.async_level == 1) { | |||||
sync(); | sync(); | ||||
} else if (m_async_level == 0) { | |||||
} else if (m_channel_state.options.async_level == 0) { | |||||
sync(); | sync(); | ||||
// check device error | // check device error | ||||
for (auto&& oup : *outputs) { | for (auto&& oup : *outputs) { | ||||
@@ -212,7 +234,10 @@ SmallVector<Handle> ChannelImpl::apply_op( | |||||
} | } | ||||
SmallVector<Handle> outputs; | SmallVector<Handle> outputs; | ||||
switch (OpDef::decide_dispatch_mode(*op, input_descs)) { | |||||
DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute | |||||
? OpDef::decide_dispatch_mode(*op, input_descs) | |||||
: DispatchMode::KERNEL; | |||||
switch (dispatch_mode) { | |||||
case DEFAULT_CPU: { | case DEFAULT_CPU: { | ||||
dispatch_default_cpu(op, input_infos, input_descs, &outputs); | dispatch_default_cpu(op, input_infos, input_descs, &outputs); | ||||
break; | break; | ||||
@@ -242,11 +267,13 @@ HostTensorND ChannelImpl::get_value(Handle handle) { | |||||
m_waitee = info; | m_waitee = info; | ||||
regenerate(info); | regenerate(info); | ||||
m_buffer.enqueue(GetValue{info}); | m_buffer.enqueue(GetValue{info}); | ||||
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue); | |||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
tensor_ptr = info->ptr; | tensor_ptr = info->ptr; | ||||
return value_fetched(); | return value_fetched(); | ||||
}); | }); | ||||
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue); | |||||
m_waitee = nullptr; | m_waitee = nullptr; | ||||
} | } | ||||
return tensor_ptr->get_value(); | return tensor_ptr->get_value(); | ||||
@@ -262,11 +289,13 @@ TensorShape ChannelImpl::get_shape(Handle handle) { | |||||
std::unique_lock<decltype(m_mutex)> lock(m_mutex); | std::unique_lock<decltype(m_mutex)> lock(m_mutex); | ||||
mgb_assert(!m_waitee); | mgb_assert(!m_waitee); | ||||
m_waitee = info; | m_waitee = info; | ||||
m_buffer.enqueue(Flush{info}); | |||||
m_buffer.flush(); | |||||
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape); | |||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return static_cast<bool>(info->ptr); | return static_cast<bool>(info->ptr); | ||||
}); | }); | ||||
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape); | |||||
m_waitee = nullptr; | m_waitee = nullptr; | ||||
TensorShape ret = info->ptr->layout(); | TensorShape ret = info->ptr->layout(); | ||||
mgb_assert(ret.ndim != 0); | mgb_assert(ret.ndim != 0); | ||||
@@ -277,6 +306,7 @@ DType ChannelImpl::get_dtype(Handle handle) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType); | |||||
auto ret = info->desc.layout.dtype; | auto ret = info->desc.layout.dtype; | ||||
mgb_assert(ret.valid()); | mgb_assert(ret.valid()); | ||||
return ret; | return ret; | ||||
@@ -286,6 +316,7 @@ CompNode ChannelImpl::get_device(Handle handle) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device); | |||||
auto ret = info->desc.comp_node; | auto ret = info->desc.comp_node; | ||||
mgb_assert(ret.valid()); | mgb_assert(ret.valid()); | ||||
return ret; | return ret; | ||||
@@ -299,20 +330,23 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { | |||||
mgb_assert(!m_waitee); | mgb_assert(!m_waitee); | ||||
m_waitee = info; | m_waitee = info; | ||||
regenerate(info); | regenerate(info); | ||||
m_buffer.enqueue(Flush{info}); | |||||
m_buffer.flush(); | |||||
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue); | |||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return static_cast<bool>(info->ptr); | return static_cast<bool>(info->ptr); | ||||
}); | }); | ||||
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue); | |||||
m_waitee = nullptr; | m_waitee = nullptr; | ||||
return info->ptr->dev_tensor(); | return info->ptr->dev_tensor(); | ||||
} | } | ||||
void ChannelImpl::sync() { | void ChannelImpl::sync() { | ||||
if (!m_buffer.empty()) { | |||||
m_buffer.enqueue(Flush{}); | |||||
} | |||||
m_buffer.flush(); | |||||
m_channel_state.profiler->record_host<SyncStartEvent>(); | |||||
m_worker.wait_all_task_finish(); | m_worker.wait_all_task_finish(); | ||||
CompNode::sync_all(); | |||||
m_channel_state.profiler->record_host<SyncFinishEvent>(); | |||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
} | } | ||||
@@ -321,33 +355,41 @@ void ChannelImpl::close() { | |||||
sync(); | sync(); | ||||
} | } | ||||
void ChannelImpl::config_async_level(int level) { | |||||
mgb_assert(level <= 2 && level >= 0, "async_level should be 0, 1 or 2"); | |||||
m_async_level = level; | |||||
int ChannelImpl::get_option(std::string name) { | |||||
return m_channel_state.options.get_option(name); | |||||
} | } | ||||
int ChannelImpl::get_async_level() { | |||||
return m_async_level; | |||||
void ChannelImpl::set_option(std::string name, int value) { | |||||
m_channel_state.options.set_option(name, value); | |||||
m_buffer.enqueue(SetOption{name, value}); | |||||
} | } | ||||
TensorInfo* ChannelImpl::alloc() { | TensorInfo* ChannelImpl::alloc() { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
auto info = m_pool.alloc(); | auto info = m_pool.alloc(); | ||||
m_valid_handle.insert(info); | m_valid_handle.insert(info); | ||||
info->id = m_last_id++; | |||||
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id); | |||||
return info; | return info; | ||||
} | } | ||||
void ChannelImpl::free(TensorInfo* ptr) { | void ChannelImpl::free(TensorInfo* ptr) { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id); | |||||
m_pool.free(ptr); | m_pool.free(ptr); | ||||
} | } | ||||
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){ | |||||
m_channel_state.tid = std::this_thread::get_id(); | |||||
} | |||||
ChannelImpl::~ChannelImpl() { | ChannelImpl::~ChannelImpl() { | ||||
close(); | close(); | ||||
} | } | ||||
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node()); | |||||
dest->value_fetched = ptr->value_fetched(); | dest->value_fetched = ptr->value_fetched(); | ||||
// update tensor desc for static infer | // update tensor desc for static infer | ||||
dest->desc.layout = ptr->layout(); | dest->desc.layout = ptr->layout(); | ||||
@@ -397,55 +439,57 @@ void ChannelImpl::detach_users(TensorInfo* dest) { | |||||
output->detach_producer(); | output->detach_producer(); | ||||
} | } | ||||
} | } | ||||
dest->users.clear(); | |||||
mgb_assert(dest->users.size() == 0); | |||||
//dest->users.clear(); | |||||
} | } | ||||
void ChannelImpl::set_swap_flag(bool flag) { | |||||
if ((!flag) && (m_enable_evict & SWAP)) { | |||||
for (auto handle: m_valid_handle) { | |||||
auto* info = reinterpret_cast<TensorInfo*>(handle); | |||||
if (info->evict_type == SWAP) { | |||||
swap_in(info); | |||||
} | |||||
void ChannelImpl::sync_device_scope(CompNode device) { | |||||
auto& prev = m_worker_state.device_scope_map[device]; | |||||
auto& current = m_worker_state.scopes; | |||||
auto push_scope = [&](std::string name) { | |||||
m_worker_state.profiler->record_device<DeviceBeginScope>(device, name); | |||||
}; | |||||
auto pop_scope = [&](std::string name) { | |||||
m_worker_state.profiler->record_device<DeviceEndScope>(device, name); | |||||
}; | |||||
size_t similarity = 0; | |||||
for (size_t i = 0; i < prev.size() && i < current.size(); i++) { | |||||
if (prev[i] == current[i]) { | |||||
similarity++; | |||||
} else { | |||||
break; | |||||
} | } | ||||
} | } | ||||
if (flag) { | |||||
m_enable_evict |= SWAP; | |||||
} else { | |||||
m_enable_evict &= ~SWAP; | |||||
while (prev.size() > similarity) { | |||||
pop_scope(prev.back()); | |||||
prev.pop_back(); | |||||
} | } | ||||
} | |||||
void ChannelImpl::set_drop_flag(bool flag) { | |||||
if ((!flag) && (m_enable_evict & DROP)) { | |||||
for (auto handle: m_valid_handle) { | |||||
auto* info = reinterpret_cast<TensorInfo*>(handle); | |||||
if (info->evict_type == DROP) { | |||||
recompute(info->producer); | |||||
} | |||||
} | |||||
} | |||||
if (flag) { | |||||
m_enable_evict |= DROP; | |||||
} else { | |||||
m_enable_evict &= ~DROP; | |||||
while (prev.size() < current.size()) { | |||||
prev.push_back(current[prev.size()]); | |||||
push_scope(prev.back()); | |||||
} | } | ||||
} | } | ||||
void ChannelImpl::set_buffer_length(int length) { | |||||
m_buffer.set_capacity(length); | |||||
} | |||||
void ChannelImpl::process_one_task(Command& cmd) { | |||||
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { | |||||
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd); | |||||
bool finished = false; | |||||
auto do_finish_command = [&]{ | |||||
if (finished) { | |||||
return; | |||||
} | |||||
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd); | |||||
finished = true; | |||||
}; | |||||
//TODO: remove std::visit for support osx 10.12 | //TODO: remove std::visit for support osx 10.12 | ||||
std::visit([this](auto& cmd) { | |||||
using T = std::remove_reference_t<decltype(cmd)>; | |||||
try { | |||||
auto cmd_visitor = [&](auto& cmd) { | |||||
using T = std::remove_reference_t<decltype(cmd)>; | |||||
if constexpr (std::is_same_v<T, Put>) { | if constexpr (std::is_same_v<T, Put>) { | ||||
auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value); | auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value); | ||||
produce_tensor(cmd.dest, std::move(value)); | produce_tensor(cmd.dest, std::move(value)); | ||||
} else if constexpr (std::is_same_v<T, ApplyOp>) { | } else if constexpr (std::is_same_v<T, ApplyOp>) { | ||||
uint64_t apply_id = ++m_last_id; | |||||
SmallVector<TensorPtr> tensor_inputs; | SmallVector<TensorPtr> tensor_inputs; | ||||
SmallVector<CompNode> devices; | |||||
tensor_inputs.reserve(cmd.inputs.size()); | tensor_inputs.reserve(cmd.inputs.size()); | ||||
// refcnt == 1, owners: [TensorInfo::ptr] | // refcnt == 1, owners: [TensorInfo::ptr] | ||||
for (auto i : cmd.inputs) { | for (auto i : cmd.inputs) { | ||||
@@ -453,6 +497,23 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
// refcnt ++, owners: [i->ptr, tensor_inputs] | // refcnt ++, owners: [i->ptr, tensor_inputs] | ||||
tensor_inputs.push_back(i->ptr); | tensor_inputs.push_back(i->ptr); | ||||
} | } | ||||
// Begin profiling operator | |||||
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { | |||||
SmallVector<uint64_t> tid; | |||||
for (auto* ptinfo: tinfo) { | |||||
tid.push_back(ptinfo->id); | |||||
} | |||||
return tid; | |||||
}; | |||||
OpEvent event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)}; | |||||
// Collecting devices | |||||
for (auto i : cmd.inputs) { | |||||
devices.push_back(i->desc.comp_node); | |||||
} | |||||
for (auto i : cmd.outputs) { | |||||
devices.push_back(i->desc.comp_node); | |||||
} | |||||
devices.erase(std::unique(devices.begin(), devices.end()), devices.end()); | |||||
// Fused by command buffer. @see: CommandBuffer::fuse_del | // Fused by command buffer. @see: CommandBuffer::fuse_del | ||||
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. | // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. | ||||
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. | // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. | ||||
@@ -461,9 +522,24 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
// if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor | // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor | ||||
free(del); | free(del); | ||||
} | } | ||||
// Before wait | |||||
//TODO: split operator wait and execute so that OpWait could be corrected recorded. | |||||
// Before execute | |||||
m_worker_state.profiler->record_host<HostOpExecuteEvent>(event_data); | |||||
for (auto&& device: devices) { | |||||
sync_device_scope(device); | |||||
m_worker_state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data); | |||||
} | |||||
// Apply op | |||||
// Here std::move is REQUIRED for removing duplicated references. | // Here std::move is REQUIRED for removing duplicated references. | ||||
auto tensor_outputs = OpDef::apply_on_physical_tensor( | auto tensor_outputs = OpDef::apply_on_physical_tensor( | ||||
*cmd.op, std::move(tensor_inputs)); | *cmd.op, std::move(tensor_inputs)); | ||||
// After execute | |||||
m_worker_state.profiler->record_host<HostOpFinishEvent>(event_data); | |||||
for (auto&& device: devices) { | |||||
m_worker_state.profiler->record_device<DeviceOpFinishEvent>(device, event_data); | |||||
} | |||||
// End profiling operator | |||||
mgb_assert(tensor_outputs.size() == cmd.outputs.size()); | mgb_assert(tensor_outputs.size() == cmd.outputs.size()); | ||||
for (size_t i = 0; i < tensor_outputs.size(); ++i) { | for (size_t i = 0; i < tensor_outputs.size(); ++i) { | ||||
if (cmd.outputs[i] == nullptr) { | if (cmd.outputs[i] == nullptr) { | ||||
@@ -488,13 +564,51 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
release_tensor(cmd.dest); | release_tensor(cmd.dest); | ||||
} else if constexpr (std::is_same_v<T, Drop>) { | } else if constexpr (std::is_same_v<T, Drop>) { | ||||
release_tensor(cmd.dest); | release_tensor(cmd.dest); | ||||
} else if constexpr (std::is_same_v<T, Move>) { | |||||
produce_tensor(cmd.dest, cmd.src->ptr); | |||||
free(cmd.src); | |||||
} else if constexpr (std::is_same_v<T, SetOption>) { | |||||
m_worker_state.options.set_option(cmd.key, cmd.value); | |||||
} else if constexpr (std::is_same_v<T, StartProfile>) { | |||||
CompNode::sync_all(); | |||||
m_worker_state.profiler.reset(cmd.profiler); | |||||
} else if constexpr (std::is_same_v<T, StopProfile>) { | |||||
for (auto&& [device, scopes]: m_worker_state.device_scope_map) { | |||||
MGB_MARK_USED_VAR(scopes); | |||||
sync_device_scope(device); | |||||
} | |||||
do_finish_command(); | |||||
auto profiler = std::make_unique<InterpreterProfiler>(); | |||||
std::swap(profiler, m_worker_state.profiler); | |||||
auto records = profiler->stop(); | |||||
auto host_map = [this](std::thread::id tid) { | |||||
if (tid == m_channel_state.tid) { | |||||
return "channel"; | |||||
} else if (tid == m_worker_state.tid) { | |||||
return "worker"; | |||||
} else { | |||||
return "unknown"; | |||||
} | |||||
}; | |||||
InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map); | |||||
} else if constexpr (std::is_same_v<T, PushScope>) { | |||||
m_worker_state.scopes.push_back(cmd.scope_name); | |||||
do_finish_command(); | |||||
m_worker_state.profiler->record_host<WorkerBeginScope>(cmd.scope_name); | |||||
} else if constexpr (std::is_same_v<T, PopScope>) { | |||||
mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch"); | |||||
m_worker_state.scopes.pop_back(); | |||||
do_finish_command(); | |||||
m_worker_state.profiler->record_host<WorkerEndScope>(cmd.scope_name); | |||||
} else { | } else { | ||||
static_assert(std::is_same_v<T, Flush> || | |||||
std::is_same_v<T, Nop>); | |||||
static_assert(std::is_same_v<T, T>); | |||||
} | } | ||||
}; | |||||
std::visit([&](auto& cmd){ | |||||
using T = std::decay_t<decltype(cmd)>; | |||||
if (!m_worker_state.options.catch_worker_execption) { | |||||
cmd_visitor(cmd); | |||||
return; | |||||
} | |||||
try { | |||||
cmd_visitor(cmd); | |||||
} catch (...) { | } catch (...) { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
if constexpr (std::is_same_v<T, ApplyOp>) { | if constexpr (std::is_same_v<T, ApplyOp>) { | ||||
@@ -507,7 +621,8 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
m_worker_exc = std::current_exception(); | m_worker_exc = std::current_exception(); | ||||
m_cv.notify_all(); | m_cv.notify_all(); | ||||
} | } | ||||
}, cmd); | |||||
}, icmd.second); | |||||
do_finish_command(); | |||||
} | } | ||||
void ChannelImpl::check_worker_exc_unsafe() { | void ChannelImpl::check_worker_exc_unsafe() { | ||||
@@ -524,18 +639,22 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) { | |||||
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) { | if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) { | ||||
return; | return; | ||||
} | } | ||||
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, cmd); | |||||
mgb_log_debug("%s Enqueued", command_repr.c_str()); | |||||
mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); | |||||
m_commands.push_back(std::move(cmd)); | m_commands.push_back(std::move(cmd)); | ||||
auto flush_pos = flush_pos_for(m_commands.back()); | auto flush_pos = flush_pos_for(m_commands.back()); | ||||
flush(flush_pos); | flush(flush_pos); | ||||
} | } | ||||
void ChannelImpl::CommandBuffer::flush() { | |||||
flush(m_commands.end()); | |||||
} | |||||
void ChannelImpl::CommandBuffer::flush(Handle pos) { | void ChannelImpl::CommandBuffer::flush(Handle pos) { | ||||
for (auto iter = m_commands.begin(); iter != pos; ++iter) { | for (auto iter = m_commands.begin(); iter != pos; ++iter) { | ||||
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, *iter); | |||||
mgb_log_debug("%s Flushed", command_repr.c_str()); | |||||
m_owner->m_worker.add_task(std::move(*iter)); | |||||
mgb_log_debug("%s Flushed", to_string(*iter).c_str()); | |||||
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; | |||||
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd); | |||||
m_owner->m_worker.add_task(std::move(icmd)); | |||||
} | } | ||||
m_commands.erase(m_commands.begin(), pos); | m_commands.erase(m_commands.begin(), pos); | ||||
} | } | ||||
@@ -555,17 +674,10 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { | |||||
} | } | ||||
} else if constexpr (std::is_same_v<T, GetValue>) { | } else if constexpr (std::is_same_v<T, GetValue>) { | ||||
return m_commands.end(); | return m_commands.end(); | ||||
} else if constexpr (std::is_same_v<T, Flush>) { | |||||
if (cmd.dest == nullptr) { | |||||
return m_commands.end(); | |||||
} | |||||
auto produce_iter = find_produce(cmd.dest, {m_commands.begin(), m_commands.end()}); | |||||
if (produce_iter != m_commands.end()) { | |||||
return produce_iter + 1; | |||||
} | |||||
} | } | ||||
if (m_commands.size() > m_capacity) { | |||||
return m_commands.begin() + (m_commands.size() - m_capacity); | |||||
size_t buffer_length = m_owner->m_channel_state.options.buffer_length; | |||||
if (m_commands.size() > buffer_length) { | |||||
return m_commands.begin() + (m_commands.size() - buffer_length); | |||||
} | } | ||||
return m_commands.begin(); | return m_commands.begin(); | ||||
}, cmd); | }, cmd); | ||||
@@ -589,7 +701,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { | |||||
if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { | if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { | ||||
return false; | return false; | ||||
} | } | ||||
mgb_log_debug("%s Fused", cmd.to_string().c_str()); | |||||
mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); | |||||
std::get<ApplyOp>(*apply_iter).dels.push_back(dest); | std::get<ApplyOp>(*apply_iter).dels.push_back(dest); | ||||
return true; | return true; | ||||
} | } | ||||
@@ -636,3 +748,41 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) | |||||
}, cmd); | }, cmd); | ||||
}); | }); | ||||
} | } | ||||
void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { | |||||
auto profiler_option = InterpreterProfiler::Option::from_dict(option); | |||||
auto profiler = std::make_unique<InterpreterProfiler>(); | |||||
profiler->set_option(profiler_option); | |||||
profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic)); | |||||
std::swap(profiler, m_channel_state.profiler); | |||||
m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()}); | |||||
} | |||||
void ChannelImpl::stop_profile(std::string basename, std::string format) { | |||||
m_buffer.flush(); | |||||
auto profiler = std::make_unique<InterpreterProfiler>(); | |||||
std::swap(profiler, m_channel_state.profiler); | |||||
profiler.release(); | |||||
m_buffer.enqueue(StopProfile{basename, format}); | |||||
} | |||||
void ChannelImpl::push_scope(std::string name) { | |||||
m_channel_state.profiler->record_host<ChannelBeginScope>(name); | |||||
m_channel_state.scopes.push_back(name); | |||||
m_buffer.enqueue(PushScope{name}); | |||||
} | |||||
void ChannelImpl::pop_scope(std::string name) { | |||||
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); | |||||
m_channel_state.scopes.pop_back(); | |||||
m_channel_state.profiler->record_host<ChannelEndScope>(name); | |||||
m_buffer.enqueue(PopScope{name}); | |||||
} | |||||
void ChannelImpl::assert_in_channel() { | |||||
mgb_assert(m_channel_state.tid != std::this_thread::get_id()); | |||||
} | |||||
void ChannelImpl::assert_in_worker() { | |||||
mgb_assert(m_worker_state.tid == std::this_thread::get_id()); | |||||
} |
@@ -0,0 +1,205 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter/interpreter_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <deque> | |||||
#include <future> | |||||
#include <list> | |||||
#include <thread> | |||||
#include <unordered_set> | |||||
#include <variant> | |||||
#include "megbrain/utils/mempool.h" | |||||
#include "megbrain/imperative/interpreter.h" | |||||
#include "megbrain/imperative/profiler.h" | |||||
#include "./commands.h" | |||||
#include "./events.h" | |||||
#include "./tensor_info.h" | |||||
#include "./option_manager.h" | |||||
#include "./profiler.h" | |||||
namespace mgb::imperative::interpreter::intl { | |||||
using Handle = Interpreter::Handle; | |||||
struct InterpreterImpl : Interpreter { | |||||
std::unique_ptr<Channel> create_channel() override; | |||||
}; | |||||
struct ChannelImpl : Interpreter::Channel { | |||||
ChannelImpl(); | |||||
~ChannelImpl() override; | |||||
Handle put(const HostTensorND& value, bool no_cache) override; | |||||
Handle put(const DeviceTensorND& value) override; | |||||
void del(Handle) override; | |||||
void swap_in(Handle) override; | |||||
void swap_out(Handle) override; | |||||
void drop(Handle) override; | |||||
SmallVector<Handle> apply_op( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<Handle>& inputs) override; | |||||
HostTensorND get_value(Handle) override; | |||||
TensorShape get_shape(Handle) override; | |||||
DType get_dtype(Handle) override; | |||||
CompNode get_device(Handle) override; | |||||
DeviceTensorND get_dev_tensor(Handle) override; | |||||
void sync() override; | |||||
void close() override; | |||||
int get_option(std::string name) override; | |||||
void set_option(std::string name, int value) override; | |||||
void start_profile(std::unordered_map<std::string, int> option) override; | |||||
void stop_profile(std::string basename, std::string format) override; | |||||
void push_scope(std::string) override; | |||||
void pop_scope(std::string) override; | |||||
private: | |||||
TensorInfo* alloc(); | |||||
void free(TensorInfo*); | |||||
void detach_users(TensorInfo*); | |||||
void process_one_task(IdentifiedCommand&); | |||||
void check_worker_exc_unsafe(); | |||||
void produce_tensor(TensorInfo* dest, TensorPtr ptr); | |||||
void release_tensor(TensorInfo* dest); | |||||
void regenerate(TensorInfo* dest); | |||||
void recompute(TensorInfo::ComputePath* path); | |||||
void dispatch_default_cpu( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs); | |||||
void dispatch_kernel( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs); | |||||
void assert_in_channel(); | |||||
void assert_in_worker(); | |||||
void sync_device_scope(CompNode device); | |||||
template <typename TCommand> | |||||
void enqueue_command(TCommand&& cmd) { | |||||
m_buffer.enqueue(Command{std::forward<TCommand>(cmd)}); | |||||
} | |||||
std::mutex m_mutex; | |||||
std::condition_variable m_cv; | |||||
MemPool<TensorInfo> m_pool; | |||||
std::unordered_set<Handle> m_valid_handle; | |||||
TensorInfo* m_waitee = nullptr; | |||||
std::exception_ptr m_worker_exc; | |||||
std::atomic_uint64_t m_last_id = 0; | |||||
struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> { | |||||
// set max_spin=0 to prevent Queue fetch task in busy wait manner. | |||||
// this won't affect throughput when python interpreter is sending enough task, | |||||
// but will significantly save CPU time when waiting for task, e.g. wait for data input | |||||
WorkQueue(ChannelImpl* owner) | |||||
: AsyncQueueSC<IdentifiedCommand, WorkQueue>(0), m_owner(owner) { | |||||
sys::set_thread_name("interpreter"); | |||||
} | |||||
void process_one_task(IdentifiedCommand& icmd) { | |||||
m_owner->process_one_task(icmd); | |||||
} | |||||
void on_async_queue_worker_thread_start() override { | |||||
sys::set_thread_name("worker"); | |||||
m_owner->m_worker_state.tid = std::this_thread::get_id(); | |||||
} | |||||
private: | |||||
ChannelImpl* m_owner; | |||||
} m_worker; | |||||
/** | |||||
* Buf a command window for following fuse | |||||
* example: | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} | | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} | | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... | | |||||
* --------------------------------------------------------------------- | |||||
* Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task | |||||
*/ | |||||
struct CommandBuffer { | |||||
CommandBuffer(ChannelImpl* owner) : m_owner(owner) {} | |||||
void enqueue(Command cmd); | |||||
bool empty() const { | |||||
return m_commands.empty(); | |||||
} | |||||
void flush(); | |||||
private: | |||||
ChannelImpl* m_owner; | |||||
std::deque<Command> m_commands; | |||||
using Handle = decltype(m_commands)::iterator; | |||||
// [begin, end) | |||||
using Range = std::array<Handle, 2>; | |||||
// Launch commands in range [m_commands.begin(), pos) | |||||
void flush(Handle pos); | |||||
// Select flush position for incoming cmd | |||||
Handle flush_pos_for(const Command& cmd); | |||||
// Fuse del command into suitable ApplyOp | |||||
bool fuse_del(const Del& cmd); | |||||
// Returns the last handle that dest is used within range. If dest is not used, returns range[1] | |||||
Handle find_last_usage(TensorInfo* dest, Range range); | |||||
// Returns the produce position of dest. If not found, returns range[1] | |||||
Handle find_produce(TensorInfo* dest, Range range); | |||||
} m_buffer; | |||||
//! config whether raise error exactly when invoking op. | |||||
//! level 2: both device and user side errors are async; | |||||
//! level 1: user side errors are sync; | |||||
//! level 0: both sync. | |||||
int m_async_level = 2; | |||||
int m_max_recompute_time = 1; | |||||
struct State { | |||||
std::thread::id tid; | |||||
OptionManager options; | |||||
std::vector<std::string> scopes; | |||||
std::unique_ptr<InterpreterProfiler> profiler; | |||||
State() { | |||||
profiler = std::make_unique<InterpreterProfiler>(); | |||||
} | |||||
}; | |||||
struct ChannelState: State {}; | |||||
struct WorkerState: State { | |||||
CompNode::UnorderedMap<std::vector<std::string>> device_scope_map; | |||||
}; | |||||
ChannelState m_channel_state; | |||||
WorkerState m_worker_state; | |||||
}; | |||||
} // namespace mgb::imperative::interpreter::intl |
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter/option_manager.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
#include <unordered_map> | |||||
#include "megbrain/common.h" | |||||
namespace mgb::imperative::interpreter::intl { | |||||
struct OptionManager { | |||||
private: | |||||
std::unordered_map<std::string, int*> m_option_map = {}; | |||||
public: | |||||
#define DEF_OPTION(name, env_key, default_value, desc) \ | |||||
int name = (m_option_map[#name]=&name, get_option_from_env(env_key, default_value)); | |||||
DEF_OPTION(async_level, "MEGENGINE_INTERP_ASYNC_LEVEL", 2, | |||||
"config whether raise error exactly when invoking op.\n" | |||||
"level 2: both device and user side errors are async;\n" | |||||
"level 1: user side errors are sync;\n" | |||||
"level 0: both sync."); | |||||
DEF_OPTION(enable_swap, "MEGENGINE_ENABLE_SWAP", 0, ""); | |||||
DEF_OPTION(enable_drop, "MEGENGINE_ENABLE_DROP", 0, ""); | |||||
DEF_OPTION(max_recompute_time, "MEGENGINE_MAX_RECOMP_TIME", 1, ""); | |||||
DEF_OPTION(catch_worker_execption, "MEGENGINE_CATCH_WORKER_EXEC", 1, | |||||
"catch worker exception if enabled, close it when debugging"); | |||||
DEF_OPTION(buffer_length, "MEGENGINE_COMMAND_BUFFER_LENGTH", 3, | |||||
"set command buffer length."); | |||||
DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, | |||||
"enable host compute, thus computation may be done in host event if it's device is gpu."); | |||||
#undef DEF_OPTION | |||||
void set_option(const std::string& name, int value) { | |||||
*m_option_map[name] = value; | |||||
} | |||||
int get_option(const std::string& name) const { | |||||
return *m_option_map.at(name); | |||||
} | |||||
static int get_option_from_env(const std::string& name, int default_value) { | |||||
if (const char* env_val = MGB_GETENV(name.c_str())) { | |||||
default_value = std::atoi(env_val); | |||||
} | |||||
return default_value; | |||||
} | |||||
}; | |||||
} |
@@ -0,0 +1,280 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter/profiler.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./profiler.h" | |||||
#include <sstream> | |||||
#include <cinttypes> | |||||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) | |||||
#include <unistd.h> | |||||
#elif defined(_WIN32) | |||||
#include <process.h> | |||||
#else | |||||
#error Unsupported platform | |||||
#endif | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative::interpreter::intl { | |||||
namespace { | |||||
struct InterpreterProfilerDumpChromeTimelineContext { | |||||
// either host_thread(std::thread::id) or device_thread(CompNode) | |||||
using Thread = std::variant<std::thread::id, CompNode>; | |||||
// input params | |||||
std::string base_name; | |||||
std::string format; | |||||
InterpreterProfiler::Data profile_data; | |||||
InterpreterProfiler::Option option; | |||||
std::function<std::string(std::thread::id)> host_map; | |||||
// internal states | |||||
decltype(getpid()) pid; | |||||
CompNode::UnorderedMap<std::map<double, CompNode::Event*>> device_sync_map; | |||||
SmallVector<Thread> thread_list; | |||||
double time_start; | |||||
// options | |||||
bool show_operator_name; | |||||
// results | |||||
ChromeTraceEventList event_list; | |||||
InterpreterProfilerDumpChromeTimelineContext( | |||||
std::string base_name, | |||||
std::string format, | |||||
InterpreterProfiler::Data profile_data, | |||||
InterpreterProfiler::Option option, | |||||
std::function<std::string(std::thread::id)> host_map) | |||||
: base_name{base_name}, format{format}, profile_data{profile_data}, option{option}, host_map{host_map} { | |||||
pid = getpid(); | |||||
time_start = option.align_time ? time_start : 0; | |||||
show_operator_name = option.show_operator_name; | |||||
} | |||||
// get device time from event | |||||
double get_device_time(CompNode::Event* device_event, double host_time) { | |||||
device_event->host_wait(); | |||||
auto& sync_map = device_sync_map[device_event->comp_node()]; | |||||
// find sync point | |||||
auto iter = sync_map.begin(); | |||||
auto sync_current = [&] { | |||||
iter = sync_map.insert(iter, {host_time, device_event}); | |||||
return host_time; | |||||
}; | |||||
if (iter == sync_map.end()) { | |||||
// not found, insert sync | |||||
return sync_current(); | |||||
} | |||||
auto& [base_time, base] = *iter; | |||||
// calculate elapsed time | |||||
double delta_time = base->elapsed_time_until(*device_event) * 1e3; | |||||
return base_time + delta_time; | |||||
}; | |||||
template <typename T> | |||||
size_t get_tid(T t) { | |||||
for (size_t i = 0; i < thread_list.size(); i++) { | |||||
if (thread_list[i] == Thread{t}) { | |||||
return i; | |||||
} | |||||
} | |||||
thread_list.push_back(t); | |||||
return thread_list.size() - 1; | |||||
}; | |||||
ChromeTraceEvent& new_event(std::string name, char ph, uint64_t tid, double ts) { | |||||
return event_list.new_event().name(name).ph(ph).tid(tid).ts(ts).pid(pid); | |||||
}; | |||||
// convert Command to json object. Has to be an callable object | |||||
static auto constexpr cmd_to_args = [](auto&& cmd) { | |||||
auto args = json::Object::make(); | |||||
cmd.get_props([&](const char* key, auto&& value){ | |||||
(*args)[key] = json::String::make(to_string(value)); | |||||
}); | |||||
(*args)["__type__"] = json::String::make(typeid(cmd).name()); | |||||
return args; | |||||
}; | |||||
void process() { | |||||
// enumerate and process each record | |||||
for (auto&& record: profile_data.records) { | |||||
std::visit([this](auto& record){ | |||||
using TEvent = std::decay_t<decltype(record.data)>; | |||||
Session<TEvent>(*this, record).process(); | |||||
}, record); | |||||
} | |||||
for (size_t tid = 0; tid < thread_list.size(); ++tid) { | |||||
auto tname = std::visit([&](auto& host_or_device) -> std::string{ | |||||
using T = std::decay_t<decltype(host_or_device)>; | |||||
if constexpr (std::is_same_v<T, std::thread::id>) { | |||||
// take name from host_map | |||||
return host_map(host_or_device); | |||||
} else { | |||||
// use CompNode::to_string | |||||
return host_or_device.to_string(); | |||||
} | |||||
}, thread_list[tid]); | |||||
// assign thread name | |||||
new_event("thread_name", 'M', tid, 0) | |||||
.arg("name", tname); | |||||
} | |||||
// wraite output to file | |||||
std::string out_buf; | |||||
event_list.to_json()->writeto(out_buf, 4); | |||||
std::ofstream output_stream; | |||||
output_stream.open(base_name + "." + format); | |||||
output_stream << out_buf; | |||||
output_stream.flush(); | |||||
output_stream.close(); | |||||
} | |||||
template <typename TEvent> | |||||
struct Session { | |||||
InterpreterProfilerDumpChromeTimelineContext& ctx; | |||||
ProfilerBase::EventRecord<TEvent>& record; | |||||
TEvent& data; | |||||
Session(InterpreterProfilerDumpChromeTimelineContext& ctx, | |||||
ProfilerBase::EventRecord<TEvent>& record) | |||||
: ctx{ctx}, record{record}, data{record.data} {} | |||||
uint64_t get_host_tid() { | |||||
return ctx.get_tid(record.host().tid); | |||||
}; | |||||
double get_host_ts() { | |||||
return (ctx.time_start + record.host().time) * 1e3; | |||||
}; | |||||
uint64_t get_device_tid() { | |||||
return ctx.get_tid(record.device().event->comp_node()); | |||||
}; | |||||
double get_device_ts() { | |||||
return (ctx.time_start + ctx.get_device_time(record.device().event.get(), record.device().after)) * 1e3; | |||||
}; | |||||
ChromeTraceEvent& new_host_event(std::string name, char ph) { | |||||
return ctx.new_event(std::move(name), ph, get_host_tid(), get_host_ts()); | |||||
}; | |||||
ChromeTraceEvent& new_device_event(std::string name, char ph) { | |||||
return ctx.new_event(std::move(name), ph, get_device_tid(), get_device_ts()); | |||||
}; | |||||
void process() { | |||||
// dispatch event by type | |||||
if constexpr (std::is_same_v<TEvent, CommandEnqueueEvent>) { | |||||
auto args = std::visit(cmd_to_args, data.icmd.second); | |||||
new_host_event("CommandEnqueue", 'X').dur(0).args(args); | |||||
} else if constexpr (std::is_same_v<TEvent, CommandExecuteEvent>) { | |||||
auto args = std::visit(cmd_to_args, data.icmd.second); | |||||
new_host_event("CommandExecute", 'B').args(args); | |||||
} else if constexpr (std::is_same_v<TEvent, CommandFinishEvent>) { | |||||
new_host_event("CommandExecute", 'E'); | |||||
} else if constexpr (std::is_same_v<TEvent, HostOpExecuteEvent>) { | |||||
auto args = json::Object::make(); | |||||
auto props = OpDef::props(*data.op); | |||||
auto name = data.op->trait()->name; | |||||
for (auto&& [prop_name, prop_val]: props) { | |||||
(*args)[std::string("op.") + prop_name] = json::String::make(prop_val); | |||||
} | |||||
(*args)["name"] = json::String::make(name); | |||||
(*args)["id"] = json::Number::make(data.id); | |||||
(*args)["inputs"] = json::String::make(to_string(data.inputs)); | |||||
(*args)["outputs"] = json::String::make(to_string(data.outputs)); | |||||
new_host_event(ctx.show_operator_name ? name : "OpExecute", 'B').args(args); | |||||
} else if constexpr (std::is_same_v<TEvent, DeviceOpExecuteEvent>) { | |||||
auto args = json::Object::make(); | |||||
auto props = OpDef::props(*data.op); | |||||
auto name = data.op->trait()->name; | |||||
for (auto&& [prop_name, prop_val]: props) { | |||||
(*args)[std::string("op.") + prop_name] = json::String::make(prop_val); | |||||
} | |||||
(*args)["name"] = json::String::make(name); | |||||
(*args)["id"] = json::Number::make(data.id); | |||||
(*args)["inputs"] = json::String::make(to_string(data.inputs)); | |||||
(*args)["outputs"] = json::String::make(to_string(data.outputs)); | |||||
new_device_event(ctx.show_operator_name ? name : "OpExecute", 'B').args(args); | |||||
} else if constexpr (std::is_same_v<TEvent, HostOpFinishEvent>) { | |||||
auto name = data.op->trait()->name; | |||||
new_host_event(ctx.show_operator_name ? name : "OpExecute", 'E'); | |||||
} else if constexpr (std::is_same_v<TEvent, DeviceOpFinishEvent>) { | |||||
auto name = data.op->trait()->name; | |||||
new_device_event(ctx.show_operator_name ? name : "OpExecute", 'E'); | |||||
} else if constexpr (std::is_same_v<TEvent, TensorDeclareEvent>) { | |||||
json::Number::make(data.tensor_id); | |||||
new_host_event("TensorLifetime", 'N').id(data.tensor_id); | |||||
} else if constexpr (std::is_same_v<TEvent, TensorProduceEvent>) { | |||||
auto snapshot = json::Object::make(); | |||||
(*snapshot)["shape"] = json::String::make(to_string((TensorShape)data.layout)); | |||||
(*snapshot)["dtype"] = json::String::make(to_string(data.layout.dtype)); | |||||
(*snapshot)["device"] = json::String::make(to_string(data.device)); | |||||
json::Number::make(data.tensor_id); | |||||
new_host_event("TensorLifetime", 'O').id(data.tensor_id).arg("snapshot", snapshot); | |||||
} else if constexpr (std::is_same_v<TEvent, TensorEraseEvent>) { | |||||
json::Number::make(data.tensor_id); | |||||
new_host_event("TensorLifetime", 'D').id(data.tensor_id); | |||||
} else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) { | |||||
auto args = json::Object::make(); | |||||
(*args)["id"] = json::Number::make(data.tensor_id); | |||||
(*args)["prop"] = json::String::make(to_string(data.prop)); | |||||
(*args)["prop_desc"] = json::String::make(data.prop_desc); | |||||
new_host_event("TensorGetProp", 'X').dur(0).args(args); | |||||
} else if constexpr (std::is_same_v<TEvent, TensorNotifyPropEvent>) { | |||||
// TODO | |||||
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) { | |||||
auto args = json::Object::make(); | |||||
(*args)["id"] = json::Number::make(data.tensor_id); | |||||
(*args)["prop"] = json::String::make(to_string(data.prop)); | |||||
(*args)["prop_desc"] = json::String::make(data.prop_desc); | |||||
new_host_event("TensorWaitProp", 'B').args(args); | |||||
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) { | |||||
auto args = json::Object::make(); | |||||
(*args)["id"] = json::Number::make(data.tensor_id); | |||||
(*args)["prop"] = json::String::make(to_string(data.prop)); | |||||
(*args)["prop_desc"] = json::String::make(data.prop_desc); | |||||
new_host_event("TensorWaitProp", 'E').args(args); | |||||
} else if constexpr (std::is_same_v<TEvent, SyncStartEvent>) { | |||||
new_host_event("SyncEvent", 'B'); | |||||
} else if constexpr (std::is_same_v<TEvent, SyncFinishEvent>) { | |||||
new_host_event("SyncEvent", 'E'); | |||||
} else if constexpr (std::is_same_v<TEvent, ChannelBeginScope>) { | |||||
new_host_event(data.name, 'B'); | |||||
} else if constexpr (std::is_same_v<TEvent, ChannelEndScope>) { | |||||
new_host_event(data.name, 'E'); | |||||
} else if constexpr (std::is_same_v<TEvent, WorkerBeginScope>) { | |||||
new_host_event(data.name, 'B'); | |||||
} else if constexpr (std::is_same_v<TEvent, WorkerEndScope>) { | |||||
new_host_event(data.name, 'E'); | |||||
} else if constexpr (std::is_same_v<TEvent, DeviceBeginScope>) { | |||||
new_device_event(data.name, 'B'); | |||||
} else if constexpr (std::is_same_v<TEvent, DeviceEndScope>) { | |||||
new_device_event(data.name, 'E'); | |||||
} else { | |||||
static_assert(!std::is_same_v<TEvent, TEvent>); | |||||
} | |||||
} | |||||
}; | |||||
}; | |||||
} | |||||
void InterpreterProfiler::dump_data( | |||||
std::string basename, | |||||
std::string format, | |||||
InterpreterProfiler::Data profile_data, | |||||
const InterpreterProfiler::Option& option, | |||||
std::function<std::string(std::thread::id)> host_map) { | |||||
InterpreterProfilerDumpChromeTimelineContext{ | |||||
basename, format, profile_data, option, host_map | |||||
}.process(); | |||||
} | |||||
} |
@@ -0,0 +1,97 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter/profiler.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/profiler.h" | |||||
#include "./commands.h" | |||||
#include "./events.h" | |||||
#include "./option_manager.h" | |||||
namespace mgb::imperative::interpreter::intl { | |||||
class InterpreterProfiler: public Profiler< | |||||
CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent, | |||||
HostOpExecuteEvent, HostOpFinishEvent, | |||||
DeviceOpExecuteEvent, DeviceOpFinishEvent, | |||||
TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent, | |||||
TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent, | |||||
SyncStartEvent, SyncFinishEvent, | |||||
ChannelBeginScope, ChannelEndScope, | |||||
WorkerBeginScope, WorkerEndScope, | |||||
DeviceBeginScope, DeviceEndScope> { | |||||
/*22 events now. Enum code may be a better solution*/ | |||||
public: | |||||
enum Topic { | |||||
Command = 0b000001, | |||||
Operator = 0b000010, | |||||
TensorLifetime = 0b000100, | |||||
TensorProp = 0b001000, | |||||
Sync = 0b010000, | |||||
Scope = 0b100000, | |||||
}; | |||||
struct Option { | |||||
Topic topic; | |||||
bool align_time; | |||||
bool show_operator_name; | |||||
static Option from_dict(std::unordered_map<std::string, int> dict) { | |||||
Option option; | |||||
option.topic = Topic(dict.at("topic")); | |||||
option.align_time = bool(dict.at("align_time")); | |||||
option.show_operator_name = bool(dict.at("show_operator_name")); | |||||
return option; | |||||
} | |||||
}; | |||||
Option get_option() const { | |||||
return m_option; | |||||
} | |||||
void set_option(const Option& option) { | |||||
m_option = option; | |||||
} | |||||
static void dump_data(std::string basename, std::string format, InterpreterProfiler::Data profile_data, const Option& option, std::function<std::string(std::thread::id)> host_map); | |||||
static Mask topic_to_mask(Topic topic) { | |||||
Mask result; | |||||
if (topic & Command) { | |||||
result |= mask_of<CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent>(); | |||||
} | |||||
if (topic & Operator) { | |||||
result |= mask_of<HostOpExecuteEvent, HostOpFinishEvent>(); | |||||
result |= mask_of<DeviceOpExecuteEvent, DeviceOpFinishEvent>(); | |||||
} | |||||
if (topic & TensorLifetime) { | |||||
result |= mask_of<TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent>(); | |||||
} | |||||
if (topic & TensorProp) { | |||||
result |= mask_of<TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent>(); | |||||
} | |||||
if (topic & Sync) { | |||||
result |= mask_of<SyncStartEvent, SyncFinishEvent>(); | |||||
} | |||||
if (topic & Scope) { | |||||
result |= mask_of<ChannelBeginScope, ChannelEndScope, WorkerBeginScope, WorkerEndScope>(); | |||||
result |= mask_of<DeviceBeginScope, DeviceEndScope>(); | |||||
} | |||||
return result; | |||||
} | |||||
private: | |||||
Option m_option; | |||||
}; | |||||
} |
@@ -0,0 +1,135 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter/tensor_info.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/physical_tensor.h" | |||||
#include "megbrain/imperative/op_def.h" | |||||
#include "megbrain/imperative/utils/to_string.h" | |||||
namespace mgb::imperative { | |||||
namespace interpreter::intl { | |||||
enum EvictType { | |||||
NONE = 0, | |||||
SWAP = 1, | |||||
DROP = 2, | |||||
}; | |||||
struct TensorInfo; | |||||
using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||||
struct TensorInfo { | |||||
enum Prop { | |||||
Device, Shape, DType, DevValue, HostValue | |||||
}; | |||||
uint64_t id; | |||||
TensorPtr ptr; | |||||
LogicalTensorDesc desc; | |||||
// FIXME: broken by drop | |||||
bool value_fetched = false; | |||||
bool invalid = false; | |||||
bool allow_delete = false; | |||||
EvictType evict_type = NONE; | |||||
HostTensorND h_value; | |||||
// reserved for auto drop | |||||
size_t pinned = 0; | |||||
size_t recompute_times = 0; | |||||
struct ComputePath { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<TensorInfo*> inputs; | |||||
SmallVector<TensorInfo*> unique_inputs; | |||||
SmallVector<TensorInfo*> outputs; | |||||
size_t ref_cnt() { | |||||
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr); | |||||
} | |||||
static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) { | |||||
auto* path = new TensorInfo::ComputePath(); | |||||
path->op = op; | |||||
path->inputs = inputs; | |||||
path->outputs = outputs; | |||||
// dedup | |||||
SmallVector<TensorInfo*> unique_inputs = inputs; | |||||
std::sort(unique_inputs.begin(), unique_inputs.end()); | |||||
unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end()); | |||||
path->unique_inputs = unique_inputs; | |||||
// attach users | |||||
for (auto input: unique_inputs) { | |||||
input->users.push_back(path); | |||||
} | |||||
// attach producer | |||||
for (auto output: outputs) { | |||||
output->producer = path; | |||||
} | |||||
return path; | |||||
} | |||||
}* producer = nullptr; | |||||
void pin() { | |||||
++pinned; | |||||
} | |||||
void unpin() { | |||||
--pinned; | |||||
} | |||||
void detach_producer() { | |||||
if (!producer) { | |||||
return; | |||||
} | |||||
auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this); | |||||
mgb_assert(output != producer->outputs.end()); | |||||
*output = nullptr; | |||||
if (producer->ref_cnt() == 0) { | |||||
for (auto* input: producer->unique_inputs) { | |||||
input->users.erase(std::find(input->users.begin(), input->users.end(), producer)); | |||||
} | |||||
delete producer; | |||||
} | |||||
producer = nullptr; | |||||
} | |||||
SmallVector<ComputePath*> users; | |||||
}; | |||||
} | |||||
template <> | |||||
struct ToStringTrait<interpreter::intl::TensorInfo::Prop>{ | |||||
using TensorInfo = interpreter::intl::TensorInfo; | |||||
std::string operator()(TensorInfo::Prop prop) const { | |||||
switch(prop) { | |||||
case TensorInfo::DType: | |||||
return "dtype"; | |||||
case TensorInfo::DevValue: | |||||
return "dev_value"; | |||||
case TensorInfo::Device: | |||||
return "device"; | |||||
case TensorInfo::HostValue: | |||||
return "host_value"; | |||||
case TensorInfo::Shape: | |||||
return "shape"; | |||||
default: | |||||
return "unknown"; | |||||
} | |||||
} | |||||
}; | |||||
} |
@@ -1,351 +0,0 @@ | |||||
/** | |||||
* \file imperative/src/impl/interpreter_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <deque> | |||||
#include <future> | |||||
#include <list> | |||||
#include <unordered_set> | |||||
#include <variant> | |||||
#include "megbrain/utils/mempool.h" | |||||
#include "megbrain/imperative/interpreter.h" | |||||
namespace mgb::imperative::interpreter::intl { | |||||
using Handle = Interpreter::Handle; | |||||
struct InterpreterImpl : Interpreter { | |||||
std::unique_ptr<Channel> create_channel() override; | |||||
}; | |||||
enum EvictType { | |||||
NONE = 0, | |||||
SWAP = 1, | |||||
DROP = 2, | |||||
}; | |||||
struct TensorInfo; | |||||
using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||||
struct TensorInfo { | |||||
TensorPtr ptr; | |||||
LogicalTensorDesc desc; | |||||
// FIXME: broken by drop | |||||
bool value_fetched = false; | |||||
bool invalid = false; | |||||
EvictType evict_type = NONE; | |||||
HostTensorND h_value; | |||||
// reserved for auto drop | |||||
size_t pinned = 0; | |||||
size_t recompute_times = 0; | |||||
struct ComputePath { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<TensorInfo*> inputs; | |||||
SmallVector<TensorInfo*> unique_inputs; | |||||
SmallVector<TensorInfo*> outputs; | |||||
size_t ref_cnt() { | |||||
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr); | |||||
} | |||||
static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) { | |||||
auto* path = new TensorInfo::ComputePath(); | |||||
path->op = op; | |||||
path->inputs = inputs; | |||||
path->outputs = outputs; | |||||
// dedup | |||||
SmallVector<TensorInfo*> unique_inputs = inputs; | |||||
std::sort(unique_inputs.begin(), unique_inputs.end()); | |||||
unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end()); | |||||
path->unique_inputs = unique_inputs; | |||||
// attach users | |||||
for (auto input: unique_inputs) { | |||||
input->users.push_back(path); | |||||
} | |||||
// attach producer | |||||
for (auto output: outputs) { | |||||
output->producer = path; | |||||
} | |||||
return path; | |||||
} | |||||
}* producer = nullptr; | |||||
void pin() { | |||||
++pinned; | |||||
} | |||||
void unpin() { | |||||
--pinned; | |||||
} | |||||
void detach_producer() { | |||||
if (!producer) { | |||||
return; | |||||
} | |||||
auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this); | |||||
mgb_assert(output != producer->outputs.end()); | |||||
*output = nullptr; | |||||
if (producer->ref_cnt() == 0) { | |||||
for (auto* input: producer->unique_inputs) { | |||||
input->users.erase(std::find(input->users.begin(), input->users.end(), producer)); | |||||
} | |||||
delete producer; | |||||
} | |||||
producer = nullptr; | |||||
} | |||||
SmallVector<ComputePath*> users; | |||||
}; | |||||
struct Put { | |||||
TensorInfo* dest; | |||||
HostTensorND value; | |||||
bool no_cache = false; | |||||
std::string to_string() const { return ssprintf("Command: Put %p", dest); } | |||||
}; | |||||
struct ApplyOp { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<TensorInfo*> inputs; | |||||
SmallVector<TensorInfo*> outputs; | |||||
SmallVector<TensorInfo*> dels; | |||||
std::string to_string() const { | |||||
std::string builder{"Command: ApplyOp {"}; | |||||
builder += "inputs ["; | |||||
for (auto* input : inputs) { | |||||
builder += ssprintf("%p, ", input); | |||||
} | |||||
builder += "], outputs ["; | |||||
for (auto* output : outputs) { | |||||
builder += ssprintf("%p, ", output); | |||||
} | |||||
builder += "], dels ["; | |||||
for (auto* del : dels) { | |||||
builder += ssprintf("%p, ", del); | |||||
} | |||||
builder += "]"; | |||||
return builder; | |||||
} | |||||
}; | |||||
struct Del { | |||||
TensorInfo* dest; | |||||
std::string to_string() const { return ssprintf("Command: Del %p", dest); } | |||||
}; | |||||
struct GetValue { | |||||
TensorInfo* dest; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: GetValue %p", dest); | |||||
} | |||||
}; | |||||
struct SwapIn { | |||||
TensorInfo* dest; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: SwapIn %p", dest); | |||||
} | |||||
}; | |||||
struct SwapOut { | |||||
TensorInfo* dest; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: SwapOut %p", dest); | |||||
} | |||||
}; | |||||
struct Drop { | |||||
TensorInfo* dest; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: Drop %p", dest); | |||||
} | |||||
}; | |||||
struct Move { | |||||
TensorInfo* src; | |||||
TensorInfo* dest; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: Move %s to %s", | |||||
src->desc.layout.to_string().c_str(), | |||||
dest->desc.layout.to_string().c_str()); | |||||
} | |||||
}; | |||||
struct Flush { | |||||
TensorInfo* dest = nullptr; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: Flush %p", dest); | |||||
} | |||||
}; | |||||
struct Nop { | |||||
std::string to_string() const { return "Command: Nop"; } | |||||
}; | |||||
using Command = std::variant<Put, | |||||
ApplyOp, | |||||
Del, | |||||
GetValue, | |||||
SwapIn, | |||||
SwapOut, | |||||
Drop, | |||||
Move, | |||||
Flush, | |||||
Nop>; | |||||
struct ChannelImpl : Interpreter::Channel { | |||||
ChannelImpl() : m_worker(this), m_buffer(this) {} | |||||
~ChannelImpl() override; | |||||
Handle put(const HostTensorND& value, bool no_cache) override; | |||||
Handle put(const DeviceTensorND& value) override; | |||||
void del(Handle) override; | |||||
void swap_in(Handle) override; | |||||
void swap_out(Handle) override; | |||||
void drop(Handle) override; | |||||
SmallVector<Handle> apply_op( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<Handle>& inputs) override; | |||||
HostTensorND get_value(Handle) override; | |||||
TensorShape get_shape(Handle) override; | |||||
DType get_dtype(Handle) override; | |||||
CompNode get_device(Handle) override; | |||||
DeviceTensorND get_dev_tensor(Handle) override; | |||||
void sync() override; | |||||
void close() override; | |||||
void set_swap_flag(bool) override; | |||||
void set_drop_flag(bool) override; | |||||
void set_buffer_length(int) override; | |||||
void config_async_level(int level) override; | |||||
int get_async_level() override; | |||||
private: | |||||
TensorInfo* alloc(); | |||||
void free(TensorInfo*); | |||||
void detach_users(TensorInfo*); | |||||
void process_one_task(Command&); | |||||
void check_worker_exc_unsafe(); | |||||
void produce_tensor(TensorInfo* dest, TensorPtr ptr); | |||||
void release_tensor(TensorInfo* dest); | |||||
void regenerate(TensorInfo* dest); | |||||
void recompute(TensorInfo::ComputePath* path); | |||||
void dispatch_default_cpu( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs); | |||||
void dispatch_kernel( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs); | |||||
std::mutex m_mutex; | |||||
std::condition_variable m_cv; | |||||
MemPool<TensorInfo> m_pool; | |||||
std::unordered_set<Handle> m_valid_handle; | |||||
TensorInfo* m_waitee = nullptr; | |||||
std::exception_ptr m_worker_exc; | |||||
size_t m_enable_evict = 0; | |||||
struct WorkQueue : AsyncQueueSC<Command, WorkQueue> { | |||||
// set max_spin=0 to prevent Queue fetch task in busy wait manner. | |||||
// this won't affect throughput when python interpreter is sending enough task, | |||||
// but will significantly save CPU time when waiting for task, e.g. wait for data input | |||||
WorkQueue(ChannelImpl* owner) | |||||
: AsyncQueueSC<Command, WorkQueue>(0), m_owner(owner) { | |||||
sys::set_thread_name("interpreter"); | |||||
} | |||||
void process_one_task(Command& cmd) { | |||||
m_owner->process_one_task(cmd); | |||||
} | |||||
void on_async_queue_worker_thread_start() override { | |||||
sys::set_thread_name("worker"); | |||||
} | |||||
private: | |||||
ChannelImpl* m_owner; | |||||
} m_worker; | |||||
/** | |||||
* Buf a command window for following fuse | |||||
* example: | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} | | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} | | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... | | |||||
* --------------------------------------------------------------------- | |||||
* Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task | |||||
*/ | |||||
struct CommandBuffer { | |||||
CommandBuffer(ChannelImpl* owner) : m_owner(owner) { | |||||
int capacity = 3; | |||||
if(const char* capacity_str = MGB_GETENV("MEGENGINE_COMMAND_BUFFER_LENGTH")) { | |||||
capacity = atoi(capacity_str); | |||||
} | |||||
set_capacity(capacity); | |||||
} | |||||
void enqueue(Command cmd); | |||||
bool empty() const { | |||||
return m_commands.empty(); | |||||
} | |||||
void set_capacity(int capacity) { | |||||
mgb_assert(capacity >= 0 && capacity < 100, "invalid command buffer length"); | |||||
m_capacity = capacity; | |||||
} | |||||
private: | |||||
ChannelImpl* m_owner; | |||||
size_t m_capacity; | |||||
std::deque<Command> m_commands; | |||||
using Handle = decltype(m_commands)::iterator; | |||||
// [begin, end) | |||||
using Range = std::array<Handle, 2>; | |||||
// Launch commands in range [m_commands.begin(), pos) | |||||
void flush(Handle pos); | |||||
// Select flush position for incoming cmd | |||||
Handle flush_pos_for(const Command& cmd); | |||||
// Fuse del command into suitable ApplyOp | |||||
bool fuse_del(const Del& cmd); | |||||
// Returns the last handle that dest is used within range. If dest is not used, returns range[1] | |||||
Handle find_last_usage(TensorInfo* dest, Range range); | |||||
// Returns the produce position of dest. If not found, returns range[1] | |||||
Handle find_produce(TensorInfo* dest, Range range); | |||||
} m_buffer; | |||||
//! config whether raise error exactly when invoking op. | |||||
//! level 2: both device and user side errors are async; | |||||
//! level 1: user side errors are sync; | |||||
//! level 0: both sync. | |||||
int m_async_level = 2; | |||||
int m_max_recompute_time = 1; | |||||
}; | |||||
} // namespace mgb::imperative::interpreter::intl |
@@ -70,6 +70,26 @@ BackwardGraphResult OpDef::make_backward_graph( | |||||
return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | ||||
} | } | ||||
std::vector<std::pair<const char*, std::string>> OpDef::props( | |||||
const OpDef& def) { | |||||
return def.trait()->props(def); | |||||
} | |||||
const char* OpDef::name() const { | |||||
return trait()->name; | |||||
} | |||||
std::string OpDef::to_string() const { | |||||
std::string builder = "{"; | |||||
for (auto&& [name, value]: props(*this)) { | |||||
builder += name; | |||||
builder += ": "; | |||||
builder += value; | |||||
builder += ","; | |||||
} | |||||
return builder + "}"; | |||||
} | |||||
size_t OpDef::hash() const { | size_t OpDef::hash() const { | ||||
return trait()->hash(*this); | return trait()->hash(*this); | ||||
} | } | ||||
@@ -72,6 +72,7 @@ using InferOutputAttrsFallible = detail::OpMeth< | |||||
decltype(OpDef::infer_output_attrs_fallible)>; | decltype(OpDef::infer_output_attrs_fallible)>; | ||||
using GradMaker = detail::OpMeth< | using GradMaker = detail::OpMeth< | ||||
decltype(OpDef::make_backward_graph)>; | decltype(OpDef::make_backward_graph)>; | ||||
using Props = detail::OpMeth<decltype(OpDef::props)>; | |||||
using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | ||||
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | ||||
@@ -84,6 +85,7 @@ struct OpTrait { | |||||
ApplyOnVarNode apply_on_var_node; | ApplyOnVarNode apply_on_var_node; | ||||
InferOutputAttrsFallible infer_output_attrs_fallible; | InferOutputAttrsFallible infer_output_attrs_fallible; | ||||
GradMaker make_backward_graph; | GradMaker make_backward_graph; | ||||
Props props; | |||||
HashFunc hash; | HashFunc hash; | ||||
IsSame is_same_st; | IsSame is_same_st; | ||||
OpTrait(const char* name); | OpTrait(const char* name); | ||||
@@ -100,6 +102,7 @@ struct OpTrait { | |||||
cb(apply_on_var_node) \ | cb(apply_on_var_node) \ | ||||
cb(infer_output_attrs_fallible) \ | cb(infer_output_attrs_fallible) \ | ||||
cb(make_backward_graph) \ | cb(make_backward_graph) \ | ||||
cb(props) \ | |||||
cb(hash) \ | cb(hash) \ | ||||
cb(is_same_st) | cb(is_same_st) | ||||
@@ -148,9 +148,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs( | |||||
.graph().infer_attrs(inputs); | .graph().infer_attrs(inputs); | ||||
} | } | ||||
std::vector<std::pair<const char*, std::string>> props( | |||||
const OpDef& backward_graph) { | |||||
return {}; | |||||
} | |||||
OP_TRAIT_REG(BackwardGraph, BackwardGraph) | OP_TRAIT_REG(BackwardGraph, BackwardGraph) | ||||
.apply_on_physical_tensor(backward_impl) | .apply_on_physical_tensor(backward_impl) | ||||
.infer_output_attrs_fallible(infer_tensor_attrs) | .infer_output_attrs_fallible(infer_tensor_attrs) | ||||
.props(props) | |||||
.fallback(); | .fallback(); | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -95,9 +95,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||||
return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); | return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); | ||||
} | } | ||||
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||||
return {}; | |||||
} | |||||
OP_TRAIT_REG(OprAttr, OprAttr) | OP_TRAIT_REG(OprAttr, OprAttr) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.props(props) | |||||
.fallback(); | .fallback(); | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -11,12 +11,14 @@ | |||||
#include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
#include "./function_hook.h" | |||||
#include <chrono> | |||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
#include "megbrain/plugin/opr_footprint.h" | #include "megbrain/plugin/opr_footprint.h" | ||||
#include "./function_hook.h" | |||||
#include "./event_pool.h" | #include "./event_pool.h" | ||||
#include "./op_trait.h" | #include "./op_trait.h" | ||||
@@ -25,200 +27,42 @@ namespace imperative { | |||||
namespace { | namespace { | ||||
CompNode::UnorderedSet collect_comp_nodes( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
CompNode::UnorderedSet comp_nodes; | |||||
SmallVector<LogicalTensorDesc> inp_descs; | |||||
for (auto&& i : inputs) { | |||||
comp_nodes.insert(i->comp_node()); | |||||
inp_descs.push_back({i->layout(), i->comp_node(), {}}); | |||||
} | |||||
SmallVector<LogicalTensorDesc> oup_descs = std::get<0>(def.infer_output_attrs_fallible(def, inp_descs)); | |||||
for (auto&& output_attr : oup_descs) { | |||||
comp_nodes.insert(output_attr.comp_node); | |||||
} | |||||
return comp_nodes; | |||||
} | |||||
DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) { | DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) { | ||||
auto event = EventPool::with_timer().alloc_shared(device); | auto event = EventPool::with_timer().alloc_shared(device); | ||||
event->record(); | event->record(); | ||||
return event; | return event; | ||||
} | } | ||||
OprFootprint footprint{}; | |||||
} // namespace | } // namespace | ||||
void DeviceTimer::reset(thin_function<double()> host_timer) { | |||||
CompNode::foreach ([this, host_timer](CompNode device) { | |||||
m_base_event_table[device] = {alloc_recorded_event(device), host_timer()}; | |||||
}); | |||||
m_host_timer = host_timer; | |||||
DeviceTimer::SharedEvent DeviceTimer::get_device_time(CompNode device) { | |||||
return alloc_recorded_event(device); | |||||
} | } | ||||
thin_function<double()> DeviceTimer::get_device_time(CompNode device) { | |||||
auto event = EventPool::with_timer().alloc_shared(device); | |||||
event->record(); | |||||
if(m_base_event_table.count(device) == 0) { | |||||
m_base_event_table[device] = {alloc_recorded_event(device), m_host_timer()}; | |||||
SmallVector<DeviceTimer::SharedEvent> DeviceTimer::get_all(SmallVector<CompNode> device_list) { | |||||
SmallVector<DeviceTimer::SharedEvent> results; | |||||
for (auto&& device: device_list) { | |||||
results.push_back(alloc_recorded_event(device)); | |||||
} | } | ||||
auto base = m_base_event_table[device]; | |||||
return [base, event] { | |||||
auto [base_event, host_time] = base; | |||||
// TODO: sync once for each compnode | |||||
event->host_wait(); | |||||
return base_event->elapsed_time_until(*event) * 1000 + host_time; | |||||
}; | |||||
return results; | |||||
} | } | ||||
void DeviceTimer::clear() { | |||||
m_base_event_table.clear(); | |||||
double HostTimer::get_msecs() { | |||||
using namespace std::chrono; | |||||
auto finish = steady_clock::now(); | |||||
auto duration = duration_cast<microseconds>(finish - m_start); | |||||
return (double)duration.count() / 1e3; | |||||
} | } | ||||
size_t TensorRecorder::record_tensor(const TensorPtr& tensor) { | |||||
if (m_tensor_map.count(tensor.get()) > 0) { | |||||
auto& [prev, id] = m_tensor_map[tensor.get()]; | |||||
if (prev.lock() != tensor) { | |||||
prev = tensor; | |||||
id = m_next_id++; | |||||
} | |||||
return id; | |||||
} else { | |||||
auto id = m_next_id++; | |||||
m_tensor_map.insert( | |||||
{tensor.get(), {std::weak_ptr<Tensor>{tensor}, id}}); | |||||
return id; | |||||
} | |||||
} | |||||
void TensorRecorder::clear() { | |||||
m_next_id = 0; | |||||
m_tensor_map.clear(); | |||||
} | |||||
Profile& Profiler::get_profile() { | |||||
for (auto& entry : m_profile) { | |||||
for (auto& [device, device_begin, device_end] : entry.device_list) { | |||||
MGB_MARK_USED_VAR(device); | |||||
device_begin = [value = device_begin()] { return value; }; | |||||
device_end = [value = device_end()] { return value; }; | |||||
} | |||||
} | |||||
return m_profile; | |||||
} | |||||
void Profiler::start(uint32_t flags) { | |||||
m_host_timer.reset(); | |||||
m_device_timer.reset([&] { return m_host_timer.get_msecs(); }); | |||||
OpTrait::for_each_trait([this, flags](OpTrait& trait) { | |||||
auto hook_apply_on_physical_tensor = | |||||
make_shared_hook(&trait.apply_on_physical_tensor); | |||||
auto hook_apply_on_var_node = | |||||
make_shared_hook(&trait.apply_on_var_node); | |||||
hook_apply_on_physical_tensor->apply_hook([this, flags] | |||||
(auto&& apply, const OpDef& def, SmallVector<TensorPtr> inputs) { | |||||
auto shape2vector = [](const TensorShape& shape) { | |||||
std::vector<size_t> vector_shape; | |||||
for (size_t i = 0; i < shape.ndim; i++) { | |||||
vector_shape.push_back(shape[i]); | |||||
} | |||||
return vector_shape; | |||||
}; | |||||
ProfileEntry entry; | |||||
entry.id = m_entry_count++; | |||||
// TODO: assign parent | |||||
entry.parent = 0; | |||||
// Record apply context and save to m_profile | |||||
entry.op = const_cast<OpDef&>(def).shared_from_this(); | |||||
for (auto&& input : inputs) { | |||||
entry.inputs.push_back({m_tensor_recorder.record_tensor(input), | |||||
shape2vector(input->layout()), | |||||
input->comp_node()}); | |||||
} | |||||
double host_begin = m_host_timer.get_msecs(); | |||||
auto&& comp_nodes = collect_comp_nodes(def, inputs); | |||||
for (auto&& comp_node : comp_nodes) { | |||||
entry.device_list.push_back( | |||||
{comp_node, | |||||
m_device_timer.get_device_time(comp_node), | |||||
{}}); | |||||
} | |||||
if (flags & PROFILE_FOOTPRINT) { | |||||
MGB_LOCK_GUARD(m_lock); | |||||
m_entry_stack.push({&def, &entry, std::this_thread::get_id()}); | |||||
} | |||||
// Do real apply | |||||
auto outputs = apply(def, inputs); | |||||
for (auto& [cn, dev_begin, dev_end] : entry.device_list) { | |||||
MGB_MARK_USED_VAR(cn); | |||||
MGB_MARK_USED_VAR(dev_begin); | |||||
dev_end = m_device_timer.get_device_time(cn); | |||||
} | |||||
entry.host = {host_begin, m_host_timer.get_msecs()}; | |||||
for (auto&& output : outputs) { | |||||
entry.outputs.push_back( | |||||
{m_tensor_recorder.record_tensor(output), | |||||
shape2vector(output->layout()), output->comp_node()}); | |||||
} | |||||
if (flags & PROFILE_FOOTPRINT) { | |||||
mgb_assert(std::get<1>(m_entry_stack.top()) == &entry); | |||||
MGB_LOCK_GUARD(m_lock); | |||||
m_entry_stack.pop(); | |||||
} | |||||
m_profile.push_back(std::move(entry)); | |||||
return outputs; | |||||
}); | |||||
if (flags & PROFILE_FOOTPRINT) { | |||||
hook_apply_on_var_node->apply_hook( | |||||
[this](auto&& apply, const OpDef& def, | |||||
VarNodeArray inputs) -> VarNodeArray { | |||||
auto vars = apply(def, std::move(inputs)); | |||||
std::remove_reference_t<decltype(m_entry_stack.top())> | |||||
top; | |||||
{ | |||||
MGB_LOCK_GUARD(m_lock); | |||||
if (m_entry_stack.empty()) { | |||||
return vars; | |||||
} | |||||
top = m_entry_stack.top(); | |||||
} | |||||
auto [current_op, current_entry, thread_id] = top; | |||||
if (current_op != &def || | |||||
thread_id != std::this_thread::get_id()) { | |||||
return vars; | |||||
} | |||||
auto&& footprint_result = | |||||
footprint.calc_footprint(vars[0]->owner_opr()); | |||||
current_entry->memory = footprint_result.memory; | |||||
current_entry->computation = | |||||
footprint_result.computation; | |||||
#if MGB_ENABLE_JSON | |||||
current_entry->param = footprint_result.param; | |||||
#endif | |||||
return vars; | |||||
}); | |||||
} | |||||
m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); | |||||
m_hooker_list.push_back(std::move(hook_apply_on_var_node)); | |||||
}); | |||||
} | |||||
void Profiler::stop() { | |||||
m_hooker_list.clear(); | |||||
for (auto& entry : m_profile) { | |||||
entry.wait_device(); | |||||
} | |||||
double HostTimer::get_started_at() { | |||||
return m_started_at; | |||||
} | } | ||||
void Profiler::clear() { | |||||
mgb_assert(m_entry_stack.empty(), | |||||
"entry_stack should be empty after profile"); | |||||
mgb_assert(m_hooker_list.empty(), "hooks should be released"); | |||||
m_profile.clear(); | |||||
m_entry_count = 0; | |||||
m_device_timer.clear(); | |||||
m_tensor_recorder.clear(); | |||||
void HostTimer::reset() { | |||||
using namespace std::chrono; | |||||
m_start = steady_clock::now(); | |||||
auto now_us = duration_cast<microseconds>(std::chrono::system_clock::now().time_since_epoch()); | |||||
m_started_at = (double)(now_us.count()) / 1e3; | |||||
} | } | ||||
} // namespace imperative | } // namespace imperative | ||||
@@ -471,6 +471,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph { | |||||
} | } | ||||
if (can_pop) { | if (can_pop) { | ||||
for (auto _ : comp_node_trackers) { | for (auto _ : comp_node_trackers) { | ||||
MGB_MARK_USED_VAR(_); | |||||
busy_oprs.pop_front(); | busy_oprs.pop_front(); | ||||
} | } | ||||
m_opr = busy_oprs.front().opr; | m_opr = busy_oprs.front().opr; | ||||
@@ -10,6 +10,7 @@ | |||||
*/ | */ | ||||
#include <atomic> | #include <atomic> | ||||
#include <any> | |||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
@@ -42,12 +43,15 @@ struct Interpreter { | |||||
virtual void sync() = 0; | virtual void sync() = 0; | ||||
virtual void close() = 0; | virtual void close() = 0; | ||||
virtual void set_swap_flag(bool) = 0; | |||||
virtual void set_drop_flag(bool) = 0; | |||||
virtual void set_buffer_length(int) = 0; | |||||
virtual void config_async_level(int level) = 0; | |||||
virtual int get_async_level() = 0; | |||||
virtual int get_option(std::string name) = 0; | |||||
virtual void set_option(std::string name, int value) = 0; | |||||
virtual void start_profile(std::unordered_map<std::string, int> option) = 0; | |||||
virtual void stop_profile(std::string basename, std::string format) = 0; | |||||
virtual void push_scope(std::string name) = 0; | |||||
virtual void pop_scope(std::string name) = 0; | |||||
}; | }; | ||||
virtual std::unique_ptr<Channel> create_channel() = 0; | virtual std::unique_ptr<Channel> create_channel() = 0; | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
#include "megbrain/imperative/utils/to_string.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -80,8 +81,15 @@ public: | |||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad); | const SmallVector<bool>& output_has_grad); | ||||
static std::vector<std::pair<const char*, std::string>> props( | |||||
const OpDef& def); | |||||
const OpTrait* trait() const; | const OpTrait* trait() const; | ||||
const char* name() const; | |||||
std::string to_string() const; | |||||
virtual size_t hash() const; | virtual size_t hash() const; | ||||
virtual bool is_same_st(const Hashable&) const; | virtual bool is_same_st(const Hashable&) const; | ||||
@@ -96,6 +104,16 @@ public: | |||||
} | } | ||||
}; | }; | ||||
template <> | |||||
struct ToStringTrait<OpDef*>{ | |||||
std::string operator()(OpDef* op) const { | |||||
if (op == nullptr) { | |||||
return "nullptr"; | |||||
} | |||||
return op->to_string(); | |||||
} | |||||
}; | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -11,10 +11,12 @@ | |||||
#pragma once | #pragma once | ||||
#include <any> | |||||
#include <optional> | #include <optional> | ||||
#include <stack> | |||||
#include <list> | |||||
#include <map> | |||||
#include <variant> | |||||
#include <fstream> | |||||
#include <chrono> | |||||
#include <bitset> | |||||
#include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
#include "megbrain/graph/event.h" | #include "megbrain/graph/event.h" | ||||
@@ -27,89 +29,298 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
using ProfileTensor = std::tuple<size_t, std::vector<size_t>, CompNode>; | |||||
struct ProfileEntry { | |||||
using TimeClosure = std::function<double()>; | |||||
size_t id; | |||||
size_t parent; | |||||
std::shared_ptr<OpDef> op; | |||||
//(host_begin, host_end) | |||||
std::tuple<double, double> host; | |||||
//[(device, device_begin, device_end)] | |||||
std::vector<std::tuple<CompNode, TimeClosure, TimeClosure>> device_list; | |||||
std::vector<ProfileTensor> inputs; | |||||
std::vector<ProfileTensor> outputs; | |||||
long long memory = 0; | |||||
long long computation = 0; | |||||
#if MGB_ENABLE_JSON | |||||
std::shared_ptr<json::Value> param; | |||||
#endif | |||||
void wait_device() { | |||||
for (auto& [cn, begin, end] : device_list) { | |||||
MGB_MARK_USED_VAR(cn); | |||||
begin = [begin = begin()] { return begin; }; | |||||
end = [end = end()] { return end; }; | |||||
} | |||||
} | |||||
}; | |||||
using Profile = std::list<ProfileEntry>; | |||||
class DeviceTimer { | class DeviceTimer { | ||||
public: | public: | ||||
using SharedEvent = std::shared_ptr<CompNode::Event>; | using SharedEvent = std::shared_ptr<CompNode::Event>; | ||||
DeviceTimer() = default; | DeviceTimer() = default; | ||||
void reset(thin_function<double()> host_timer); | |||||
thin_function<double()> get_device_time(CompNode device); | |||||
void clear(); | |||||
SharedEvent get_device_time(CompNode device); | |||||
SmallVector<SharedEvent> get_all(SmallVector<CompNode> device_list); | |||||
}; | |||||
class HostTimer { | |||||
public: | |||||
void reset(); | |||||
double get_msecs(); | |||||
double get_started_at(); | |||||
private: | private: | ||||
CompNode::UnorderedMap<std::tuple<SharedEvent, double>> m_base_event_table; | |||||
thin_function<double()> m_host_timer; | |||||
decltype(std::chrono::steady_clock::now()) m_start; | |||||
double m_started_at; | |||||
}; | }; | ||||
class TensorRecorder { | |||||
private: | |||||
// active tensors | |||||
std::unordered_map<Tensor*, std::tuple<std::weak_ptr<Tensor>, size_t>> | |||||
m_tensor_map; | |||||
size_t m_next_id; | |||||
class ProfilerBase { | |||||
public: | public: | ||||
size_t record_tensor(const TensorPtr& tensor); | |||||
void clear(); | |||||
using Host = std::thread::id; | |||||
using Device = CompNode; | |||||
struct HostInstant { | |||||
Host tid; | |||||
double time; | |||||
void wait() {} | |||||
}; | |||||
struct DeviceInstant { | |||||
double before; | |||||
std::shared_ptr<CompNode::Event> event; | |||||
double after; | |||||
void wait() { | |||||
event->host_wait(); | |||||
} | |||||
}; | |||||
using Instant = std::variant<HostInstant, DeviceInstant>; | |||||
template <typename TEvent> | |||||
struct EventRecord { | |||||
Instant instant; | |||||
TEvent data; | |||||
HostInstant& host() { | |||||
return std::get<HostInstant>(instant); | |||||
} | |||||
DeviceInstant device() { | |||||
return std::get<DeviceInstant>(instant); | |||||
} | |||||
void wait() { | |||||
std::visit([&](auto& instant){ instant.wait(); }, instant); | |||||
} | |||||
}; | |||||
protected: | |||||
HostInstant record_host() { | |||||
return {std::this_thread::get_id(), m_host_timer.get_msecs()}; | |||||
} | |||||
DeviceInstant record_device(Device device) { | |||||
auto before = m_host_timer.get_msecs(); | |||||
auto event = m_device_timer.get_device_time(device); | |||||
auto after = m_host_timer.get_msecs(); | |||||
return {before, event, after}; | |||||
} | |||||
protected: | |||||
std::atomic_int64_t m_last_id = 0; | |||||
HostTimer m_host_timer; | |||||
DeviceTimer m_device_timer; | |||||
Spinlock m_lock; | |||||
}; | }; | ||||
class Profiler { | |||||
template <typename... TEvents> | |||||
class Profiler: public ProfilerBase { | |||||
public: | public: | ||||
enum Flags { | |||||
PROFILE_FOOTPRINT = 1, | |||||
using Record = std::variant<EventRecord<TEvents>...>; | |||||
using Mask = std::bitset<sizeof...(TEvents)>; | |||||
struct Data { | |||||
std::vector<Record> records; | |||||
double started_at; | |||||
}; | }; | ||||
template <typename TEvent, size_t index = 0> | |||||
static constexpr size_t index_of() { | |||||
if constexpr (index == std::variant_size_v<Record>) { | |||||
return index; | |||||
} else if constexpr (std::is_same_v<EventRecord<TEvent>, std::variant_alternative_t<index, Record>>) { | |||||
return index; | |||||
} else { | |||||
return index_of<TEvent, index+1>(); | |||||
} | |||||
}; | |||||
template <typename... TEvents2> | |||||
static Mask mask_of() { | |||||
return Mask{} | (Mask{}.set(index_of<TEvents2>()) |...); | |||||
} | |||||
enum Status { | |||||
NotStarted, Profiling, Stopped | |||||
}; | |||||
public: | public: | ||||
Profiler() = default; | |||||
// Start profiler by hook OpTrait | |||||
void start(uint32_t flags); | |||||
// Stop profiler and clean environment | |||||
void stop(); | |||||
void clear(); | |||||
Profile& get_profile(); | |||||
template <typename TEvent, typename... TArgs> | |||||
void record_host(TArgs&&... args) { | |||||
auto instant = HostInstant{std::this_thread::get_id(), m_host_timer.get_msecs()}; | |||||
MGB_LOCK_GUARD(m_lock); | |||||
if (!m_event_mask.test(index_of<TEvent>())) { | |||||
return; | |||||
} | |||||
mgb_assert(m_status != Stopped, "record after stop"); | |||||
m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}}); | |||||
} | |||||
template <typename TEvent, typename... TArgs> | |||||
void record_device(Device device, TArgs&&... args) { | |||||
auto before = m_host_timer.get_msecs(); | |||||
auto event = m_device_timer.get_device_time(device); | |||||
auto after = m_host_timer.get_msecs(); | |||||
auto instant = DeviceInstant{before, event, after}; | |||||
MGB_LOCK_GUARD(m_lock); | |||||
if (!m_event_mask.test(index_of<TEvent>())) { | |||||
return; | |||||
} | |||||
mgb_assert(m_status != Stopped, "record after stop"); | |||||
m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}}); | |||||
} | |||||
void start(Mask mask) { | |||||
MGB_LOCK_GUARD(m_lock); | |||||
mgb_assert(m_status == NotStarted, "profiler already started"); | |||||
m_status = Profiling; | |||||
m_event_mask = mask; | |||||
m_host_timer.reset(); | |||||
} | |||||
Data stop() { | |||||
MGB_LOCK_GUARD(m_lock); | |||||
mgb_assert(m_status == Profiling, "profiler not active"); | |||||
m_status = Stopped; | |||||
for (auto&& record: m_record_list) { | |||||
std::visit([&](auto& record){ | |||||
record.wait(); | |||||
}, record); | |||||
} | |||||
auto records = std::move(m_record_list); | |||||
return { records, m_host_timer.get_started_at() }; | |||||
} | |||||
protected: | |||||
std::vector<Record> m_record_list; | |||||
Mask m_event_mask; | |||||
Status m_status = NotStarted; | |||||
}; | |||||
class ChromeTraceEvent { | |||||
public: | |||||
ChromeTraceEvent& name(std::string name) { | |||||
m_name = std::move(name); | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& tid(uint64_t tid) { | |||||
m_tid = std::move(tid); | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& cat(std::string cat) { | |||||
m_cat = std::move(cat); | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& pid(uint64_t pid) { | |||||
m_pid = pid; | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& id(uint64_t id) { | |||||
m_id = id; | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& idx(uint64_t idx) { | |||||
m_idx = idx; | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& ts(double ts) { | |||||
m_ts = ts; | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& dur(double dur) { | |||||
m_dur = dur; | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& ph(char ph) { | |||||
m_ph = ph; | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& bp(char bp) { | |||||
m_bp = bp; | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& args(std::shared_ptr<json::Object> args) { | |||||
m_args = std::move(args); | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& arg(std::string key, std::string value) { | |||||
if (!m_args) { | |||||
m_args = json::Object::make(); | |||||
} | |||||
(*m_args)[key] = json::String::make(value); | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& arg(std::string key, double value) { | |||||
if (!m_args) { | |||||
m_args = json::Object::make(); | |||||
} | |||||
(*m_args)[key] = json::Number::make(value); | |||||
return *this; | |||||
} | |||||
ChromeTraceEvent& arg(std::string key, std::shared_ptr<json::Value> value) { | |||||
if (!m_args) { | |||||
m_args = json::Object::make(); | |||||
} | |||||
(*m_args)[key] = value; | |||||
return *this; | |||||
} | |||||
std::shared_ptr<json::Object> to_json() const { | |||||
auto result = json::Object::make(); | |||||
auto prop_str = [&](auto key, auto value) { | |||||
if (value.empty()) { | |||||
return; | |||||
} | |||||
(*result)[key] = json::String::make(value); | |||||
}; | |||||
auto prop_num = [&](auto key, auto value) { | |||||
if (!value) { | |||||
return; | |||||
} | |||||
(*result)[key] = json::Number::make(value.value()); | |||||
}; | |||||
auto prop_char = [&](auto key, auto value) { | |||||
if (!value) { | |||||
return; | |||||
} | |||||
(*result)[key] = json::String::make(std::string{} + value.value()); | |||||
}; | |||||
prop_str("name", m_name); | |||||
prop_num("tid", m_tid); | |||||
prop_str("cat", m_cat); | |||||
prop_num("pid", m_pid); | |||||
prop_num("id", m_id); | |||||
prop_num("idx", m_idx); | |||||
prop_num("ts", m_ts); | |||||
prop_num("dur", m_dur); | |||||
prop_char("ph", m_ph); | |||||
prop_char("bp", m_bp); | |||||
if (m_args) { | |||||
(*result)["args"] = m_args; | |||||
} | |||||
return result; | |||||
} | |||||
private: | private: | ||||
DeviceTimer m_device_timer; | |||||
RealTimer m_host_timer; | |||||
Profile m_profile; | |||||
TensorRecorder m_tensor_recorder; | |||||
std::stack<std::tuple<const OpDef*, ProfileEntry*, std::thread::id>> | |||||
m_entry_stack; | |||||
// Hold profile owned by this Profiler | |||||
std::unique_ptr<Profile> m_owned_profile; | |||||
// Hold hooks, cleared when stop | |||||
std::vector<std::any> m_hooker_list; | |||||
size_t m_entry_count = 0; | |||||
Spinlock m_lock; | |||||
std::unordered_map<Tensor*, std::weak_ptr<Tensor>> m_recorded_tensors; | |||||
std::string m_name; | |||||
std::string m_cat; | |||||
std::optional<uint64_t> m_tid; | |||||
std::optional<uint64_t> m_pid; | |||||
std::optional<uint64_t> m_id; | |||||
std::optional<uint64_t> m_idx; | |||||
std::optional<double> m_ts; | |||||
std::optional<double> m_dur; | |||||
std::optional<char> m_ph; | |||||
std::optional<char> m_bp; | |||||
std::shared_ptr<json::Object> m_args; | |||||
}; | |||||
class ChromeTraceEventList { | |||||
public: | |||||
ChromeTraceEvent& new_event() { | |||||
m_content.emplace_back(); | |||||
return m_content.back(); | |||||
} | |||||
std::shared_ptr<json::Array> to_json() { | |||||
auto result = json::Array::make(); | |||||
for (auto&& event: m_content) { | |||||
result->add(event.to_json()); | |||||
} | |||||
return result; | |||||
} | |||||
private: | |||||
std::vector<ChromeTraceEvent> m_content; | |||||
}; | }; | ||||
} // namespace imperative | } // namespace imperative | ||||
@@ -0,0 +1,125 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/to_string.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
#include <type_traits> | |||||
#include <memory> | |||||
#include <tuple> | |||||
#include "megbrain/utils/small_vector.h" | |||||
#include "megbrain/tensor.h" | |||||
namespace mgb::imperative { | |||||
template <typename T> | |||||
struct ToStringTrait; | |||||
template <typename T> | |||||
std::string to_string(const T& value) { | |||||
return ToStringTrait<T>{}(value); | |||||
} | |||||
template <typename T> | |||||
struct ToStringTrait{ | |||||
std::string operator()(const T& value) const { | |||||
return std::to_string(value); | |||||
} | |||||
}; | |||||
template <> | |||||
struct ToStringTrait<std::string>{ | |||||
std::string operator()(const std::string& value) const { | |||||
return value; | |||||
} | |||||
}; | |||||
template <typename T, unsigned N> | |||||
struct ToStringTrait<SmallVector<T, N>>{ | |||||
std::string operator()(const SmallVector<T, N>& sv) const { | |||||
if (sv.empty()) { | |||||
return "[]"; | |||||
} | |||||
std::string result = "["; | |||||
result += to_string(sv[0]); | |||||
for (size_t i = 1; i < sv.size(); ++i) { | |||||
result += ", "; | |||||
result += to_string(sv[i]); | |||||
} | |||||
return result + "]"; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct ToStringTrait<std::shared_ptr<T>>{ | |||||
std::string operator()(const std::shared_ptr<T>& sp) const { | |||||
return to_string(sp.get()); | |||||
} | |||||
}; | |||||
template <typename TKey, typename TValue> | |||||
struct ToStringTrait<std::pair<TKey, TValue>>{ | |||||
std::string operator()(const std::pair<TKey, TValue>& pr) const { | |||||
return "(" + to_string(pr.first) + ", " + to_string(pr.second) + ")"; | |||||
} | |||||
}; | |||||
template <typename TItem, typename... TItems> | |||||
struct ToStringTrait<std::tuple<TItem, TItems...>>{ | |||||
std::string operator()(const std::tuple<TItem, TItems...>& tp) const { | |||||
auto folder = [&](auto... item){ return ( ...+ ("," + to_string(item))); }; | |||||
return "(" + std::apply(folder, tp) + ")"; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct ToStringTrait<T*>{ | |||||
std::string operator()(T* p) const { | |||||
return ssprintf("%p", p); | |||||
} | |||||
}; | |||||
template <> | |||||
struct ToStringTrait<TensorShape>{ | |||||
std::string operator()(TensorShape shape) const { | |||||
if (shape.ndim > TensorShape::MAX_NDIM) { | |||||
printf("ndim: %d\n", (int)shape.ndim); | |||||
return "[]"; | |||||
} | |||||
mgb_assert(shape.ndim <= TensorShape::MAX_NDIM); | |||||
if (shape.ndim == 0) { | |||||
return "[ ]"; | |||||
} | |||||
std::string result = "[ " + std::to_string(shape[0]); | |||||
for (size_t i = 1; i < shape.ndim; i++) { | |||||
result += ", "; | |||||
result += std::to_string(shape[i]); | |||||
} | |||||
return result + " ]"; | |||||
} | |||||
}; | |||||
template <> | |||||
struct ToStringTrait<DType>{ | |||||
std::string operator()(DType dtype) const { | |||||
return dtype.name(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct ToStringTrait<CompNode>{ | |||||
std::string operator()(CompNode device) const { | |||||
return device.to_string(); | |||||
} | |||||
}; | |||||
} |
@@ -222,10 +222,25 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | ||||
os << "}\n"; | os << "}\n"; | ||||
// generate props() | |||||
os << formatv( | |||||
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||||
formatMethImpl("props") | |||||
); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
os << "} // anonymous namespace\n"; | os << "} // anonymous namespace\n"; | ||||
methods.push_back("hash"); | methods.push_back("hash"); | ||||
methods.push_back("is_same_st"); | methods.push_back("is_same_st"); | ||||
methods.push_back("props"); | |||||
} | } | ||||
if (!methods.empty()) { | if (!methods.empty()) { | ||||
os << formatv( | os << formatv( | ||||
@@ -423,7 +438,7 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||||
std::vector<std::string> getsetters; | std::vector<std::string> getsetters; | ||||
for (auto &&i : op.getMgbAttributes()) { | for (auto &&i : op.getMgbAttributes()) { | ||||
getsetters.push_back(formatv( | getsetters.push_back(formatv( | ||||
"{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},", | |||||
"{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},", | |||||
className, i.name)); | className, i.name)); | ||||
} | } | ||||
@@ -66,7 +66,7 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||||
} | } | ||||
llvm::StringRef getParentNamespace() const { | llvm::StringRef getParentNamespace() const { | ||||
return getBaseRecord()->getValueAsString("parentNamespce"); | |||||
return getBaseRecord()->getValueAsString("parentNamespace"); | |||||
} | } | ||||
llvm::StringRef getEnumName() const { | llvm::StringRef getEnumName() const { | ||||
return getBaseRecord()->getValueAsString("enumName"); | return getBaseRecord()->getValueAsString("enumName"); | ||||
@@ -87,6 +87,9 @@ struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | |||||
llvm::StringRef getCmpFunctionTemplate() const { | llvm::StringRef getCmpFunctionTemplate() const { | ||||
return getBaseRecord()->getValueAsString("cmpFunction"); | return getBaseRecord()->getValueAsString("cmpFunction"); | ||||
} | } | ||||
llvm::StringRef getReprFunctionTemplate() const { | |||||
return getBaseRecord()->getValueAsString("reprFunction"); | |||||
} | |||||
}; | }; | ||||
struct MgbAliasAttrMixin : public MgbAttrWrapperBase { | struct MgbAliasAttrMixin : public MgbAttrWrapperBase { | ||||
@@ -205,6 +208,39 @@ private: | |||||
body += " return true;\n"; | body += " return true;\n"; | ||||
return body; | return body; | ||||
} | } | ||||
std::string getDefaultPropsFunction() const { | |||||
std::string body = " std::vector<std::pair<const char*, std::string>> props_;\n"; | |||||
if (!getMgbAttributes().empty()) { | |||||
mlir::tblgen::FmtContext ctx; | |||||
for (auto&& it : getMgbAttributes()) { | |||||
if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) { | |||||
body += formatv(" switch ({0}){{\n", "$_self." + it.name); | |||||
for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||||
body += formatv( | |||||
" case {0}::{1}::{2}:\n", | |||||
getCppClassName(), enumAttr->getEnumName(), enumMember | |||||
); | |||||
body += formatv( | |||||
" props_.emplace_back(\"{0}\", \"{1}\");\n", | |||||
it.name, enumMember | |||||
); | |||||
body += " break;\n"; | |||||
} | |||||
body += " default: break;\n"; | |||||
body += " }\n"; | |||||
} else { | |||||
auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | |||||
body += formatv( | |||||
" props_.emplace_back(\"{0}\", {1});\n", it.name, | |||||
mlir::tblgen::tgfmt(attr.getReprFunctionTemplate(), | |||||
&ctx, "$_self." + it.name) | |||||
); | |||||
} | |||||
} | |||||
} | |||||
body += " return props_;\n"; | |||||
return body; | |||||
} | |||||
public: | public: | ||||
static bool classof(const Operator* op) { | static bool classof(const Operator* op) { | ||||
return op->getDef().isSubClassOf("MgbHashableOpMixin"); | return op->getDef().isSubClassOf("MgbHashableOpMixin"); | ||||
@@ -222,7 +258,13 @@ public: | |||||
} | } | ||||
return getDefaultCmpFunction(); | return getDefaultCmpFunction(); | ||||
} | } | ||||
std::string getPropsFunctionTemplate() const { | |||||
if (auto f = getDef().getValueAsOptionalString("propsFunction")) { | |||||
return f.getValue().str(); | |||||
} | |||||
return getDefaultPropsFunction(); | |||||
} | |||||
}; | }; | ||||
} // namespace tblgen | } // namespace tblgen | ||||
} // namespace mlir | |||||
} // namespace mlir |
@@ -30,6 +30,7 @@ class MgbHashableAttrMixin { | |||||
string hashFunction = "mgb::hash($0)"; | string hashFunction = "mgb::hash($0)"; | ||||
// return 0 for eq, else for ne | // return 0 for eq, else for ne | ||||
string cmpFunction = "$0 != $1"; | string cmpFunction = "$0 != $1"; | ||||
string reprFunction = "std::to_string($0)"; | |||||
} | } | ||||
class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | ||||
@@ -98,6 +99,7 @@ def MgbStringAttr : HashableAttr<"std::string"> { | |||||
let storageType = "::mlir::StringAttr"; | let storageType = "::mlir::StringAttr"; | ||||
let convertFromStorage = "$_self.getValue().str()"; | let convertFromStorage = "$_self.getValue().str()"; | ||||
let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor | let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor | ||||
string reprFunction = "$0"; | |||||
} | } | ||||
class MgbArrayAttr<MgbAttrWrapper elem>: | class MgbArrayAttr<MgbAttrWrapper elem>: | ||||
@@ -123,6 +125,7 @@ class MgbArrayAttr<MgbAttrWrapper elem>: | |||||
" });\n" | " });\n" | ||||
" return $_builder.getArrayAttr(ret" # recursionDepth # ");" | " return $_builder.getArrayAttr(ret" # recursionDepth # ");" | ||||
"}()"; | "}()"; | ||||
let reprFunction = "\"{std::vector}\""; | |||||
} | } | ||||
defvar EmptyStrList = !listsplat("", 0); | defvar EmptyStrList = !listsplat("", 0); | ||||
@@ -168,6 +171,7 @@ class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||||
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | ||||
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | ||||
let hashFunction = "mgb::enumhash()($0)"; | let hashFunction = "mgb::enumhash()($0)"; | ||||
string reprFunction = "std::to_string((int)$0)"; | |||||
} | } | ||||
class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: | class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: | ||||
@@ -179,12 +183,14 @@ def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { | |||||
let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))"; | let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))"; | ||||
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))"; | let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))"; | ||||
let hashFunction = "mgb::hash($0.handle())"; | let hashFunction = "mgb::hash($0.handle())"; | ||||
let reprFunction = "$0.name()"; | |||||
} | } | ||||
def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> { | def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> { | ||||
let storageType = "::mlir::StringAttr"; | let storageType = "::mlir::StringAttr"; | ||||
let convertFromStorage = underlyingType # "::load($_self.getValue().str())"; | let convertFromStorage = underlyingType # "::load($_self.getValue().str())"; | ||||
let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())"; | let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())"; | ||||
string reprFunction = "$0.to_string()"; | |||||
} | } | ||||
def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { | def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { | ||||
@@ -209,6 +215,7 @@ def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { | |||||
" }\n" | " }\n" | ||||
" return $_builder.getArrayAttr(ret);" | " return $_builder.getArrayAttr(ret);" | ||||
"}()"; | "}()"; | ||||
let reprFunction = "$0.to_string()"; | |||||
} | } | ||||
class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>: | class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>: | ||||