Browse Source

refactor(profiler): integrate profiler into interpreter

GitOrigin-RevId: ccc984acbd
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
dbb3dd681f
31 changed files with 2058 additions and 958 deletions
  1. +3
    -0
      imperative/python/megengine/autodiff/grad_manager.py
  2. +12
    -0
      imperative/python/megengine/core/__init__.py
  3. +12
    -0
      imperative/python/megengine/module/module.py
  4. +5
    -0
      imperative/python/megengine/optimizer/optimizer.py
  5. +50
    -213
      imperative/python/megengine/utils/profiler.py
  6. +23
    -5
      imperative/python/src/tensor.cpp
  7. +0
    -27
      imperative/python/src/utils.cpp
  8. +54
    -0
      imperative/python/test/integration/test_profiler.py
  9. +17
    -32
      imperative/src/impl/function_hook.h
  10. +231
    -0
      imperative/src/impl/interpreter/commands.h
  11. +92
    -0
      imperative/src/impl/interpreter/events.h
  12. +226
    -76
      imperative/src/impl/interpreter/interpreter_impl.cpp
  13. +205
    -0
      imperative/src/impl/interpreter/interpreter_impl.h
  14. +61
    -0
      imperative/src/impl/interpreter/option_manager.h
  15. +280
    -0
      imperative/src/impl/interpreter/profiler.cpp
  16. +97
    -0
      imperative/src/impl/interpreter/profiler.h
  17. +135
    -0
      imperative/src/impl/interpreter/tensor_info.h
  18. +0
    -351
      imperative/src/impl/interpreter_impl.h
  19. +20
    -0
      imperative/src/impl/op_def.cpp
  20. +3
    -0
      imperative/src/impl/op_trait.h
  21. +6
    -0
      imperative/src/impl/ops/backward_graph.cpp
  22. +5
    -0
      imperative/src/impl/ops/opr_attr.cpp
  23. +22
    -178
      imperative/src/impl/profiler.cpp
  24. +1
    -0
      imperative/src/impl/proxy_graph/mini_graph.h
  25. +9
    -5
      imperative/src/include/megbrain/imperative/interpreter.h
  26. +18
    -0
      imperative/src/include/megbrain/imperative/op_def.h
  27. +279
    -68
      imperative/src/include/megbrain/imperative/profiler.h
  28. +125
    -0
      imperative/src/include/megbrain/imperative/utils/to_string.h
  29. +16
    -1
      imperative/tablegen/autogen.cpp
  30. +44
    -2
      imperative/tablegen/helper.h
  31. +7
    -0
      src/core/include/megbrain/ir/base.td

+ 3
- 0
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -3,6 +3,7 @@ from collections import defaultdict
from contextlib import contextmanager
from typing import Callable

from ..core._imperative_rt.core2 import pop_scope, push_scope
from ..core.autodiff.grad import Grad
from ..logger import get_logger
from ..tensor import Tensor
@@ -239,6 +240,7 @@ class GradManager:
:param y: tensor or list of tensors
:param dy: tensor or list of tensors. Defaults to 1 if y is scalar
"""
push_scope("backward")
from ..functional import ones_like

global backwarding_grad_manager
@@ -280,6 +282,7 @@ class GradManager:
finally:
self.release()
backwarding_grad_manager = cache
pop_scope("backward")

def record(self):
r"""


+ 12
- 0
imperative/python/megengine/core/__init__.py View File

@@ -8,5 +8,17 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import sys
from contextlib import contextmanager

from ._imperative_rt.core2 import get_option, set_option
from .tensor.megbrain_graph import Graph


@contextmanager
def option(key, value):
value = int(value)
old = get_option(key)
set_option(key, value)
yield
assert get_option(key) == value
set_option(key, old)

+ 12
- 0
imperative/python/megengine/module/module.py View File

@@ -12,6 +12,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union

import numpy as np

from ..core._imperative_rt.core2 import pop_scope, push_scope
from ..core.tensor.utils import make_shape_tuple
from ..logger import get_logger
from ..tensor import Parameter, Tensor
@@ -78,6 +79,7 @@ class Module(metaclass=ABCMeta):
self._forward_hooks = OrderedDict()

self._modules = []
self._name = "{anonymous}"

@abstractmethod
def forward(self, inputs):
@@ -103,6 +105,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook)

def __call__(self, *inputs, **kwargs):
push_scope(self._name)
for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs)
if modified_inputs is not None:
@@ -116,6 +119,7 @@ class Module(metaclass=ABCMeta):
modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None:
outputs = modified_outputs
pop_scope(self._name)
return outputs

def _flatten(
@@ -571,6 +575,14 @@ class Module(metaclass=ABCMeta):

return set(loaded), set(skipped)

def __getattribute__(self, name: str):
value = super().__getattribute__(name)
if name == "_name":
return value
if _is_module(value):
value._name = name
return value

def __setattr__(self, name: str, value):
if _is_module(value):
modules = self.__dict__.get("_modules")


+ 5
- 0
imperative/python/megengine/optimizer/optimizer.py View File

@@ -15,6 +15,7 @@ from typing import Union

import numpy as np

from ..core._imperative_rt.core2 import pop_scope, push_scope
from ..core.tensor.utils import set_convert_inputs
from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated
@@ -155,7 +156,9 @@ class Optimizer(metaclass=ABCMeta):
"but the ordering of parameters in sets will change between runs. "
"Please use a list instead."
)
push_scope("step")
self._updates(group)
pop_scope("step")
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
return self
@@ -172,8 +175,10 @@ class Optimizer(metaclass=ABCMeta):
Set the grad attribute to None for all parameters.
"""
for param_group in self.param_groups:
push_scope("clear_grad")
for param in param_group["params"]:
param.grad = None
pop_scope("clear_grad")

def state_dict(self) -> Dict:
r"""


+ 50
- 213
imperative/python/megengine/utils/profiler.py View File

@@ -6,159 +6,17 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import base64
import json
import os
import re
from typing import Iterable, List, Optional
from contextlib import contextmanager
from typing import List

from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry
from ..core._imperative_rt import ProfilerImpl as _Profiler
from ..core._imperative_rt.core2 import sync
from ..core._imperative_rt.ops import CollectiveComm


def _make_dict(**kwargs):
unused_keys = []
for k, v in kwargs.items():
if v is None:
unused_keys.append(k)
for k in unused_keys:
del kwargs[k]
return kwargs


def _print_opnode_config(config):
return _make_dict(
name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr,
)


def _dump_chrome_timeline(entries: List[ProfileEntry], path: str):
pid = os.getpid()
trace_events = []

def append_event(**kwargs):
trace_events.append(_make_dict(**kwargs))

for id, entry in enumerate(entries):
op = entry.op
name = type(op).__name__
host_begin, host_end = entry.host
device_list = entry.device_list
args = Profiler.fetch_attrs(op)
args["__id__"] = "[{}]".format(id)
cat = name
for ts, ph in [(host_begin, "B"), (host_end, "E")]:
append_event(
name=name, ph=ph, ts=ts * 1000, pid=pid, tid="host", args=args, cat=cat,
)
for device, device_begin, device_end in device_list:
for ts, ph in [(device_begin(), "B"), (device_end(), "E")]:
append_event(
name=name, ph=ph, ts=ts * 1000, pid=pid, tid=str(device), args=args,
)
with open("{}.chrome_timeline.json".format(path), "w") as f:
json.dump(trace_events, f, indent=2)


def _dump_compatible(entries: List[ProfileEntry], path: str):
obj = {
"graph_exec": {"var": [], "operator": {}},
"profiler": {"device": {}, "host": {}, "opr_footprint": {}},
}
var_list = obj["graph_exec"]["var"]
operator_dict = obj["graph_exec"]["operator"]
device_dict = obj["profiler"]["device"]
host_dict = obj["profiler"]["host"]
opr_foot_print_dict = obj["profiler"]["opr_footprint"]

def add_var(var) -> int:
var_id = len(var_list)
var_list.append(
{"comp_node": str(var[2]),}
)
return var_id

for op_id, entry in enumerate(entries):
operator_dict[op_id] = {
"input": [add_var(var) for var in entry.inputs],
"output": [add_var(var) for var in entry.outputs],
"name": str(entry.op.ctype()),
"type": "imperative",
"id": entry.id,
}
op_device_dict = {}
for device, device_begin, device_end in entry.device_list:
op_device_dict[str(device)] = {
"start": device_begin(),
"kern": device_begin(),
"end": device_end(),
}
device_dict[op_id] = op_device_dict
host_begin, host_end = entry.host
host_dict[op_id] = {
"host": {"start": host_begin, "kern": host_begin, "end": host_end}
}
opr_footprint = {
"out_shapes": [oup[1] for oup in entry.outputs],
"in_shapes": [inp[1] for inp in entry.inputs],
"params": {},
}
if entry.memory > 0:
opr_footprint["memory"] = entry.memory
if entry.computation > 0:
opr_footprint["computation"] = entry.computation
opr_foot_print_dict[op_id] = opr_footprint
with open("{}.compatible.json".format(path), "w") as f:
json.dump(obj, f, indent=2)


def _dump_graphviz(entries: List[ProfileEntry], path: str):
import json

import graphviz

graph = graphviz.Digraph()
graph.graph_attr["ordering"] = "out"
var_cache = {}

def cache_var(var_id, var_shape):
if var_id not in var_cache:
var_name = "var({})".format(var_id)
var_label = "{}\nshape:{}\n".format(var_name, shape)
graph.node(var_name, var_label)
var_cache[var_id] = var_name
return var_cache[var_id]

for op_id, entry in enumerate(entries):
op = entry.op
op_name = "op({})".format(op_id)
op_type = type(op).__name__
op_attrs = Profiler.fetch_attrs(op)
label_lines = []
if "param" in op_attrs:
del op_attrs["param"]
label_lines.append("{}:{}".format(op_name, op_type))
for k, v in op_attrs.items():
label_lines.append("attr[{}]: {}".format(k, v))
op_param_str = entry.param
if len(op_param_str) > 0:
op_param = json.loads(op_param_str)
for k, v in op_param.items():
label_lines.append("param[{}]:{}".format(k, v))
host_begin, host_end = entry.host
label_lines.append("time[host]: {:f}ms".format(host_end - host_begin))
for device, device_begin, device_end in entry.device_list:
device_time = device_end() - device_begin()
label_lines.append("time[{}]: {:f}ms".format(device, device_time))
op_label = "\n".join(label_lines)
graph.node(op_name, op_label, shape="rectangle")
for var_id, shape, device in entry.inputs:
graph.edge(cache_var(var_id, shape), op_name)
for var_id, shape, device in entry.outputs:
graph.edge(op_name, cache_var(var_id, shape))
graph.save("{}.graphviz.dot".format(path))
from ..core._imperative_rt.core2 import (
pop_scope,
push_scope,
start_profile,
stop_profile,
sync,
)


class Profiler:
@@ -181,85 +39,45 @@ class Profiler:
# Only profile record of last iter would be saved
with Profiler("profile"):
# your code here
# Then open the profile file in chrome timeline window
"""

CHROME_TIMELINE = "chrome_timeline"
COMPATIBLE = "compatible"
GRAPHVIZ = "graphviz"

WITH_FOOTPRINT = 1
CHROME_TIMELINE = "chrome_timeline.json"

_type_map = {
OperatorNodeConfig: lambda x: _print_opnode_config(x),
bytes: lambda x: base64.encodebytes(x).decode("ascii"),
CollectiveComm.Mode: lambda x: str(x),
}

_dumper_map = {
CHROME_TIMELINE: _dump_chrome_timeline,
COMPATIBLE: _dump_compatible,
GRAPHVIZ: _dump_graphviz,
}
COMMAND = 1 << 0
OPERATOR = 1 << 1
TENSOR_LIFETIME = 1 << 2
TENSOR_PROP = 1 << 3
SYNC = 1 << 4
SCOPE = 1 << 5
ALL = (1 << 6) - 1

def __init__(
self,
path: str = "profile",
format: str = CHROME_TIMELINE,
*,
formats: Iterable[str] = (CHROME_TIMELINE,),
type_filter: str = ".*",
exit_dump: bool = True
topic=OPERATOR | SCOPE,
align_time=True,
show_operator_name=True
) -> None:
self._impl = _Profiler()
self._path = path

if isinstance(formats, str):
formats = (formats,)

self._filter = type_filter
self._dumpers = [Profiler._dumper_map[fmt] for fmt in formats]
self._exit_dump = exit_dump
self._format = format
self._options = {
"topic": int(topic),
"align_time": int(align_time),
"show_operator_name": int(show_operator_name),
}

def __enter__(self):
sync()
self._impl.start(Profiler.WITH_FOOTPRINT)
start_profile(self._options)
return self

def __exit__(self, val, tp, trace):
if self._exit_dump:
self.dump()
sync()
self._impl.stop()
self._impl.clear()

@classmethod
def fetch_attrs(cls, op):
attrs = dir(op)
results = {}
for attr in attrs:
if attr.startswith("_"):
continue
value = op.__getattribute__(attr)
if callable(value):
continue
value_type = type(value)
if value_type in cls._type_map:
value = cls._type_map[value_type](value)
results[attr] = str(value)
return results

def dump(self, path: Optional[str] = None):
stop_profile(self._path, self._format)
# dump is async, so it's necessary to sync interpreter
sync()
raw = [
entry
for entry in self._impl.dump()
if re.match(self._filter, type(entry.op).__name__)
]
if path is None:
path = self._path
for dumper in self._dumpers:
dumper(raw, path)

def __call__(self, func):
def wrapper(*args, **kwargs):
@@ -269,4 +87,23 @@ class Profiler:
return wrapper


@contextmanager
def scope(name):
push_scope(name)
yield
pop_scope(name)


profile = Profiler


def merge_trace_events(sources: List[str], target: str):
names = list(map(lambda x: x + ".chrome_timeline.json", sources))
result = []
for name in names:
with open(name, "r", encoding="utf-8") as f:
content = json.load(f)
for entry in content:
result.append(entry)
with open(target + ".chrome_timeline.json", "w") as f:
json.dump(result, f, ensure_ascii=False, indent=4)

+ 23
- 5
imperative/python/src/tensor.cpp View File

@@ -807,16 +807,34 @@ void init_tensor(py::module m) {
}
}

m.def("set_option",
[](std::string name, int value){ interpreter_for_py->set_option(name, value); });
m.def("get_option",
[](std::string name){ return interpreter_for_py->get_option(name); });
m.def("_set_swap_flag",
[](bool flag) { interpreter_for_py->set_swap_flag(flag); });
[](bool flag) { interpreter_for_py->set_option("enable_swap", flag); });
m.def("_set_drop_flag",
[](bool flag) { interpreter_for_py->set_drop_flag(flag); });
[](bool flag) { interpreter_for_py->set_option("enable_drop", flag); });
m.def("config_async_level",
[](int level) { interpreter_for_py->config_async_level(level); });
[](int level) {
mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
interpreter_for_py->set_option("async_level", level);
});
m.def("get_async_level",
[]() { return interpreter_for_py->get_async_level(); });
[]() { return interpreter_for_py->get_option("async_level"); });
m.def("set_buffer_length",
[](int length) { interpreter_for_py->set_buffer_length(length); });
[](int length) {
mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
interpreter_for_py->set_option("buffer_length", length);
});
m.def("push_scope",
[](std::string name) { interpreter_for_py->push_scope(name); });
m.def("pop_scope",
[](std::string name) { interpreter_for_py->pop_scope(name); });
m.def("start_profile",
[](std::unordered_map<std::string, int> option) { return interpreter_for_py->start_profile(option); });
m.def("stop_profile",
[](std::string basename, std::string format) { interpreter_for_py->stop_profile(basename, format); });
m.def("sync",
[]() {
interpreter_for_py->sync();


+ 0
- 27
imperative/python/src/utils.cpp View File

@@ -200,33 +200,6 @@ void init_utils(py::module m) {
m.def("_get_device_count", &mgb::CompNode::get_device_count,
"Get total number of specific devices on this system");

using mgb::imperative::ProfileEntry;

py::class_<ProfileEntry>(m, "ProfileEntry")
.def_readwrite("op", &ProfileEntry::op)
.def_readwrite("host", &ProfileEntry::host)
.def_readwrite("device_list", &ProfileEntry::device_list)
.def_readwrite("inputs", &ProfileEntry::inputs)
.def_readwrite("outputs", &ProfileEntry::outputs)
.def_readwrite("id", &ProfileEntry::id)
.def_readwrite("parent", &ProfileEntry::parent)
.def_readwrite("memory", &ProfileEntry::memory)
.def_readwrite("computation", &ProfileEntry::computation)
.def_property_readonly("param", [](ProfileEntry& self)->std::string{
if(self.param){
return self.param->to_string();
} else {
return {};
}
});

py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl")
.def(py::init<>())
.def("start", &mgb::imperative::Profiler::start)
.def("stop", &mgb::imperative::Profiler::stop)
.def("clear", &mgb::imperative::Profiler::clear)
.def("dump", &mgb::imperative::Profiler::get_profile);

using mgb::imperative::TensorSanityCheck;
py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")
.def(py::init<>())


+ 54
- 0
imperative/python/test/integration/test_profiler.py View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
import json
import os

import pytest

from megengine import Parameter, tensor
from megengine.core import option
from megengine.module import Module
from megengine.utils.profiler import Profiler, scope


class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter([1.23], dtype="float32")

def forward(self, x):
x = x * self.a
return x


def test_profiler():
profile_prefix = "pytest_profile"
profile_format = "chrome_timeline.json"
profile_path = "{}.{}".format(profile_prefix, profile_format)
with Profiler(profile_prefix, format=profile_format):
with scope("my_scope"):
oup = Simple()(tensor([1.23], dtype="float32"))
with open(profile_path, "r") as f:
events = json.load(f)
os.remove(profile_path)
prev_ts = {}
scope_count = 0
for event in events:
if "dur" in event:
assert event["dur"] >= 0
elif "ts" in event and "tid" in event:
ts = event["ts"]
tid = event["tid"]
if ts == 0:
continue
assert (tid not in prev_ts) or prev_ts[tid] <= ts
prev_ts[tid] = ts
if "name" in event and event["name"] == "my_scope":
scope_count += 1
assert scope_count > 0 and scope_count % 2 == 0

+ 17
- 32
imperative/src/impl/function_hook.h View File

@@ -17,52 +17,37 @@ namespace mgb {
namespace imperative {

template <typename TFunction>
class FunctionHooker;
class FunctionHook;

template <typename TRet, typename... TArgs>
class FunctionHooker<TRet(TArgs...)> {
template <template <typename> class TFunction, typename TRet, typename... TArgs>
class FunctionHook<TFunction<TRet(TArgs...)>> {
public:
using FunctionType = thin_function<TRet(TArgs...)>;
//Type of hooks. Hook should accept a real function as argument
//and invoke it on an appropriate time
using HookType = thin_function<TRet(FunctionType, TArgs...)>;
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} {
m_backup = {nullptr, [](FunctionType*){}};
using FunctionType = TFunction<TRet(TArgs...)>;
explicit FunctionHook(FunctionType* fptr) : m_fptr{fptr} {
m_backup = *fptr;
}

public:
FunctionHooker& apply_hook(HookType&& hook) {
if (!m_backup) {
FunctionType* backup = new FunctionType(*m_fptr);
//Restore hooked function, would be invoked when destructed
std::function<void(FunctionType*)> restorer =
[fptr = m_fptr](FunctionType* bkp) -> void {
*fptr = *bkp;
delete bkp;
};
m_backup = decltype(m_backup)(backup, restorer);
}
template <typename THook, typename=std::enable_if_t<std::is_invocable_r_v<TRet, THook, FunctionType, TArgs...>, void>>
FunctionHook& apply_hook(THook&& hook) {
//Replace with hooked version
*m_fptr = [func = *m_fptr, hook](TArgs... args) -> TRet {
*m_fptr = [func = *m_fptr, hook=std::forward<THook>(hook)](TArgs... args) -> TRet {
return hook(func, std::forward<TArgs>(args)...);
};
//Convinent for chain call
return *this;
}

private:
FunctionType* m_fptr;
std::unique_ptr<FunctionType, std::function<void(FunctionType*)>> m_backup;
FunctionType m_backup;
public:
~FunctionHook() {
*m_fptr = std::move(m_backup);
}
};

//Helps to deduce template args
template <typename TRet, typename... TArgs>
FunctionHooker(thin_function<TRet(TArgs...)>* f)
-> FunctionHooker<TRet(TArgs...)>;

template<typename TSignature>
auto make_shared_hook(thin_function<TSignature>* fptr){
return std::make_shared<FunctionHooker<TSignature>>(fptr);
template<typename TFunction>
auto make_shared_hook(TFunction* fptr){
return std::make_shared<FunctionHook<TFunction>>(fptr);
}

} // namespace imperative


+ 231
- 0
imperative/src/impl/interpreter/commands.h View File

@@ -0,0 +1,231 @@
/**
* \file imperative/src/impl/interpreter/commands.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include <string>
#include <variant>

#include "megbrain/tensor.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"

namespace mgb::imperative {

namespace interpreter::intl {

struct TensorInfo;
class InterpreterProfiler;

struct Put {
TensorInfo* dest;
HostTensorND value;
bool no_cache = false;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
functor("no_cache", no_cache);
//functor("value", value);
}

const char* get_name() const {
return "Put";
}
};

struct ApplyOp {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<TensorInfo*> dels;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("op", op);
functor("inputs", inputs);
functor("outputs", outputs);
functor("dels", dels);
}

const char* get_name() const {
return "ApplyOp";
}
};

struct Del {
TensorInfo* dest;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}

const char* get_name() const {
return "Del";
}
};

struct GetValue {
TensorInfo* dest;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}

const char* get_name() const {
return "GetValue";
}
};

struct SwapIn {
TensorInfo* dest;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}

const char* get_name() const {
return "SwapIn";
}
};

struct SwapOut {
TensorInfo* dest;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}

const char* get_name() const {
return "SwapOut";
}
};

struct Drop {
TensorInfo* dest;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}

const char* get_name() const {
return "Drop";
}
};

struct SetOption {
std::string key;
int value;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("key", key);
functor("value", value);
}

const char* get_name() const {
return "SetOption";
}
};

struct StartProfile {
InterpreterProfiler* profiler;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {}

const char* get_name() const {
return "StartProfile";
}
};

struct StopProfile {
std::string basename;
std::string format;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("basename", basename);
functor("format", format);
}

const char* get_name() const {
return "StopProfile";
}
};

struct PushScope {
std::string scope_name;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("scope_name", scope_name);
}

const char* get_name() const {
return "PushScope";
}
};

struct PopScope {
std::string scope_name;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("scope_name", scope_name);
}

const char* get_name() const {
return "PopScope";
}
};

using Command = std::variant<Put,
ApplyOp,
Del,
GetValue,
SwapIn,
SwapOut,
Drop,
SetOption,
StartProfile,
StopProfile,
PushScope,
PopScope>;

using IdentifiedCommand = std::pair<uint64_t, Command>;

}

template <>
struct ToStringTrait<interpreter::intl::Command>{
std::string operator()(const interpreter::intl::Command& cmd) const {
return std::visit([](auto& cmd){
std::string result = cmd.get_name();
result += "{";
cmd.get_props([&](const char* key, auto&& value) {
result += key;
result += ": ";
result += to_string(value);
result += ",";
});
result += "}";
return result;
}, cmd);
}
};

}

+ 92
- 0
imperative/src/impl/interpreter/events.h View File

@@ -0,0 +1,92 @@
/**
* \file imperative/src/impl/interpreter/events.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "./commands.h"
#include "./tensor_info.h"

namespace mgb::imperative::interpreter::intl {

struct CommandEvent {
IdentifiedCommand icmd;
};

struct CommandEnqueueEvent: CommandEvent {};

struct CommandExecuteEvent: CommandEvent {};

struct CommandFinishEvent: CommandEvent {};

struct OpEvent {
uint64_t id;
std::shared_ptr<OpDef> op;
SmallVector<uint64_t> inputs;
SmallVector<uint64_t> outputs;
};

struct HostOpExecuteEvent: OpEvent {};

struct DeviceOpExecuteEvent: OpEvent {};

struct HostOpFinishEvent: OpEvent {};

struct DeviceOpFinishEvent: OpEvent {};

struct TensorDeclareEvent {
uint64_t tensor_id;
};

struct TensorProduceEvent {
uint64_t tensor_id;
TensorLayout layout;
CompNode device;
};

struct TensorEraseEvent {
uint64_t tensor_id;
};

struct TensorPropEvent {
uint64_t tensor_id;
TensorInfo::Prop prop;
std::string prop_desc;
};

struct TensorGetPropEvent: TensorPropEvent{};

struct TensorWaitPropEvent: TensorPropEvent{};

struct TensorNotifyPropEvent: TensorPropEvent{};

struct TensorWaitPropFinishEvent: TensorPropEvent{};

struct SyncStartEvent {};

struct SyncFinishEvent {};

struct ScopeEvent {
std::string name;
};

struct ChannelBeginScope: ScopeEvent {};

struct ChannelEndScope: ScopeEvent {};

struct WorkerBeginScope: ScopeEvent {};

struct WorkerEndScope: ScopeEvent {};

struct DeviceBeginScope: ScopeEvent {};

struct DeviceEndScope: ScopeEvent {};

}

imperative/src/impl/interpreter_impl.cpp → imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -1,5 +1,5 @@
/**
* \file imperative/src/impl/interpreter_impl.cpp
* \file imperative/src/impl/interpreter/interpreter_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,10 +10,14 @@
*/

#include "./interpreter_impl.h"

#include "megbrain/common.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/to_string.h"

#include "../op_trait.h"

using namespace mgb;
using namespace imperative;
@@ -48,6 +52,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data);
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
return info;
}

@@ -61,7 +66,7 @@ void ChannelImpl::del(Handle handle) {
}

void ChannelImpl::swap_in(Handle handle) {
if (m_enable_evict & SWAP) {
if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
@@ -71,7 +76,7 @@ void ChannelImpl::swap_in(Handle handle) {
}

void ChannelImpl::swap_out(Handle handle) {
if (m_enable_evict & SWAP) {
if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
@@ -81,7 +86,7 @@ void ChannelImpl::swap_out(Handle handle) {
}

void ChannelImpl::drop(Handle handle) {
if (m_enable_evict & DROP) {
if (m_worker_state.options.enable_drop) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
@@ -100,6 +105,7 @@ void ChannelImpl::dispatch_default_cpu(
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) {
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
MGB_MARK_USED_VAR(validated);

SmallVector<DeviceTensorND> input_tensornds;
input_tensornds.reserve(input_descs.size());
@@ -133,6 +139,17 @@ void ChannelImpl::dispatch_default_cpu(
output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
}

auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) {
tid.push_back(ptinfo->id);
}
return tid;
};
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}};

m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data);

OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);

SmallVector<TensorInfo*> output_infos;
@@ -146,9 +163,14 @@ void ChannelImpl::dispatch_default_cpu(
output_infos.push_back(info);
outputs->push_back(info);
}
if (m_enable_evict & DROP) {

if (m_channel_state.options.enable_drop) {
TensorInfo::ComputePath::make(op, input_infos, output_infos);
}

event_data.outputs = tinfo_to_tid(output_infos);

m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data);
}

void ChannelImpl::dispatch_kernel(
@@ -173,13 +195,13 @@ void ChannelImpl::dispatch_kernel(
cmd.outputs.push_back(info);
outputs->push_back(info);
}
if (m_enable_evict & DROP) {
if (m_channel_state.options.enable_drop) {
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
}
m_buffer.enqueue(std::move(cmd));
if (!validated && m_async_level == 1) {
if (!validated && m_channel_state.options.async_level == 1) {
sync();
} else if (m_async_level == 0) {
} else if (m_channel_state.options.async_level == 0) {
sync();
// check device error
for (auto&& oup : *outputs) {
@@ -212,7 +234,10 @@ SmallVector<Handle> ChannelImpl::apply_op(
}

SmallVector<Handle> outputs;
switch (OpDef::decide_dispatch_mode(*op, input_descs)) {
DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute
? OpDef::decide_dispatch_mode(*op, input_descs)
: DispatchMode::KERNEL;
switch (dispatch_mode) {
case DEFAULT_CPU: {
dispatch_default_cpu(op, input_infos, input_descs, &outputs);
break;
@@ -242,11 +267,13 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
m_waitee = info;
regenerate(info);
m_buffer.enqueue(GetValue{info});
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
tensor_ptr = info->ptr;
return value_fetched();
});
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
m_waitee = nullptr;
}
return tensor_ptr->get_value();
@@ -262,11 +289,13 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee);
m_waitee = info;
m_buffer.enqueue(Flush{info});
m_buffer.flush();
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return static_cast<bool>(info->ptr);
});
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
m_waitee = nullptr;
TensorShape ret = info->ptr->layout();
mgb_assert(ret.ndim != 0);
@@ -277,6 +306,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
auto ret = info->desc.layout.dtype;
mgb_assert(ret.valid());
return ret;
@@ -286,6 +316,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
auto ret = info->desc.comp_node;
mgb_assert(ret.valid());
return ret;
@@ -299,20 +330,23 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(!m_waitee);
m_waitee = info;
regenerate(info);
m_buffer.enqueue(Flush{info});
m_buffer.flush();
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return static_cast<bool>(info->ptr);
});
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
m_waitee = nullptr;
return info->ptr->dev_tensor();
}

void ChannelImpl::sync() {
if (!m_buffer.empty()) {
m_buffer.enqueue(Flush{});
}
m_buffer.flush();
m_channel_state.profiler->record_host<SyncStartEvent>();
m_worker.wait_all_task_finish();
CompNode::sync_all();
m_channel_state.profiler->record_host<SyncFinishEvent>();
MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe();
}
@@ -321,33 +355,41 @@ void ChannelImpl::close() {
sync();
}

void ChannelImpl::config_async_level(int level) {
mgb_assert(level <= 2 && level >= 0, "async_level should be 0, 1 or 2");
m_async_level = level;
int ChannelImpl::get_option(std::string name) {
return m_channel_state.options.get_option(name);
}

int ChannelImpl::get_async_level() {
return m_async_level;
void ChannelImpl::set_option(std::string name, int value) {
m_channel_state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value});
}

TensorInfo* ChannelImpl::alloc() {
MGB_LOCK_GUARD(m_mutex);
auto info = m_pool.alloc();
m_valid_handle.insert(info);
info->id = m_last_id++;
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id);
return info;
}

void ChannelImpl::free(TensorInfo* ptr) {
MGB_LOCK_GUARD(m_mutex);
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id);
m_pool.free(ptr);
}

ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){
m_channel_state.tid = std::this_thread::get_id();
}

ChannelImpl::~ChannelImpl() {
close();
}

void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
MGB_LOCK_GUARD(m_mutex);
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
@@ -397,55 +439,57 @@ void ChannelImpl::detach_users(TensorInfo* dest) {
output->detach_producer();
}
}
dest->users.clear();
mgb_assert(dest->users.size() == 0);
//dest->users.clear();
}

void ChannelImpl::set_swap_flag(bool flag) {
if ((!flag) && (m_enable_evict & SWAP)) {
for (auto handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle);
if (info->evict_type == SWAP) {
swap_in(info);
}
void ChannelImpl::sync_device_scope(CompNode device) {
auto& prev = m_worker_state.device_scope_map[device];
auto& current = m_worker_state.scopes;
auto push_scope = [&](std::string name) {
m_worker_state.profiler->record_device<DeviceBeginScope>(device, name);
};
auto pop_scope = [&](std::string name) {
m_worker_state.profiler->record_device<DeviceEndScope>(device, name);
};
size_t similarity = 0;
for (size_t i = 0; i < prev.size() && i < current.size(); i++) {
if (prev[i] == current[i]) {
similarity++;
} else {
break;
}
}
if (flag) {
m_enable_evict |= SWAP;
} else {
m_enable_evict &= ~SWAP;
while (prev.size() > similarity) {
pop_scope(prev.back());
prev.pop_back();
}
}

void ChannelImpl::set_drop_flag(bool flag) {
if ((!flag) && (m_enable_evict & DROP)) {
for (auto handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle);
if (info->evict_type == DROP) {
recompute(info->producer);
}
}
}
if (flag) {
m_enable_evict |= DROP;
} else {
m_enable_evict &= ~DROP;
while (prev.size() < current.size()) {
prev.push_back(current[prev.size()]);
push_scope(prev.back());
}
}

void ChannelImpl::set_buffer_length(int length) {
m_buffer.set_capacity(length);
}

void ChannelImpl::process_one_task(Command& cmd) {
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd);
bool finished = false;
auto do_finish_command = [&]{
if (finished) {
return;
}
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd);
finished = true;
};
//TODO: remove std::visit for support osx 10.12
std::visit([this](auto& cmd) {
using T = std::remove_reference_t<decltype(cmd)>;
try {
auto cmd_visitor = [&](auto& cmd) {
using T = std::remove_reference_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, Put>) {
auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
produce_tensor(cmd.dest, std::move(value));
} else if constexpr (std::is_same_v<T, ApplyOp>) {
uint64_t apply_id = ++m_last_id;
SmallVector<TensorPtr> tensor_inputs;
SmallVector<CompNode> devices;
tensor_inputs.reserve(cmd.inputs.size());
// refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) {
@@ -453,6 +497,23 @@ void ChannelImpl::process_one_task(Command& cmd) {
// refcnt ++, owners: [i->ptr, tensor_inputs]
tensor_inputs.push_back(i->ptr);
}
// Begin profiling operator
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) {
tid.push_back(ptinfo->id);
}
return tid;
};
OpEvent event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)};
// Collecting devices
for (auto i : cmd.inputs) {
devices.push_back(i->desc.comp_node);
}
for (auto i : cmd.outputs) {
devices.push_back(i->desc.comp_node);
}
devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
@@ -461,9 +522,24 @@ void ChannelImpl::process_one_task(Command& cmd) {
// if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
free(del);
}
// Before wait
//TODO: split operator wait and execute so that OpWait could be corrected recorded.
// Before execute
m_worker_state.profiler->record_host<HostOpExecuteEvent>(event_data);
for (auto&& device: devices) {
sync_device_scope(device);
m_worker_state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data);
}
// Apply op
// Here std::move is REQUIRED for removing duplicated references.
auto tensor_outputs = OpDef::apply_on_physical_tensor(
*cmd.op, std::move(tensor_inputs));
// After execute
m_worker_state.profiler->record_host<HostOpFinishEvent>(event_data);
for (auto&& device: devices) {
m_worker_state.profiler->record_device<DeviceOpFinishEvent>(device, event_data);
}
// End profiling operator
mgb_assert(tensor_outputs.size() == cmd.outputs.size());
for (size_t i = 0; i < tensor_outputs.size(); ++i) {
if (cmd.outputs[i] == nullptr) {
@@ -488,13 +564,51 @@ void ChannelImpl::process_one_task(Command& cmd) {
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Drop>) {
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Move>) {
produce_tensor(cmd.dest, cmd.src->ptr);
free(cmd.src);
} else if constexpr (std::is_same_v<T, SetOption>) {
m_worker_state.options.set_option(cmd.key, cmd.value);
} else if constexpr (std::is_same_v<T, StartProfile>) {
CompNode::sync_all();
m_worker_state.profiler.reset(cmd.profiler);
} else if constexpr (std::is_same_v<T, StopProfile>) {
for (auto&& [device, scopes]: m_worker_state.device_scope_map) {
MGB_MARK_USED_VAR(scopes);
sync_device_scope(device);
}
do_finish_command();
auto profiler = std::make_unique<InterpreterProfiler>();
std::swap(profiler, m_worker_state.profiler);
auto records = profiler->stop();
auto host_map = [this](std::thread::id tid) {
if (tid == m_channel_state.tid) {
return "channel";
} else if (tid == m_worker_state.tid) {
return "worker";
} else {
return "unknown";
}
};
InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map);
} else if constexpr (std::is_same_v<T, PushScope>) {
m_worker_state.scopes.push_back(cmd.scope_name);
do_finish_command();
m_worker_state.profiler->record_host<WorkerBeginScope>(cmd.scope_name);
} else if constexpr (std::is_same_v<T, PopScope>) {
mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch");
m_worker_state.scopes.pop_back();
do_finish_command();
m_worker_state.profiler->record_host<WorkerEndScope>(cmd.scope_name);
} else {
static_assert(std::is_same_v<T, Flush> ||
std::is_same_v<T, Nop>);
static_assert(std::is_same_v<T, T>);
}
};
std::visit([&](auto& cmd){
using T = std::decay_t<decltype(cmd)>;
if (!m_worker_state.options.catch_worker_execption) {
cmd_visitor(cmd);
return;
}
try {
cmd_visitor(cmd);
} catch (...) {
MGB_LOCK_GUARD(m_mutex);
if constexpr (std::is_same_v<T, ApplyOp>) {
@@ -507,7 +621,8 @@ void ChannelImpl::process_one_task(Command& cmd) {
m_worker_exc = std::current_exception();
m_cv.notify_all();
}
}, cmd);
}, icmd.second);
do_finish_command();
}

void ChannelImpl::check_worker_exc_unsafe() {
@@ -524,18 +639,22 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return;
}
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, cmd);
mgb_log_debug("%s Enqueued", command_repr.c_str());
mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
m_commands.push_back(std::move(cmd));
auto flush_pos = flush_pos_for(m_commands.back());
flush(flush_pos);
}

void ChannelImpl::CommandBuffer::flush() {
flush(m_commands.end());
}

void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) {
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, *iter);
mgb_log_debug("%s Flushed", command_repr.c_str());
m_owner->m_worker.add_task(std::move(*iter));
mgb_log_debug("%s Flushed", to_string(*iter).c_str());
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd);
m_owner->m_worker.add_task(std::move(icmd));
}
m_commands.erase(m_commands.begin(), pos);
}
@@ -555,17 +674,10 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
}
} else if constexpr (std::is_same_v<T, GetValue>) {
return m_commands.end();
} else if constexpr (std::is_same_v<T, Flush>) {
if (cmd.dest == nullptr) {
return m_commands.end();
}
auto produce_iter = find_produce(cmd.dest, {m_commands.begin(), m_commands.end()});
if (produce_iter != m_commands.end()) {
return produce_iter + 1;
}
}
if (m_commands.size() > m_capacity) {
return m_commands.begin() + (m_commands.size() - m_capacity);
size_t buffer_length = m_owner->m_channel_state.options.buffer_length;
if (m_commands.size() > buffer_length) {
return m_commands.begin() + (m_commands.size() - buffer_length);
}
return m_commands.begin();
}, cmd);
@@ -589,7 +701,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
return false;
}
mgb_log_debug("%s Fused", cmd.to_string().c_str());
mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
return true;
}
@@ -636,3 +748,41 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
}, cmd);
});
}

void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
auto profiler_option = InterpreterProfiler::Option::from_dict(option);
auto profiler = std::make_unique<InterpreterProfiler>();
profiler->set_option(profiler_option);
profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic));
std::swap(profiler, m_channel_state.profiler);
m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()});
}

void ChannelImpl::stop_profile(std::string basename, std::string format) {
m_buffer.flush();
auto profiler = std::make_unique<InterpreterProfiler>();
std::swap(profiler, m_channel_state.profiler);
profiler.release();
m_buffer.enqueue(StopProfile{basename, format});
}

void ChannelImpl::push_scope(std::string name) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_channel_state.scopes.push_back(name);
m_buffer.enqueue(PushScope{name});
}

void ChannelImpl::pop_scope(std::string name) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.scopes.pop_back();
m_channel_state.profiler->record_host<ChannelEndScope>(name);
m_buffer.enqueue(PopScope{name});
}

void ChannelImpl::assert_in_channel() {
mgb_assert(m_channel_state.tid != std::this_thread::get_id());
}

void ChannelImpl::assert_in_worker() {
mgb_assert(m_worker_state.tid == std::this_thread::get_id());
}

+ 205
- 0
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -0,0 +1,205 @@
/**
* \file imperative/src/impl/interpreter/interpreter_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include <deque>
#include <future>
#include <list>
#include <thread>
#include <unordered_set>
#include <variant>

#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/profiler.h"

#include "./commands.h"
#include "./events.h"
#include "./tensor_info.h"
#include "./option_manager.h"
#include "./profiler.h"

namespace mgb::imperative::interpreter::intl {

using Handle = Interpreter::Handle;

struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override;
};


struct ChannelImpl : Interpreter::Channel {
ChannelImpl();
~ChannelImpl() override;

Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value) override;

void del(Handle) override;
void swap_in(Handle) override;
void swap_out(Handle) override;
void drop(Handle) override;

SmallVector<Handle> apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) override;

HostTensorND get_value(Handle) override;
TensorShape get_shape(Handle) override;
DType get_dtype(Handle) override;
CompNode get_device(Handle) override;

DeviceTensorND get_dev_tensor(Handle) override;

void sync() override;
void close() override;

int get_option(std::string name) override;
void set_option(std::string name, int value) override;

void start_profile(std::unordered_map<std::string, int> option) override;
void stop_profile(std::string basename, std::string format) override;

void push_scope(std::string) override;
void pop_scope(std::string) override;
private:
TensorInfo* alloc();
void free(TensorInfo*);
void detach_users(TensorInfo*);

void process_one_task(IdentifiedCommand&);

void check_worker_exc_unsafe();

void produce_tensor(TensorInfo* dest, TensorPtr ptr);

void release_tensor(TensorInfo* dest);

void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path);

void dispatch_default_cpu(
std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);
void dispatch_kernel(
std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);

void assert_in_channel();
void assert_in_worker();

void sync_device_scope(CompNode device);

template <typename TCommand>
void enqueue_command(TCommand&& cmd) {
m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
}

std::mutex m_mutex;
std::condition_variable m_cv;
MemPool<TensorInfo> m_pool;
std::unordered_set<Handle> m_valid_handle;
TensorInfo* m_waitee = nullptr;
std::exception_ptr m_worker_exc;
std::atomic_uint64_t m_last_id = 0;

struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner.
// this won't affect throughput when python interpreter is sending enough task,
// but will significantly save CPU time when waiting for task, e.g. wait for data input
WorkQueue(ChannelImpl* owner)
: AsyncQueueSC<IdentifiedCommand, WorkQueue>(0), m_owner(owner) {
sys::set_thread_name("interpreter");
}
void process_one_task(IdentifiedCommand& icmd) {
m_owner->process_one_task(icmd);
}
void on_async_queue_worker_thread_start() override {
sys::set_thread_name("worker");
m_owner->m_worker_state.tid = std::this_thread::get_id();
}
private:
ChannelImpl* m_owner;
} m_worker;

/**
* Buf a command window for following fuse
* example:
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
* ---------------------------------------------------------------------
* Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task
*/
struct CommandBuffer {
CommandBuffer(ChannelImpl* owner) : m_owner(owner) {}
void enqueue(Command cmd);
bool empty() const {
return m_commands.empty();
}
void flush();
private:
ChannelImpl* m_owner;
std::deque<Command> m_commands;

using Handle = decltype(m_commands)::iterator;
// [begin, end)
using Range = std::array<Handle, 2>;

// Launch commands in range [m_commands.begin(), pos)
void flush(Handle pos);
// Select flush position for incoming cmd
Handle flush_pos_for(const Command& cmd);
// Fuse del command into suitable ApplyOp
bool fuse_del(const Del& cmd);
// Returns the last handle that dest is used within range. If dest is not used, returns range[1]
Handle find_last_usage(TensorInfo* dest, Range range);
// Returns the produce position of dest. If not found, returns range[1]
Handle find_produce(TensorInfo* dest, Range range);
} m_buffer;

//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 2;
int m_max_recompute_time = 1;

struct State {
std::thread::id tid;
OptionManager options;
std::vector<std::string> scopes;
std::unique_ptr<InterpreterProfiler> profiler;

State() {
profiler = std::make_unique<InterpreterProfiler>();
}
};

struct ChannelState: State {};

struct WorkerState: State {
CompNode::UnorderedMap<std::vector<std::string>> device_scope_map;
};

ChannelState m_channel_state;
WorkerState m_worker_state;
};

} // namespace mgb::imperative::interpreter::intl

+ 61
- 0
imperative/src/impl/interpreter/option_manager.h View File

@@ -0,0 +1,61 @@
/**
* \file imperative/src/impl/interpreter/option_manager.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include <string>
#include <unordered_map>

#include "megbrain/common.h"

namespace mgb::imperative::interpreter::intl {

struct OptionManager {
private:
std::unordered_map<std::string, int*> m_option_map = {};
public:
#define DEF_OPTION(name, env_key, default_value, desc) \
int name = (m_option_map[#name]=&name, get_option_from_env(env_key, default_value));

DEF_OPTION(async_level, "MEGENGINE_INTERP_ASYNC_LEVEL", 2,
"config whether raise error exactly when invoking op.\n"
"level 2: both device and user side errors are async;\n"
"level 1: user side errors are sync;\n"
"level 0: both sync.");
DEF_OPTION(enable_swap, "MEGENGINE_ENABLE_SWAP", 0, "");
DEF_OPTION(enable_drop, "MEGENGINE_ENABLE_DROP", 0, "");
DEF_OPTION(max_recompute_time, "MEGENGINE_MAX_RECOMP_TIME", 1, "");
DEF_OPTION(catch_worker_execption, "MEGENGINE_CATCH_WORKER_EXEC", 1,
"catch worker exception if enabled, close it when debugging");
DEF_OPTION(buffer_length, "MEGENGINE_COMMAND_BUFFER_LENGTH", 3,
"set command buffer length.");
DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1,
"enable host compute, thus computation may be done in host event if it's device is gpu.");

#undef DEF_OPTION

void set_option(const std::string& name, int value) {
*m_option_map[name] = value;
}

int get_option(const std::string& name) const {
return *m_option_map.at(name);
}

static int get_option_from_env(const std::string& name, int default_value) {
if (const char* env_val = MGB_GETENV(name.c_str())) {
default_value = std::atoi(env_val);
}
return default_value;
}
};

}

+ 280
- 0
imperative/src/impl/interpreter/profiler.cpp View File

@@ -0,0 +1,280 @@
/**
* \file imperative/src/impl/interpreter/profiler.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./profiler.h"

#include <sstream>
#include <cinttypes>

#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
#include <unistd.h>
#elif defined(_WIN32)
#include <process.h>
#else
#error Unsupported platform
#endif

#include "../op_trait.h"

namespace mgb::imperative::interpreter::intl {

namespace {

struct InterpreterProfilerDumpChromeTimelineContext {
// either host_thread(std::thread::id) or device_thread(CompNode)
using Thread = std::variant<std::thread::id, CompNode>;

// input params
std::string base_name;
std::string format;
InterpreterProfiler::Data profile_data;
InterpreterProfiler::Option option;
std::function<std::string(std::thread::id)> host_map;

// internal states
decltype(getpid()) pid;
CompNode::UnorderedMap<std::map<double, CompNode::Event*>> device_sync_map;
SmallVector<Thread> thread_list;
double time_start;
// options
bool show_operator_name;
// results
ChromeTraceEventList event_list;

InterpreterProfilerDumpChromeTimelineContext(
std::string base_name,
std::string format,
InterpreterProfiler::Data profile_data,
InterpreterProfiler::Option option,
std::function<std::string(std::thread::id)> host_map)
: base_name{base_name}, format{format}, profile_data{profile_data}, option{option}, host_map{host_map} {
pid = getpid();
time_start = option.align_time ? time_start : 0;
show_operator_name = option.show_operator_name;
}

// get device time from event
double get_device_time(CompNode::Event* device_event, double host_time) {
device_event->host_wait();
auto& sync_map = device_sync_map[device_event->comp_node()];
// find sync point
auto iter = sync_map.begin();
auto sync_current = [&] {
iter = sync_map.insert(iter, {host_time, device_event});
return host_time;
};
if (iter == sync_map.end()) {
// not found, insert sync
return sync_current();
}
auto& [base_time, base] = *iter;
// calculate elapsed time
double delta_time = base->elapsed_time_until(*device_event) * 1e3;
return base_time + delta_time;
};

template <typename T>
size_t get_tid(T t) {
for (size_t i = 0; i < thread_list.size(); i++) {
if (thread_list[i] == Thread{t}) {
return i;
}
}
thread_list.push_back(t);
return thread_list.size() - 1;
};

ChromeTraceEvent& new_event(std::string name, char ph, uint64_t tid, double ts) {
return event_list.new_event().name(name).ph(ph).tid(tid).ts(ts).pid(pid);
};

// convert Command to json object. Has to be an callable object
static auto constexpr cmd_to_args = [](auto&& cmd) {
auto args = json::Object::make();
cmd.get_props([&](const char* key, auto&& value){
(*args)[key] = json::String::make(to_string(value));
});
(*args)["__type__"] = json::String::make(typeid(cmd).name());
return args;
};

void process() {
// enumerate and process each record
for (auto&& record: profile_data.records) {
std::visit([this](auto& record){
using TEvent = std::decay_t<decltype(record.data)>;
Session<TEvent>(*this, record).process();
}, record);
}
for (size_t tid = 0; tid < thread_list.size(); ++tid) {
auto tname = std::visit([&](auto& host_or_device) -> std::string{
using T = std::decay_t<decltype(host_or_device)>;
if constexpr (std::is_same_v<T, std::thread::id>) {
// take name from host_map
return host_map(host_or_device);
} else {
// use CompNode::to_string
return host_or_device.to_string();
}
}, thread_list[tid]);
// assign thread name
new_event("thread_name", 'M', tid, 0)
.arg("name", tname);
}
// wraite output to file
std::string out_buf;
event_list.to_json()->writeto(out_buf, 4);
std::ofstream output_stream;
output_stream.open(base_name + "." + format);
output_stream << out_buf;
output_stream.flush();
output_stream.close();
}

template <typename TEvent>
struct Session {
InterpreterProfilerDumpChromeTimelineContext& ctx;
ProfilerBase::EventRecord<TEvent>& record;
TEvent& data;

Session(InterpreterProfilerDumpChromeTimelineContext& ctx,
ProfilerBase::EventRecord<TEvent>& record)
: ctx{ctx}, record{record}, data{record.data} {}

uint64_t get_host_tid() {
return ctx.get_tid(record.host().tid);
};
double get_host_ts() {
return (ctx.time_start + record.host().time) * 1e3;
};
uint64_t get_device_tid() {
return ctx.get_tid(record.device().event->comp_node());
};
double get_device_ts() {
return (ctx.time_start + ctx.get_device_time(record.device().event.get(), record.device().after)) * 1e3;
};
ChromeTraceEvent& new_host_event(std::string name, char ph) {
return ctx.new_event(std::move(name), ph, get_host_tid(), get_host_ts());
};
ChromeTraceEvent& new_device_event(std::string name, char ph) {
return ctx.new_event(std::move(name), ph, get_device_tid(), get_device_ts());
};

void process() {
// dispatch event by type
if constexpr (std::is_same_v<TEvent, CommandEnqueueEvent>) {
auto args = std::visit(cmd_to_args, data.icmd.second);
new_host_event("CommandEnqueue", 'X').dur(0).args(args);
} else if constexpr (std::is_same_v<TEvent, CommandExecuteEvent>) {
auto args = std::visit(cmd_to_args, data.icmd.second);
new_host_event("CommandExecute", 'B').args(args);
} else if constexpr (std::is_same_v<TEvent, CommandFinishEvent>) {
new_host_event("CommandExecute", 'E');
} else if constexpr (std::is_same_v<TEvent, HostOpExecuteEvent>) {
auto args = json::Object::make();
auto props = OpDef::props(*data.op);
auto name = data.op->trait()->name;
for (auto&& [prop_name, prop_val]: props) {
(*args)[std::string("op.") + prop_name] = json::String::make(prop_val);
}
(*args)["name"] = json::String::make(name);
(*args)["id"] = json::Number::make(data.id);
(*args)["inputs"] = json::String::make(to_string(data.inputs));
(*args)["outputs"] = json::String::make(to_string(data.outputs));
new_host_event(ctx.show_operator_name ? name : "OpExecute", 'B').args(args);
} else if constexpr (std::is_same_v<TEvent, DeviceOpExecuteEvent>) {
auto args = json::Object::make();
auto props = OpDef::props(*data.op);
auto name = data.op->trait()->name;
for (auto&& [prop_name, prop_val]: props) {
(*args)[std::string("op.") + prop_name] = json::String::make(prop_val);
}
(*args)["name"] = json::String::make(name);
(*args)["id"] = json::Number::make(data.id);
(*args)["inputs"] = json::String::make(to_string(data.inputs));
(*args)["outputs"] = json::String::make(to_string(data.outputs));
new_device_event(ctx.show_operator_name ? name : "OpExecute", 'B').args(args);
} else if constexpr (std::is_same_v<TEvent, HostOpFinishEvent>) {
auto name = data.op->trait()->name;
new_host_event(ctx.show_operator_name ? name : "OpExecute", 'E');
} else if constexpr (std::is_same_v<TEvent, DeviceOpFinishEvent>) {
auto name = data.op->trait()->name;
new_device_event(ctx.show_operator_name ? name : "OpExecute", 'E');
} else if constexpr (std::is_same_v<TEvent, TensorDeclareEvent>) {
json::Number::make(data.tensor_id);
new_host_event("TensorLifetime", 'N').id(data.tensor_id);
} else if constexpr (std::is_same_v<TEvent, TensorProduceEvent>) {
auto snapshot = json::Object::make();
(*snapshot)["shape"] = json::String::make(to_string((TensorShape)data.layout));
(*snapshot)["dtype"] = json::String::make(to_string(data.layout.dtype));
(*snapshot)["device"] = json::String::make(to_string(data.device));
json::Number::make(data.tensor_id);
new_host_event("TensorLifetime", 'O').id(data.tensor_id).arg("snapshot", snapshot);
} else if constexpr (std::is_same_v<TEvent, TensorEraseEvent>) {
json::Number::make(data.tensor_id);
new_host_event("TensorLifetime", 'D').id(data.tensor_id);
} else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) {
auto args = json::Object::make();
(*args)["id"] = json::Number::make(data.tensor_id);
(*args)["prop"] = json::String::make(to_string(data.prop));
(*args)["prop_desc"] = json::String::make(data.prop_desc);
new_host_event("TensorGetProp", 'X').dur(0).args(args);
} else if constexpr (std::is_same_v<TEvent, TensorNotifyPropEvent>) {
// TODO
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) {
auto args = json::Object::make();
(*args)["id"] = json::Number::make(data.tensor_id);
(*args)["prop"] = json::String::make(to_string(data.prop));
(*args)["prop_desc"] = json::String::make(data.prop_desc);
new_host_event("TensorWaitProp", 'B').args(args);
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) {
auto args = json::Object::make();
(*args)["id"] = json::Number::make(data.tensor_id);
(*args)["prop"] = json::String::make(to_string(data.prop));
(*args)["prop_desc"] = json::String::make(data.prop_desc);
new_host_event("TensorWaitProp", 'E').args(args);
} else if constexpr (std::is_same_v<TEvent, SyncStartEvent>) {
new_host_event("SyncEvent", 'B');
} else if constexpr (std::is_same_v<TEvent, SyncFinishEvent>) {
new_host_event("SyncEvent", 'E');
} else if constexpr (std::is_same_v<TEvent, ChannelBeginScope>) {
new_host_event(data.name, 'B');
} else if constexpr (std::is_same_v<TEvent, ChannelEndScope>) {
new_host_event(data.name, 'E');
} else if constexpr (std::is_same_v<TEvent, WorkerBeginScope>) {
new_host_event(data.name, 'B');
} else if constexpr (std::is_same_v<TEvent, WorkerEndScope>) {
new_host_event(data.name, 'E');
} else if constexpr (std::is_same_v<TEvent, DeviceBeginScope>) {
new_device_event(data.name, 'B');
} else if constexpr (std::is_same_v<TEvent, DeviceEndScope>) {
new_device_event(data.name, 'E');
} else {
static_assert(!std::is_same_v<TEvent, TEvent>);
}
}
};
};

}

void InterpreterProfiler::dump_data(
std::string basename,
std::string format,
InterpreterProfiler::Data profile_data,
const InterpreterProfiler::Option& option,
std::function<std::string(std::thread::id)> host_map) {
InterpreterProfilerDumpChromeTimelineContext{
basename, format, profile_data, option, host_map
}.process();
}

}

+ 97
- 0
imperative/src/impl/interpreter/profiler.h View File

@@ -0,0 +1,97 @@
/**
* \file imperative/src/impl/interpreter/profiler.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/imperative/profiler.h"

#include "./commands.h"
#include "./events.h"
#include "./option_manager.h"

namespace mgb::imperative::interpreter::intl {

class InterpreterProfiler: public Profiler<
CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent,
HostOpExecuteEvent, HostOpFinishEvent,
DeviceOpExecuteEvent, DeviceOpFinishEvent,
TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent,
TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent,
SyncStartEvent, SyncFinishEvent,
ChannelBeginScope, ChannelEndScope,
WorkerBeginScope, WorkerEndScope,
DeviceBeginScope, DeviceEndScope> {
/*22 events now. Enum code may be a better solution*/

public:
enum Topic {
Command = 0b000001,
Operator = 0b000010,
TensorLifetime = 0b000100,
TensorProp = 0b001000,
Sync = 0b010000,
Scope = 0b100000,
};

struct Option {
Topic topic;
bool align_time;
bool show_operator_name;

static Option from_dict(std::unordered_map<std::string, int> dict) {
Option option;
option.topic = Topic(dict.at("topic"));
option.align_time = bool(dict.at("align_time"));
option.show_operator_name = bool(dict.at("show_operator_name"));
return option;
}
};

Option get_option() const {
return m_option;
}

void set_option(const Option& option) {
m_option = option;
}

static void dump_data(std::string basename, std::string format, InterpreterProfiler::Data profile_data, const Option& option, std::function<std::string(std::thread::id)> host_map);

static Mask topic_to_mask(Topic topic) {
Mask result;
if (topic & Command) {
result |= mask_of<CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent>();
}
if (topic & Operator) {
result |= mask_of<HostOpExecuteEvent, HostOpFinishEvent>();
result |= mask_of<DeviceOpExecuteEvent, DeviceOpFinishEvent>();
}
if (topic & TensorLifetime) {
result |= mask_of<TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent>();
}
if (topic & TensorProp) {
result |= mask_of<TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent>();
}
if (topic & Sync) {
result |= mask_of<SyncStartEvent, SyncFinishEvent>();
}
if (topic & Scope) {
result |= mask_of<ChannelBeginScope, ChannelEndScope, WorkerBeginScope, WorkerEndScope>();
result |= mask_of<DeviceBeginScope, DeviceEndScope>();
}
return result;
}

private:
Option m_option;
};

}

+ 135
- 0
imperative/src/impl/interpreter/tensor_info.h View File

@@ -0,0 +1,135 @@
/**
* \file imperative/src/impl/interpreter/tensor_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"

namespace mgb::imperative {

namespace interpreter::intl {

enum EvictType {
NONE = 0,
SWAP = 1,
DROP = 2,
};

struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

struct TensorInfo {
enum Prop {
Device, Shape, DType, DevValue, HostValue
};

uint64_t id;
TensorPtr ptr;
LogicalTensorDesc desc;

// FIXME: broken by drop
bool value_fetched = false;
bool invalid = false;
bool allow_delete = false;

EvictType evict_type = NONE;

HostTensorND h_value;

// reserved for auto drop
size_t pinned = 0;
size_t recompute_times = 0;

struct ComputePath {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> unique_inputs;
SmallVector<TensorInfo*> outputs;

size_t ref_cnt() {
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
}

static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) {
auto* path = new TensorInfo::ComputePath();
path->op = op;
path->inputs = inputs;
path->outputs = outputs;
// dedup
SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end());
unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end());
path->unique_inputs = unique_inputs;
// attach users
for (auto input: unique_inputs) {
input->users.push_back(path);
}
// attach producer
for (auto output: outputs) {
output->producer = path;
}
return path;
}
}* producer = nullptr;

void pin() {
++pinned;
}

void unpin() {
--pinned;
}

void detach_producer() {
if (!producer) {
return;
}
auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
mgb_assert(output != producer->outputs.end());
*output = nullptr;
if (producer->ref_cnt() == 0) {
for (auto* input: producer->unique_inputs) {
input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
}
delete producer;
}
producer = nullptr;
}

SmallVector<ComputePath*> users;
};
}

template <>
struct ToStringTrait<interpreter::intl::TensorInfo::Prop>{
using TensorInfo = interpreter::intl::TensorInfo;

std::string operator()(TensorInfo::Prop prop) const {
switch(prop) {
case TensorInfo::DType:
return "dtype";
case TensorInfo::DevValue:
return "dev_value";
case TensorInfo::Device:
return "device";
case TensorInfo::HostValue:
return "host_value";
case TensorInfo::Shape:
return "shape";
default:
return "unknown";
}
}
};

}

+ 0
- 351
imperative/src/impl/interpreter_impl.h View File

@@ -1,351 +0,0 @@
/**
* \file imperative/src/impl/interpreter_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <deque>
#include <future>
#include <list>
#include <unordered_set>
#include <variant>

#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"

namespace mgb::imperative::interpreter::intl {

using Handle = Interpreter::Handle;

struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override;
};

enum EvictType {
NONE = 0,
SWAP = 1,
DROP = 2,
};

struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

struct TensorInfo {
TensorPtr ptr;
LogicalTensorDesc desc;

// FIXME: broken by drop
bool value_fetched = false;
bool invalid = false;

EvictType evict_type = NONE;

HostTensorND h_value;

// reserved for auto drop
size_t pinned = 0;
size_t recompute_times = 0;

struct ComputePath {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> unique_inputs;
SmallVector<TensorInfo*> outputs;

size_t ref_cnt() {
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
}

static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) {
auto* path = new TensorInfo::ComputePath();
path->op = op;
path->inputs = inputs;
path->outputs = outputs;
// dedup
SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end());
unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end());
path->unique_inputs = unique_inputs;
// attach users
for (auto input: unique_inputs) {
input->users.push_back(path);
}
// attach producer
for (auto output: outputs) {
output->producer = path;
}
return path;
}
}* producer = nullptr;

void pin() {
++pinned;
}

void unpin() {
--pinned;
}

void detach_producer() {
if (!producer) {
return;
}
auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
mgb_assert(output != producer->outputs.end());
*output = nullptr;
if (producer->ref_cnt() == 0) {
for (auto* input: producer->unique_inputs) {
input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
}
delete producer;
}
producer = nullptr;
}

SmallVector<ComputePath*> users;

};

struct Put {
TensorInfo* dest;
HostTensorND value;
bool no_cache = false;

std::string to_string() const { return ssprintf("Command: Put %p", dest); }
};
struct ApplyOp {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<TensorInfo*> dels;

std::string to_string() const {
std::string builder{"Command: ApplyOp {"};
builder += "inputs [";
for (auto* input : inputs) {
builder += ssprintf("%p, ", input);
}
builder += "], outputs [";
for (auto* output : outputs) {
builder += ssprintf("%p, ", output);
}
builder += "], dels [";
for (auto* del : dels) {
builder += ssprintf("%p, ", del);
}
builder += "]";
return builder;
}
};
struct Del {
TensorInfo* dest;

std::string to_string() const { return ssprintf("Command: Del %p", dest); }
};
struct GetValue {
TensorInfo* dest;

std::string to_string() const {
return ssprintf("Command: GetValue %p", dest);
}
};
struct SwapIn {
TensorInfo* dest;

std::string to_string() const {
return ssprintf("Command: SwapIn %p", dest);
}
};
struct SwapOut {
TensorInfo* dest;

std::string to_string() const {
return ssprintf("Command: SwapOut %p", dest);
}
};
struct Drop {
TensorInfo* dest;

std::string to_string() const {
return ssprintf("Command: Drop %p", dest);
}
};
struct Move {
TensorInfo* src;
TensorInfo* dest;

std::string to_string() const {
return ssprintf("Command: Move %s to %s",
src->desc.layout.to_string().c_str(),
dest->desc.layout.to_string().c_str());
}
};
struct Flush {
TensorInfo* dest = nullptr;

std::string to_string() const {
return ssprintf("Command: Flush %p", dest);
}
};
struct Nop {
std::string to_string() const { return "Command: Nop"; }
};
using Command = std::variant<Put,
ApplyOp,
Del,
GetValue,
SwapIn,
SwapOut,
Drop,
Move,
Flush,
Nop>;

struct ChannelImpl : Interpreter::Channel {
ChannelImpl() : m_worker(this), m_buffer(this) {}
~ChannelImpl() override;

Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value) override;

void del(Handle) override;
void swap_in(Handle) override;
void swap_out(Handle) override;
void drop(Handle) override;

SmallVector<Handle> apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) override;

HostTensorND get_value(Handle) override;
TensorShape get_shape(Handle) override;
DType get_dtype(Handle) override;
CompNode get_device(Handle) override;

DeviceTensorND get_dev_tensor(Handle) override;

void sync() override;
void close() override;
void set_swap_flag(bool) override;
void set_drop_flag(bool) override;
void set_buffer_length(int) override;

void config_async_level(int level) override;
int get_async_level() override;

private:
TensorInfo* alloc();
void free(TensorInfo*);
void detach_users(TensorInfo*);

void process_one_task(Command&);

void check_worker_exc_unsafe();

void produce_tensor(TensorInfo* dest, TensorPtr ptr);

void release_tensor(TensorInfo* dest);

void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path);

void dispatch_default_cpu(
std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);
void dispatch_kernel(
std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);

std::mutex m_mutex;
std::condition_variable m_cv;
MemPool<TensorInfo> m_pool;
std::unordered_set<Handle> m_valid_handle;
TensorInfo* m_waitee = nullptr;
std::exception_ptr m_worker_exc;
size_t m_enable_evict = 0;

struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner.
// this won't affect throughput when python interpreter is sending enough task,
// but will significantly save CPU time when waiting for task, e.g. wait for data input
WorkQueue(ChannelImpl* owner)
: AsyncQueueSC<Command, WorkQueue>(0), m_owner(owner) {
sys::set_thread_name("interpreter");
}
void process_one_task(Command& cmd) {
m_owner->process_one_task(cmd);
}
void on_async_queue_worker_thread_start() override {
sys::set_thread_name("worker");
}
private:
ChannelImpl* m_owner;
} m_worker;

/**
* Buf a command window for following fuse
* example:
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
* ---------------------------------------------------------------------
* Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task
*/
struct CommandBuffer {
CommandBuffer(ChannelImpl* owner) : m_owner(owner) {
int capacity = 3;
if(const char* capacity_str = MGB_GETENV("MEGENGINE_COMMAND_BUFFER_LENGTH")) {
capacity = atoi(capacity_str);
}
set_capacity(capacity);
}
void enqueue(Command cmd);
bool empty() const {
return m_commands.empty();
}
void set_capacity(int capacity) {
mgb_assert(capacity >= 0 && capacity < 100, "invalid command buffer length");
m_capacity = capacity;
}
private:
ChannelImpl* m_owner;
size_t m_capacity;
std::deque<Command> m_commands;

using Handle = decltype(m_commands)::iterator;
// [begin, end)
using Range = std::array<Handle, 2>;

// Launch commands in range [m_commands.begin(), pos)
void flush(Handle pos);
// Select flush position for incoming cmd
Handle flush_pos_for(const Command& cmd);
// Fuse del command into suitable ApplyOp
bool fuse_del(const Del& cmd);
// Returns the last handle that dest is used within range. If dest is not used, returns range[1]
Handle find_last_usage(TensorInfo* dest, Range range);
// Returns the produce position of dest. If not found, returns range[1]
Handle find_produce(TensorInfo* dest, Range range);
} m_buffer;

//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 2;
int m_max_recompute_time = 1;
};

} // namespace mgb::imperative::interpreter::intl

+ 20
- 0
imperative/src/impl/op_def.cpp View File

@@ -70,6 +70,26 @@ BackwardGraphResult OpDef::make_backward_graph(
return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
}

std::vector<std::pair<const char*, std::string>> OpDef::props(
const OpDef& def) {
return def.trait()->props(def);
}

const char* OpDef::name() const {
return trait()->name;
}

std::string OpDef::to_string() const {
std::string builder = "{";
for (auto&& [name, value]: props(*this)) {
builder += name;
builder += ": ";
builder += value;
builder += ",";
}
return builder + "}";
}

size_t OpDef::hash() const {
return trait()->hash(*this);
}


+ 3
- 0
imperative/src/impl/op_trait.h View File

@@ -72,6 +72,7 @@ using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>;
using GradMaker = detail::OpMeth<
decltype(OpDef::make_backward_graph)>;
using Props = detail::OpMeth<decltype(OpDef::props)>;
using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;

@@ -84,6 +85,7 @@ struct OpTrait {
ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible;
GradMaker make_backward_graph;
Props props;
HashFunc hash;
IsSame is_same_st;
OpTrait(const char* name);
@@ -100,6 +102,7 @@ struct OpTrait {
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st)



+ 6
- 0
imperative/src/impl/ops/backward_graph.cpp View File

@@ -148,9 +148,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs(
.graph().infer_attrs(inputs);
}

std::vector<std::pair<const char*, std::string>> props(
const OpDef& backward_graph) {
return {};
}

OP_TRAIT_REG(BackwardGraph, BackwardGraph)
.apply_on_physical_tensor(backward_impl)
.infer_output_attrs_fallible(infer_tensor_attrs)
.props(props)
.fallback();
} // anonymous namespace



+ 5
- 0
imperative/src/impl/ops/opr_attr.cpp View File

@@ -95,9 +95,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config());
}

std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
return {};
}

OP_TRAIT_REG(OprAttr, OprAttr)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.props(props)
.fallback();

} // anonymous namespace


+ 22
- 178
imperative/src/impl/profiler.cpp View File

@@ -11,12 +11,14 @@

#include "megbrain/imperative/profiler.h"

#include "./function_hook.h"
#include <chrono>

#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/physical_tensor.h"

#include "megbrain/plugin/opr_footprint.h"

#include "./function_hook.h"
#include "./event_pool.h"
#include "./op_trait.h"

@@ -25,200 +27,42 @@ namespace imperative {

namespace {

CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
CompNode::UnorderedSet comp_nodes;
SmallVector<LogicalTensorDesc> inp_descs;
for (auto&& i : inputs) {
comp_nodes.insert(i->comp_node());
inp_descs.push_back({i->layout(), i->comp_node(), {}});
}
SmallVector<LogicalTensorDesc> oup_descs = std::get<0>(def.infer_output_attrs_fallible(def, inp_descs));
for (auto&& output_attr : oup_descs) {
comp_nodes.insert(output_attr.comp_node);
}
return comp_nodes;
}

DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) {
auto event = EventPool::with_timer().alloc_shared(device);
event->record();
return event;
}

OprFootprint footprint{};

} // namespace

void DeviceTimer::reset(thin_function<double()> host_timer) {
CompNode::foreach ([this, host_timer](CompNode device) {
m_base_event_table[device] = {alloc_recorded_event(device), host_timer()};
});
m_host_timer = host_timer;
DeviceTimer::SharedEvent DeviceTimer::get_device_time(CompNode device) {
return alloc_recorded_event(device);
}

thin_function<double()> DeviceTimer::get_device_time(CompNode device) {
auto event = EventPool::with_timer().alloc_shared(device);
event->record();
if(m_base_event_table.count(device) == 0) {
m_base_event_table[device] = {alloc_recorded_event(device), m_host_timer()};
SmallVector<DeviceTimer::SharedEvent> DeviceTimer::get_all(SmallVector<CompNode> device_list) {
SmallVector<DeviceTimer::SharedEvent> results;
for (auto&& device: device_list) {
results.push_back(alloc_recorded_event(device));
}
auto base = m_base_event_table[device];
return [base, event] {
auto [base_event, host_time] = base;
// TODO: sync once for each compnode
event->host_wait();
return base_event->elapsed_time_until(*event) * 1000 + host_time;
};
return results;
}

void DeviceTimer::clear() {
m_base_event_table.clear();
double HostTimer::get_msecs() {
using namespace std::chrono;
auto finish = steady_clock::now();
auto duration = duration_cast<microseconds>(finish - m_start);
return (double)duration.count() / 1e3;
}

size_t TensorRecorder::record_tensor(const TensorPtr& tensor) {
if (m_tensor_map.count(tensor.get()) > 0) {
auto& [prev, id] = m_tensor_map[tensor.get()];
if (prev.lock() != tensor) {
prev = tensor;
id = m_next_id++;
}
return id;
} else {
auto id = m_next_id++;
m_tensor_map.insert(
{tensor.get(), {std::weak_ptr<Tensor>{tensor}, id}});
return id;
}
}

void TensorRecorder::clear() {
m_next_id = 0;
m_tensor_map.clear();
}

Profile& Profiler::get_profile() {
for (auto& entry : m_profile) {
for (auto& [device, device_begin, device_end] : entry.device_list) {
MGB_MARK_USED_VAR(device);
device_begin = [value = device_begin()] { return value; };
device_end = [value = device_end()] { return value; };
}
}
return m_profile;
}

void Profiler::start(uint32_t flags) {
m_host_timer.reset();
m_device_timer.reset([&] { return m_host_timer.get_msecs(); });
OpTrait::for_each_trait([this, flags](OpTrait& trait) {
auto hook_apply_on_physical_tensor =
make_shared_hook(&trait.apply_on_physical_tensor);
auto hook_apply_on_var_node =
make_shared_hook(&trait.apply_on_var_node);
hook_apply_on_physical_tensor->apply_hook([this, flags]
(auto&& apply, const OpDef& def, SmallVector<TensorPtr> inputs) {
auto shape2vector = [](const TensorShape& shape) {
std::vector<size_t> vector_shape;
for (size_t i = 0; i < shape.ndim; i++) {
vector_shape.push_back(shape[i]);
}
return vector_shape;
};
ProfileEntry entry;
entry.id = m_entry_count++;
// TODO: assign parent
entry.parent = 0;
// Record apply context and save to m_profile
entry.op = const_cast<OpDef&>(def).shared_from_this();
for (auto&& input : inputs) {
entry.inputs.push_back({m_tensor_recorder.record_tensor(input),
shape2vector(input->layout()),
input->comp_node()});
}
double host_begin = m_host_timer.get_msecs();
auto&& comp_nodes = collect_comp_nodes(def, inputs);
for (auto&& comp_node : comp_nodes) {
entry.device_list.push_back(
{comp_node,
m_device_timer.get_device_time(comp_node),
{}});
}
if (flags & PROFILE_FOOTPRINT) {
MGB_LOCK_GUARD(m_lock);
m_entry_stack.push({&def, &entry, std::this_thread::get_id()});
}
// Do real apply
auto outputs = apply(def, inputs);
for (auto& [cn, dev_begin, dev_end] : entry.device_list) {
MGB_MARK_USED_VAR(cn);
MGB_MARK_USED_VAR(dev_begin);
dev_end = m_device_timer.get_device_time(cn);
}
entry.host = {host_begin, m_host_timer.get_msecs()};
for (auto&& output : outputs) {
entry.outputs.push_back(
{m_tensor_recorder.record_tensor(output),
shape2vector(output->layout()), output->comp_node()});
}
if (flags & PROFILE_FOOTPRINT) {
mgb_assert(std::get<1>(m_entry_stack.top()) == &entry);
MGB_LOCK_GUARD(m_lock);
m_entry_stack.pop();
}
m_profile.push_back(std::move(entry));
return outputs;
});
if (flags & PROFILE_FOOTPRINT) {
hook_apply_on_var_node->apply_hook(
[this](auto&& apply, const OpDef& def,
VarNodeArray inputs) -> VarNodeArray {
auto vars = apply(def, std::move(inputs));
std::remove_reference_t<decltype(m_entry_stack.top())>
top;
{
MGB_LOCK_GUARD(m_lock);
if (m_entry_stack.empty()) {
return vars;
}
top = m_entry_stack.top();
}
auto [current_op, current_entry, thread_id] = top;
if (current_op != &def ||
thread_id != std::this_thread::get_id()) {
return vars;
}
auto&& footprint_result =
footprint.calc_footprint(vars[0]->owner_opr());
current_entry->memory = footprint_result.memory;
current_entry->computation =
footprint_result.computation;
#if MGB_ENABLE_JSON
current_entry->param = footprint_result.param;
#endif
return vars;
});
}
m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor));
m_hooker_list.push_back(std::move(hook_apply_on_var_node));
});
}

void Profiler::stop() {
m_hooker_list.clear();
for (auto& entry : m_profile) {
entry.wait_device();
}
double HostTimer::get_started_at() {
return m_started_at;
}

void Profiler::clear() {
mgb_assert(m_entry_stack.empty(),
"entry_stack should be empty after profile");
mgb_assert(m_hooker_list.empty(), "hooks should be released");
m_profile.clear();
m_entry_count = 0;
m_device_timer.clear();
m_tensor_recorder.clear();
void HostTimer::reset() {
using namespace std::chrono;
m_start = steady_clock::now();
auto now_us = duration_cast<microseconds>(std::chrono::system_clock::now().time_since_epoch());
m_started_at = (double)(now_us.count()) / 1e3;
}

} // namespace imperative


+ 1
- 0
imperative/src/impl/proxy_graph/mini_graph.h View File

@@ -471,6 +471,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph {
}
if (can_pop) {
for (auto _ : comp_node_trackers) {
MGB_MARK_USED_VAR(_);
busy_oprs.pop_front();
}
m_opr = busy_oprs.front().opr;


+ 9
- 5
imperative/src/include/megbrain/imperative/interpreter.h View File

@@ -10,6 +10,7 @@
*/

#include <atomic>
#include <any>

#include "megbrain/imperative/op_def.h"

@@ -42,12 +43,15 @@ struct Interpreter {

virtual void sync() = 0;
virtual void close() = 0;
virtual void set_swap_flag(bool) = 0;
virtual void set_drop_flag(bool) = 0;
virtual void set_buffer_length(int) = 0;

virtual void config_async_level(int level) = 0;
virtual int get_async_level() = 0;
virtual int get_option(std::string name) = 0;
virtual void set_option(std::string name, int value) = 0;

virtual void start_profile(std::unordered_map<std::string, int> option) = 0;
virtual void stop_profile(std::string basename, std::string format) = 0;

virtual void push_scope(std::string name) = 0;
virtual void pop_scope(std::string name) = 0;
};

virtual std::unique_ptr<Channel> create_channel() = 0;


+ 18
- 0
imperative/src/include/megbrain/imperative/op_def.h View File

@@ -13,6 +13,7 @@

#include "megbrain/graph.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/utils/to_string.h"

namespace mgb {
namespace imperative {
@@ -80,8 +81,15 @@ public:
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);

static std::vector<std::pair<const char*, std::string>> props(
const OpDef& def);

const OpTrait* trait() const;

const char* name() const;

std::string to_string() const;

virtual size_t hash() const;

virtual bool is_same_st(const Hashable&) const;
@@ -96,6 +104,16 @@ public:
}
};

template <>
struct ToStringTrait<OpDef*>{
std::string operator()(OpDef* op) const {
if (op == nullptr) {
return "nullptr";
}
return op->to_string();
}
};

} // namespace imperative
} // namespace mgb



+ 279
- 68
imperative/src/include/megbrain/imperative/profiler.h View File

@@ -11,10 +11,12 @@

#pragma once

#include <any>
#include <optional>
#include <stack>
#include <list>
#include <map>
#include <variant>
#include <fstream>
#include <chrono>
#include <bitset>

#include "megbrain/comp_node.h"
#include "megbrain/graph/event.h"
@@ -27,89 +29,298 @@
namespace mgb {
namespace imperative {

using ProfileTensor = std::tuple<size_t, std::vector<size_t>, CompNode>;

struct ProfileEntry {
using TimeClosure = std::function<double()>;
size_t id;
size_t parent;
std::shared_ptr<OpDef> op;
//(host_begin, host_end)
std::tuple<double, double> host;
//[(device, device_begin, device_end)]
std::vector<std::tuple<CompNode, TimeClosure, TimeClosure>> device_list;
std::vector<ProfileTensor> inputs;
std::vector<ProfileTensor> outputs;
long long memory = 0;
long long computation = 0;
#if MGB_ENABLE_JSON
std::shared_ptr<json::Value> param;
#endif
void wait_device() {
for (auto& [cn, begin, end] : device_list) {
MGB_MARK_USED_VAR(cn);
begin = [begin = begin()] { return begin; };
end = [end = end()] { return end; };
}
}
};

using Profile = std::list<ProfileEntry>;

class DeviceTimer {
public:
using SharedEvent = std::shared_ptr<CompNode::Event>;
DeviceTimer() = default;
void reset(thin_function<double()> host_timer);
thin_function<double()> get_device_time(CompNode device);
void clear();
SharedEvent get_device_time(CompNode device);
SmallVector<SharedEvent> get_all(SmallVector<CompNode> device_list);
};

class HostTimer {
public:
void reset();
double get_msecs();
double get_started_at();
private:
CompNode::UnorderedMap<std::tuple<SharedEvent, double>> m_base_event_table;
thin_function<double()> m_host_timer;
decltype(std::chrono::steady_clock::now()) m_start;
double m_started_at;
};

class TensorRecorder {
private:
// active tensors
std::unordered_map<Tensor*, std::tuple<std::weak_ptr<Tensor>, size_t>>
m_tensor_map;
size_t m_next_id;

class ProfilerBase {
public:
size_t record_tensor(const TensorPtr& tensor);
void clear();
using Host = std::thread::id;
using Device = CompNode;

struct HostInstant {
Host tid;
double time;

void wait() {}
};

struct DeviceInstant {
double before;
std::shared_ptr<CompNode::Event> event;
double after;

void wait() {
event->host_wait();
}
};

using Instant = std::variant<HostInstant, DeviceInstant>;

template <typename TEvent>
struct EventRecord {
Instant instant;
TEvent data;

HostInstant& host() {
return std::get<HostInstant>(instant);
}

DeviceInstant device() {
return std::get<DeviceInstant>(instant);
}

void wait() {
std::visit([&](auto& instant){ instant.wait(); }, instant);
}
};
protected:
HostInstant record_host() {
return {std::this_thread::get_id(), m_host_timer.get_msecs()};
}
DeviceInstant record_device(Device device) {
auto before = m_host_timer.get_msecs();
auto event = m_device_timer.get_device_time(device);
auto after = m_host_timer.get_msecs();
return {before, event, after};
}
protected:
std::atomic_int64_t m_last_id = 0;
HostTimer m_host_timer;
DeviceTimer m_device_timer;
Spinlock m_lock;
};

class Profiler {

template <typename... TEvents>
class Profiler: public ProfilerBase {
public:
enum Flags {
PROFILE_FOOTPRINT = 1,
using Record = std::variant<EventRecord<TEvents>...>;
using Mask = std::bitset<sizeof...(TEvents)>;

struct Data {
std::vector<Record> records;
double started_at;
};

template <typename TEvent, size_t index = 0>
static constexpr size_t index_of() {
if constexpr (index == std::variant_size_v<Record>) {
return index;
} else if constexpr (std::is_same_v<EventRecord<TEvent>, std::variant_alternative_t<index, Record>>) {
return index;
} else {
return index_of<TEvent, index+1>();
}
};

template <typename... TEvents2>
static Mask mask_of() {
return Mask{} | (Mask{}.set(index_of<TEvents2>()) |...);
}

enum Status {
NotStarted, Profiling, Stopped
};
public:
Profiler() = default;
// Start profiler by hook OpTrait
void start(uint32_t flags);
// Stop profiler and clean environment
void stop();
void clear();
Profile& get_profile();
template <typename TEvent, typename... TArgs>
void record_host(TArgs&&... args) {
auto instant = HostInstant{std::this_thread::get_id(), m_host_timer.get_msecs()};
MGB_LOCK_GUARD(m_lock);
if (!m_event_mask.test(index_of<TEvent>())) {
return;
}
mgb_assert(m_status != Stopped, "record after stop");
m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}});
}
template <typename TEvent, typename... TArgs>
void record_device(Device device, TArgs&&... args) {
auto before = m_host_timer.get_msecs();
auto event = m_device_timer.get_device_time(device);
auto after = m_host_timer.get_msecs();
auto instant = DeviceInstant{before, event, after};
MGB_LOCK_GUARD(m_lock);
if (!m_event_mask.test(index_of<TEvent>())) {
return;
}
mgb_assert(m_status != Stopped, "record after stop");
m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}});
}
void start(Mask mask) {
MGB_LOCK_GUARD(m_lock);
mgb_assert(m_status == NotStarted, "profiler already started");
m_status = Profiling;
m_event_mask = mask;
m_host_timer.reset();
}
Data stop() {
MGB_LOCK_GUARD(m_lock);
mgb_assert(m_status == Profiling, "profiler not active");
m_status = Stopped;
for (auto&& record: m_record_list) {
std::visit([&](auto& record){
record.wait();
}, record);
}
auto records = std::move(m_record_list);
return { records, m_host_timer.get_started_at() };
}
protected:
std::vector<Record> m_record_list;
Mask m_event_mask;
Status m_status = NotStarted;
};


class ChromeTraceEvent {
public:
ChromeTraceEvent& name(std::string name) {
m_name = std::move(name);
return *this;
}
ChromeTraceEvent& tid(uint64_t tid) {
m_tid = std::move(tid);
return *this;
}
ChromeTraceEvent& cat(std::string cat) {
m_cat = std::move(cat);
return *this;
}
ChromeTraceEvent& pid(uint64_t pid) {
m_pid = pid;
return *this;
}
ChromeTraceEvent& id(uint64_t id) {
m_id = id;
return *this;
}
ChromeTraceEvent& idx(uint64_t idx) {
m_idx = idx;
return *this;
}
ChromeTraceEvent& ts(double ts) {
m_ts = ts;
return *this;
}
ChromeTraceEvent& dur(double dur) {
m_dur = dur;
return *this;
}
ChromeTraceEvent& ph(char ph) {
m_ph = ph;
return *this;
}
ChromeTraceEvent& bp(char bp) {
m_bp = bp;
return *this;
}
ChromeTraceEvent& args(std::shared_ptr<json::Object> args) {
m_args = std::move(args);
return *this;
}
ChromeTraceEvent& arg(std::string key, std::string value) {
if (!m_args) {
m_args = json::Object::make();
}
(*m_args)[key] = json::String::make(value);
return *this;
}
ChromeTraceEvent& arg(std::string key, double value) {
if (!m_args) {
m_args = json::Object::make();
}
(*m_args)[key] = json::Number::make(value);
return *this;
}
ChromeTraceEvent& arg(std::string key, std::shared_ptr<json::Value> value) {
if (!m_args) {
m_args = json::Object::make();
}
(*m_args)[key] = value;
return *this;
}

std::shared_ptr<json::Object> to_json() const {
auto result = json::Object::make();
auto prop_str = [&](auto key, auto value) {
if (value.empty()) {
return;
}
(*result)[key] = json::String::make(value);
};
auto prop_num = [&](auto key, auto value) {
if (!value) {
return;
}
(*result)[key] = json::Number::make(value.value());
};
auto prop_char = [&](auto key, auto value) {
if (!value) {
return;
}
(*result)[key] = json::String::make(std::string{} + value.value());
};
prop_str("name", m_name);
prop_num("tid", m_tid);
prop_str("cat", m_cat);
prop_num("pid", m_pid);
prop_num("id", m_id);
prop_num("idx", m_idx);
prop_num("ts", m_ts);
prop_num("dur", m_dur);
prop_char("ph", m_ph);
prop_char("bp", m_bp);
if (m_args) {
(*result)["args"] = m_args;
}
return result;
}
private:
DeviceTimer m_device_timer;
RealTimer m_host_timer;
Profile m_profile;
TensorRecorder m_tensor_recorder;
std::stack<std::tuple<const OpDef*, ProfileEntry*, std::thread::id>>
m_entry_stack;
// Hold profile owned by this Profiler
std::unique_ptr<Profile> m_owned_profile;
// Hold hooks, cleared when stop
std::vector<std::any> m_hooker_list;
size_t m_entry_count = 0;
Spinlock m_lock;
std::unordered_map<Tensor*, std::weak_ptr<Tensor>> m_recorded_tensors;
std::string m_name;
std::string m_cat;

std::optional<uint64_t> m_tid;
std::optional<uint64_t> m_pid;
std::optional<uint64_t> m_id;
std::optional<uint64_t> m_idx;
std::optional<double> m_ts;
std::optional<double> m_dur;
std::optional<char> m_ph;
std::optional<char> m_bp;
std::shared_ptr<json::Object> m_args;
};

class ChromeTraceEventList {
public:
ChromeTraceEvent& new_event() {
m_content.emplace_back();
return m_content.back();
}

std::shared_ptr<json::Array> to_json() {
auto result = json::Array::make();
for (auto&& event: m_content) {
result->add(event.to_json());
}
return result;
}
private:
std::vector<ChromeTraceEvent> m_content;
};

} // namespace imperative


+ 125
- 0
imperative/src/include/megbrain/imperative/utils/to_string.h View File

@@ -0,0 +1,125 @@
/**
* \file imperative/src/include/megbrain/imperative/utils/to_string.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include <string>
#include <type_traits>
#include <memory>
#include <tuple>

#include "megbrain/utils/small_vector.h"
#include "megbrain/tensor.h"

namespace mgb::imperative {

template <typename T>
struct ToStringTrait;

template <typename T>
std::string to_string(const T& value) {
return ToStringTrait<T>{}(value);
}

template <typename T>
struct ToStringTrait{
std::string operator()(const T& value) const {
return std::to_string(value);
}
};

template <>
struct ToStringTrait<std::string>{
std::string operator()(const std::string& value) const {
return value;
}
};

template <typename T, unsigned N>
struct ToStringTrait<SmallVector<T, N>>{
std::string operator()(const SmallVector<T, N>& sv) const {
if (sv.empty()) {
return "[]";
}
std::string result = "[";
result += to_string(sv[0]);
for (size_t i = 1; i < sv.size(); ++i) {
result += ", ";
result += to_string(sv[i]);
}
return result + "]";
}
};

template <typename T>
struct ToStringTrait<std::shared_ptr<T>>{
std::string operator()(const std::shared_ptr<T>& sp) const {
return to_string(sp.get());
}
};

template <typename TKey, typename TValue>
struct ToStringTrait<std::pair<TKey, TValue>>{
std::string operator()(const std::pair<TKey, TValue>& pr) const {
return "(" + to_string(pr.first) + ", " + to_string(pr.second) + ")";
}
};

template <typename TItem, typename... TItems>
struct ToStringTrait<std::tuple<TItem, TItems...>>{
std::string operator()(const std::tuple<TItem, TItems...>& tp) const {
auto folder = [&](auto... item){ return ( ...+ ("," + to_string(item))); };
return "(" + std::apply(folder, tp) + ")";
}
};

template <typename T>
struct ToStringTrait<T*>{
std::string operator()(T* p) const {
return ssprintf("%p", p);
}
};

template <>
struct ToStringTrait<TensorShape>{
std::string operator()(TensorShape shape) const {
if (shape.ndim > TensorShape::MAX_NDIM) {
printf("ndim: %d\n", (int)shape.ndim);
return "[]";
}
mgb_assert(shape.ndim <= TensorShape::MAX_NDIM);
if (shape.ndim == 0) {
return "[ ]";
}
std::string result = "[ " + std::to_string(shape[0]);
for (size_t i = 1; i < shape.ndim; i++) {
result += ", ";
result += std::to_string(shape[i]);
}
return result + " ]";
}
};

template <>
struct ToStringTrait<DType>{
std::string operator()(DType dtype) const {
return dtype.name();
}
};

template <>
struct ToStringTrait<CompNode>{
std::string operator()(CompNode device) const {
return device.to_string();
}
};

}

+ 16
- 1
imperative/tablegen/autogen.cpp View File

@@ -222,10 +222,25 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
os << "}\n";

// generate props()
os << formatv(
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
formatMethImpl("props")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
os << "}\n";

os << "} // anonymous namespace\n";

methods.push_back("hash");
methods.push_back("is_same_st");
methods.push_back("props");
}
if (!methods.empty()) {
os << formatv(
@@ -423,7 +438,7 @@ EnumWrapper<{0}::{1}>::type2str = {{
std::vector<std::string> getsetters;
for (auto &&i : op.getMgbAttributes()) {
getsetters.push_back(formatv(
"{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},",
"{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},",
className, i.name));
}



+ 44
- 2
imperative/tablegen/helper.h View File

@@ -66,7 +66,7 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
}

llvm::StringRef getParentNamespace() const {
return getBaseRecord()->getValueAsString("parentNamespce");
return getBaseRecord()->getValueAsString("parentNamespace");
}
llvm::StringRef getEnumName() const {
return getBaseRecord()->getValueAsString("enumName");
@@ -87,6 +87,9 @@ struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
llvm::StringRef getCmpFunctionTemplate() const {
return getBaseRecord()->getValueAsString("cmpFunction");
}
llvm::StringRef getReprFunctionTemplate() const {
return getBaseRecord()->getValueAsString("reprFunction");
}
};

struct MgbAliasAttrMixin : public MgbAttrWrapperBase {
@@ -205,6 +208,39 @@ private:
body += " return true;\n";
return body;
}
std::string getDefaultPropsFunction() const {
std::string body = " std::vector<std::pair<const char*, std::string>> props_;\n";
if (!getMgbAttributes().empty()) {
mlir::tblgen::FmtContext ctx;
for (auto&& it : getMgbAttributes()) {
if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
body += formatv(" switch ({0}){{\n", "$_self." + it.name);
for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv(
" case {0}::{1}::{2}:\n",
getCppClassName(), enumAttr->getEnumName(), enumMember
);
body += formatv(
" props_.emplace_back(\"{0}\", \"{1}\");\n",
it.name, enumMember
);
body += " break;\n";
}
body += " default: break;\n";
body += " }\n";
} else {
auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr);
body += formatv(
" props_.emplace_back(\"{0}\", {1});\n", it.name,
mlir::tblgen::tgfmt(attr.getReprFunctionTemplate(),
&ctx, "$_self." + it.name)
);
}
}
}
body += " return props_;\n";
return body;
}
public:
static bool classof(const Operator* op) {
return op->getDef().isSubClassOf("MgbHashableOpMixin");
@@ -222,7 +258,13 @@ public:
}
return getDefaultCmpFunction();
}
std::string getPropsFunctionTemplate() const {
if (auto f = getDef().getValueAsOptionalString("propsFunction")) {
return f.getValue().str();
}
return getDefaultPropsFunction();
}
};

} // namespace tblgen
} // namespace mlir
} // namespace mlir

+ 7
- 0
src/core/include/megbrain/ir/base.td View File

@@ -30,6 +30,7 @@ class MgbHashableAttrMixin {
string hashFunction = "mgb::hash($0)";
// return 0 for eq, else for ne
string cmpFunction = "$0 != $1";
string reprFunction = "std::to_string($0)";
}

class MgbEnumAttrMixin<string namespace, string name, list<string> members> {
@@ -98,6 +99,7 @@ def MgbStringAttr : HashableAttr<"std::string"> {
let storageType = "::mlir::StringAttr";
let convertFromStorage = "$_self.getValue().str()";
let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor
string reprFunction = "$0";
}

class MgbArrayAttr<MgbAttrWrapper elem>:
@@ -123,6 +125,7 @@ class MgbArrayAttr<MgbAttrWrapper elem>:
" });\n"
" return $_builder.getArrayAttr(ret" # recursionDepth # ");"
"}()";
let reprFunction = "\"{std::vector}\"";
}

defvar EmptyStrList = !listsplat("", 0);
@@ -168,6 +171,7 @@ class MgbEnumAttr<string namespace, string enumName, list<string> members>:
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
let hashFunction = "mgb::enumhash()($0)";
string reprFunction = "std::to_string((int)$0)";
}

class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>:
@@ -179,12 +183,14 @@ def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> {
let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))";
let hashFunction = "mgb::hash($0.handle())";
let reprFunction = "$0.name()";
}

def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> {
let storageType = "::mlir::StringAttr";
let convertFromStorage = underlyingType # "::load($_self.getValue().str())";
let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())";
string reprFunction = "$0.to_string()";
}

def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> {
@@ -209,6 +215,7 @@ def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> {
" }\n"
" return $_builder.getArrayAttr(ret);"
"}()";
let reprFunction = "$0.to_string()";
}

class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>:


Loading…
Cancel
Save