Browse Source

feat(mgb/opr): add fast profile and combined Execution strategy

GitOrigin-RevId: 843dc3a790
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
a3ea1f153c
31 changed files with 361 additions and 301 deletions
  1. +56
    -0
      dnn/include/megdnn/basic_types.h
  2. +2
    -0
      dnn/include/megdnn/oprs/base.h
  3. +6
    -2
      dnn/scripts/gen_flatbuffers_schema.py
  4. +17
    -5
      dnn/scripts/gen_param_defs.py
  5. +7
    -2
      dnn/scripts/gen_tablegen.py
  6. +0
    -2
      dnn/src/common/algo_base.h
  7. +0
    -55
      dnn/src/common/utils.h
  8. +0
    -1
      dnn/src/cuda/convolution3d/backward_filter/algo.h
  9. +38
    -14
      imperative/python/megengine/functional/debug_param.py
  10. +4
    -1
      imperative/python/test/integration/test_correctness_mnistnet.py
  11. +14
    -0
      imperative/tablegen/autogen.cpp
  12. +3
    -0
      imperative/tablegen/helper.h
  13. +33
    -10
      sdk/load-and-run/src/mgblar.cpp
  14. +5
    -5
      src/core/impl/utils/persistent_cache.cpp
  15. +12
    -0
      src/core/include/megbrain/common.h
  16. +0
    -1
      src/core/include/megbrain/comp_node.h
  17. +0
    -1
      src/core/include/megbrain/graph/operator_node.h
  18. +0
    -1
      src/core/include/megbrain/graph/var_node.h
  19. +5
    -4
      src/core/include/megbrain/ir/base.td
  20. +0
    -89
      src/core/include/megbrain/utils/enum_class_bit.h
  21. +1
    -2
      src/core/include/megbrain/utils/persistent_cache.h
  22. +3
    -6
      src/gopt/impl/inference.cpp
  23. +16
    -1
      src/gopt/test/inference.cpp
  24. +0
    -1
      src/opr/impl/dnn/dnn.sereg.h
  25. +62
    -59
      src/opr/impl/search_policy/algo_chooser.cpp
  26. +18
    -10
      src/opr/include/megbrain/opr/search_policy/algo_chooser.h
  27. +0
    -2
      src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h
  28. +4
    -3
      src/opr/test/blas.cpp
  29. +39
    -23
      src/opr/test/dnn/convolution.cpp
  30. +1
    -0
      test/src/include/megbrain/test/helper.h
  31. +15
    -1
      tools/param_defs/mgb_opr_param_defs.py

+ 56
- 0
dnn/include/megdnn/basic_types.h View File

@@ -506,10 +506,66 @@ struct DynOutMallocPolicyCall {
} }
}; };



template <typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;

constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}

public:
constexpr EnumClassBit(T v)
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}

constexpr operator T() const { return static_cast<T>(m_val); }

constexpr explicit operator bool() const { return m_val; }

#define DEF_OPR(op) \
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \
return m_val op rhs.m_val; \
}

DEF_OPR(&)
DEF_OPR(|)
DEF_OPR (^)

constexpr EnumClassBit operator~() const { return ~m_val; }

#undef DEF_OPR
};

#endif // MEGDNN_CC_HOST #endif // MEGDNN_CC_HOST


} // namespace megdnn } // namespace megdnn


#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \
op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
}

#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
inline constexpr cls& operator op##=(cls& x, cls y) { \
x = x op ::megdnn::EnumClassBit<cls>(y); \
return x; \
}

#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \
return ~::megdnn::EnumClassBit<cls>(x); \
}

#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 0
dnn/include/megdnn/oprs/base.h View File

@@ -251,6 +251,8 @@ protected:
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; Handle::HandleType m_handle_type = Handle::HandleType::NAIVE;
}; };


MEGDNN_DEF_ENUM_CLASS_BIT_OPR(Algorithm::Attribute)

//! policy for executing the operator //! policy for executing the operator
struct ExecutionPolicy { struct ExecutionPolicy {
//! INVALID_ALGO_TYPE algo_type means using heuristic //! INVALID_ALGO_TYPE algo_type means using heuristic


+ 6
- 2
dnn/scripts/gen_flatbuffers_schema.py View File

@@ -53,9 +53,13 @@ class FlatBuffersWriter(IndentWriterBase):
e = self._enums[(p, e)] e = self._enums[(p, e)]
self._write_doc(e.name) self._write_doc(e.name)
self._write("enum %s%s : uint {", p, e.name, indent=1) self._write("enum %s%s : uint {", p, e.name, indent=1)
for member in e.members:
for idx, member in enumerate(e.members):
self._write_doc(member) self._write_doc(member)
self._write("%s,", scramble_enum_member_name(str(member)))
if e.combined:
self._write("%s=%d,", scramble_enum_member_name(str(member)),
1<<idx)
else:
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("}\n", indent=-1) self._write("}\n", indent=-1)


def _write_doc(self, doc): def _write_doc(self, doc):


+ 17
- 5
dnn/scripts/gen_param_defs.py View File

@@ -80,13 +80,13 @@ class member_defs:
:attr member_alias: list of (member, alias) pairs :attr member_alias: list of (member, alias) pairs
""" """
__slots__ = ['name', 'name_field', 'members', 'default', __slots__ = ['name', 'name_field', 'members', 'default',
'member_alias']
'member_alias', 'combined']


all_enums = {} all_enums = {}
"""(param_name, name) => enum""" """(param_name, name) => enum"""


def __init__(self, param_name, name, name_field, members, default, def __init__(self, param_name, name, name_field, members, default,
member_alias):
member_alias, combined = False):
name = member_defs.Doc.make(name) name = member_defs.Doc.make(name)
assert name.id[0].isupper() assert name.id[0].isupper()
members = tuple(map(member_defs.Doc.make, members)) members = tuple(map(member_defs.Doc.make, members))
@@ -97,6 +97,7 @@ class member_defs:
default = name_field.index(default) default = name_field.index(default)
assert isinstance(default, int) assert isinstance(default, int)
self.name = name self.name = name
self.combined = combined
self.name_field = self.get_name_field(name.id, name_field) self.name_field = self.get_name_field(name.id, name_field)
self.members = members self.members = members
self.default = default self.default = default
@@ -197,6 +198,12 @@ class ParamDef:
self.name.id, name, name_field, members, default, member_alias)) self.name.id, name, name_field, members, default, member_alias))
return self return self


def add_bit_combination_enum(self, name, *members, default=0,
name_field=None, member_alias=[]):
self.members.append(member_defs.Enum(
self.name.id, name, name_field, members, default, member_alias, True))
return self

def add_enum_alias(self, name, src_class, src_name=None, name_field=None, def add_enum_alias(self, name, src_class, src_name=None, name_field=None,
default=None): default=None):
self.members.append(member_defs.EnumAlias( self.members.append(member_defs.EnumAlias(
@@ -463,8 +470,12 @@ class SerializedDType(_ParamDefBase):
for idx, emem in enumerate(e.members): for idx, emem in enumerate(e.members):
self._write('%s = "%s"', emem, emem) self._write('%s = "%s"', emem, emem)
self._write_doc(emem) self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx))
if e.combined:
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, 1<<idx))
else:
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx))


for emem, emem_alis in e.member_alias: for emem, emem_alis in e.member_alias:
self._write('%s = %s', emem_alis, emem) self._write('%s = %s', emem_alis, emem)
@@ -622,6 +633,8 @@ class CPPWriter(IndentWriterBase):
for idx, i in enumerate(e.members): for idx, i in enumerate(e.members):
self._write_doc(i) self._write_doc(i)
v = '{} = {}'.format(i, idx) v = '{} = {}'.format(i, idx)
if e.combined:
v = '{} = 1 << {}'.format(i, idx)
if i is not e.members[-1] or e.member_alias: if i is not e.members[-1] or e.member_alias:
v += ',' v += ','
self._write(v) self._write(v)
@@ -672,7 +685,6 @@ class CPPEnumValueWriter(CPPWriter):
self._write('static const uint32_t %s = %s;', alias, mem) self._write('static const uint32_t %s = %s;', alias, mem)
self._write('};', indent=-1) self._write('};', indent=-1)



def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
s = e.src_enum s = e.src_enum
self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name) self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name)


+ 7
- 2
dnn/scripts/gen_tablegen.py View File

@@ -91,12 +91,17 @@ class ConverterWriter(IndentWriterBase):
def format(v): def format(v):
return '\"{}\"'.format(str(v)) return '\"{}\"'.format(str(v))
enum_def += ','.join(format(i) for i in e.members) enum_def += ','.join(format(i) for i in e.members)
enum_def += "]"

if e.combined:
enum_def += "], 1"
else:
enum_def += "], 0"

if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
enum_def += ", 1" # whether generate ToStringTrait enum_def += ", 1" # whether generate ToStringTrait
enum_def += ">" enum_def += ">"
self._write("def {} : {};".format(td_class, enum_def))


self._write("def {} : {};".format(td_class, enum_def))
if self._skip_current_param: if self._skip_current_param:
return return




+ 0
- 2
dnn/src/common/algo_base.h View File

@@ -21,8 +21,6 @@


namespace megdnn { namespace megdnn {


MEGDNN_DEF_ENUM_CLASS_BIT_OPR(AlgoAttribute)

#define MEGDNN_DECL_ALGO_TYPE(_type) \ #define MEGDNN_DECL_ALGO_TYPE(_type) \
uint32_t type() const override { \ uint32_t type() const override { \
return static_cast<std::underlying_type<AlgoType>::type>( \ return static_cast<std::underlying_type<AlgoType>::type>( \


+ 0
- 55
dnn/src/common/utils.h View File

@@ -692,61 +692,6 @@ inline void* get_origin_ptr(const TensorND* tensor, void* ptr) {
tensor->layout.span().low_byte); tensor->layout.span().low_byte);
} }


template <typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;

constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}

public:
constexpr EnumClassBit(T v)
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}

constexpr operator T() const { return static_cast<T>(m_val); }

constexpr explicit operator bool() const { return m_val; }

#define DEF_OPR(op) \
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \
return m_val op rhs.m_val; \
}

DEF_OPR(&)
DEF_OPR(|)
DEF_OPR (^)

constexpr EnumClassBit operator~() const { return ~m_val; }

#undef DEF_OPR
};

#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \
op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
}

#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
inline constexpr cls& operator op##=(cls& x, cls y) { \
x = x op ::megdnn::EnumClassBit<cls>(y); \
return x; \
}

#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \
return ~::megdnn::EnumClassBit<cls>(x); \
}

} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 1
dnn/src/cuda/convolution3d/backward_filter/algo.h View File

@@ -218,4 +218,3 @@ public:
} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen


+ 38
- 14
imperative/python/megengine/functional/debug_param.py View File

@@ -8,9 +8,12 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os import os


from ..core.ops import builtin
from ..logger import get_logger from ..logger import get_logger
from ..utils.deprecation import deprecated from ..utils.deprecation import deprecated


Strategy = builtin.ops.Convolution.Strategy

_execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC") _execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")


if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
@@ -19,7 +22,7 @@ if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
) )




def get_execution_strategy() -> str:
def get_execution_strategy() -> Strategy:
""" """
Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'


@@ -28,12 +31,22 @@ def get_execution_strategy() -> str:
return _execution_strategy return _execution_strategy




def set_execution_strategy(option: str):
def set_execution_strategy(option):
""" """
Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul'


:param option: Decides how :class:`~.Conv2d` and :func:'~.matmul' algorithms are chosen.
Available values:
:param option: Decides how :class:`~.Conv2d`and :func:'~.matmul' algorithms are chosen.
Available value Strategy
* HEURISTIC uses heuristic to choose the fastest algorithm.
* PROFILE runs possible algorithms on real device to find the best one.
* REPRODUCIBLE uses the algorithms that is reproducible.
* OPTMIZED uses the algorithms that is optimized.

The default strategy is HEURISTIC, this options can be combined to
form a combination option, e.g. PROFILE | REPRODUCIBLE
can combined a option that uses the fastest of profiling result that is also reproducible.

Available values string:


* 'HEURISTIC' uses heuristic to choose the fastest algorithm. * 'HEURISTIC' uses heuristic to choose the fastest algorithm.
* 'PROFILE' runs possible algorithms on real device to find the best one. * 'PROFILE' runs possible algorithms on real device to find the best one.
@@ -45,18 +58,29 @@ def set_execution_strategy(option: str):


It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'.
""" """
valid_option = (
"HEURISTIC",
"PROFILE",
"PROFILE_HEURISTIC",
"PROFILE_REPRODUCIBLE",
"HEURISTIC_REPRODUCIBLE",
)
if not option in valid_option:
raise ValueError("Valid option can only be one of {}".format(valid_option))
valid_string_option = {
"REPRODUCIBLE": Strategy.REPRODUCIBLE,
"HEURISTIC": Strategy.HEURISTIC,
"PROFILE": Strategy.PROFILE,
}


global _execution_strategy # pylint: disable=global-statement global _execution_strategy # pylint: disable=global-statement
_execution_strategy = option
if isinstance(option, Strategy):
_execution_strategy = option
return

assert isinstance(option, str)

strategy_tmp = Strategy(0)
for opt in option.split("_"):
if not opt in valid_string_option:
raise ValueError(
"Valid option can only be one of {}, or combine them with '_'.".format(
valid_string_option.keys()
)
)
strategy_tmp = strategy_tmp | valid_string_option[opt]
_execution_strategy = strategy_tmp




@deprecated(version="1.3", reason="use get_execution_strategy() instead") @deprecated(version="1.3", reason="use get_execution_strategy() instead")


+ 4
- 1
imperative/python/test/integration/test_correctness_mnistnet.py View File

@@ -19,6 +19,7 @@ import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import jit from megengine import jit
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_execution_strategy from megengine.functional.debug_param import set_execution_strategy
from megengine.jit import SublinearMemoryConfig from megengine.jit import SublinearMemoryConfig
@@ -33,6 +34,8 @@ from megengine.module import (
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.tensor import Tensor from megengine.tensor import Tensor


Strategy = builtin.ops.Convolution.Strategy



def get_gpu_name(): def get_gpu_name():
try: try:
@@ -242,7 +245,7 @@ def test_correctness():
else: else:
model_name = "mnist_model_with_test_cpu.mge" model_name = "mnist_model_with_test_cpu.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name) model_path = os.path.join(os.path.dirname(__file__), model_name)
set_execution_strategy("HEURISTIC_REPRODUCIBLE")
set_execution_strategy(Strategy.HEURISTIC | Strategy.REPRODUCIBLE)


run_train(model_path, False, False, max_err=1e-5) run_train(model_path, False, False, max_err=1e-5)
run_train(model_path, True, False, max_err=1e-5) run_train(model_path, True, False, max_err=1e-5)


+ 14
- 0
imperative/tablegen/autogen.cpp View File

@@ -337,6 +337,20 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
className, attr->getEnumName(), i className, attr->getEnumName(), i
)); ));
} }
if (attr->getEnumCombinedFlag()) {
//! define operator |
os << formatv(
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
"\n })",
className, attr->getEnumName());
//! define operator &
os << formatv(
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
"\n })",
className, attr->getEnumName());
}
os << formatv( os << formatv(
"\n .def(py::init([](const std::string& in) {" "\n .def(py::init([](const std::string& in) {"
"\n auto&& str = normalize_enum(in);" "\n auto&& str = normalize_enum(in);"


+ 3
- 0
imperative/tablegen/helper.h View File

@@ -77,6 +77,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
bool supportToString() const { bool supportToString() const {
return getBaseRecord()->getValueAsBit("supportToString"); return getBaseRecord()->getValueAsBit("supportToString");
} }
bool getEnumCombinedFlag() const {
return getBaseRecord()->getValueAsBit("enumCombined");
}
}; };


struct MgbHashableAttrMixin : public MgbAttrWrapperBase { struct MgbHashableAttrMixin : public MgbAttrWrapperBase {


+ 33
- 10
sdk/load-and-run/src/mgblar.cpp View File

@@ -142,8 +142,16 @@ R"__usage__(
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
R"__usage__( R"__usage__(
--fast-run --fast-run
Enable fast-run mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes.
This param will be deperated later, please replace with param --full-profile.
--full-profile
Enable full-profile mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, all algorithms will be profiled
include naive algorithms.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
--fast-profile
Enable fast-profile mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, this mode will only profile the
well optimized algorithms to get the profile result fast.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
)__usage__" )__usage__"
#endif #endif
@@ -511,7 +519,8 @@ struct Args {
bool disable_assert_throw = false; bool disable_assert_throw = false;
bool share_param_mem = false; bool share_param_mem = false;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
bool use_fast_run = false;
bool use_full_profile = false;
bool use_fast_profile = false;
#endif #endif
bool reproducible = false; bool reproducible = false;
std::string fast_run_cache_path; std::string fast_run_cache_path;
@@ -695,18 +704,20 @@ void run_test_st(Args &env) {
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::HEURISTIC; S strategy = S::HEURISTIC;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (env.use_fast_run) {
if (env.use_full_profile) {
if (env.reproducible) { if (env.reproducible) {
strategy = S::PROFILE_REPRODUCIBLE;
strategy = S::PROFILE | S::REPRODUCIBLE;
} else { } else {
strategy = S::PROFILE; strategy = S::PROFILE;
} }
} else if (env.use_fast_profile) {
strategy = S::PROFILE | S::OPTMIZED;
} else if (env.reproducible) { } else if (env.reproducible) {
strategy = S::HEURISTIC_REPRODUCIBLE;
strategy = S::HEURISTIC | S::REPRODUCIBLE;
} }
#else #else
if (env.reproducible) { if (env.reproducible) {
strategy = S::HEURISTIC_REPRODUCIBLE;
strategy = S::HEURISTIC | S::REPRODUCIBLE;
} }
#endif #endif
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
@@ -729,11 +740,12 @@ void run_test_st(Args &env) {
std::make_shared<InFilePersistentCache>(buf.get(), flen)); std::make_shared<InFilePersistentCache>(buf.get(), flen));
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
} else { } else {
mgb_assert(env.use_fast_run, "fast-run should be enabled");
mgb_assert(env.use_full_profile || env.use_fast_profile,
"fast-run or fast-profile should be enabled");
PersistentCache::set_impl( PersistentCache::set_impl(
std::make_shared<InFilePersistentCache>()); std::make_shared<InFilePersistentCache>());
} }
if (!env.use_fast_run)
if (!env.use_full_profile && !env.use_fast_profile)
#endif #endif
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
} }
@@ -1314,7 +1326,18 @@ Args Args::from_argv(int argc, char **argv) {
} }
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (!strcmp(argv[i], "--fast-run")) { if (!strcmp(argv[i], "--fast-run")) {
ret.use_fast_run = true;
mgb_log_warn(
"--fast-run param will be deperated later, please replace "
"with --full-profile or --fast-profile.");
ret.use_full_profile = true;
continue;
}
if (!strcmp(argv[i], "--full-profile")) {
ret.use_full_profile = true;
continue;
}
if (!strcmp(argv[i], "--fast-profile")) {
ret.use_fast_profile = true;
continue; continue;
} }
#endif #endif


+ 5
- 5
src/core/impl/utils/persistent_cache.cpp View File

@@ -188,7 +188,7 @@ AlgoChooserProfileCache::get(const Key &key) {
auto entry_len = read_uint32(); auto entry_len = read_uint32();
mgb_assert(buf + entry_len <= buf_end); mgb_assert(buf + entry_len <= buf_end);
auto nr = sscanf(reinterpret_cast<const char*>(buf), ENTRY_FMT, auto nr = sscanf(reinterpret_cast<const char*>(buf), ENTRY_FMT,
&i.reproducible, &i.time, &i.workspace);
&i.attribute, &i.time, &i.workspace);
mgb_assert(nr == 3); mgb_assert(nr == 3);
buf += entry_len; buf += entry_len;
} }
@@ -210,10 +210,10 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) {
auto &&cur = result[i]; auto &&cur = result[i];


if (prev.workspace <= cur.workspace && if (prev.workspace <= cur.workspace &&
prev.reproducible == cur.reproducible) {
prev.attribute == cur.attribute) {
result.erase(result.begin() + i); result.erase(result.begin() + i);
} else { } else {
++ i;
++i;
} }
} }


@@ -235,8 +235,8 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) {
write_uint32(0); write_uint32(0);
pos = val.size(); pos = val.size();
val.resize(pos + SPR_SIZE); val.resize(pos + SPR_SIZE);
uint32_t nr = snprintf(&val[pos], SPR_SIZE,
ENTRY_FMT, i.reproducible, i.time, i.workspace);
uint32_t nr = snprintf(&val[pos], SPR_SIZE, ENTRY_FMT, i.attribute,
i.time, i.workspace);
//! for memory boundary failed, snprintf ret do not contain \0 //! for memory boundary failed, snprintf ret do not contain \0
nr += 1; nr += 1;
mgb_assert(nr < SPR_SIZE); mgb_assert(nr < SPR_SIZE);


+ 12
- 0
src/core/include/megbrain/common.h View File

@@ -12,6 +12,8 @@
#pragma once #pragma once


#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/basic_types.h"


#include <memory> #include <memory>
#include <string> #include <string>
@@ -242,6 +244,16 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) {
return n; return n;
} }
#endif #endif

#define MGB_DEF_ENUM_CLASS_BIT_OPR(cls) \
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls)

} // namespace mgb } // namespace mgb


namespace megdnn {
namespace param {
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy)
}
} // namespace megdnn

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 0
- 1
src/core/include/megbrain/comp_node.h View File

@@ -12,7 +12,6 @@
#pragma once #pragma once


#include "megbrain/utils/hash.h" #include "megbrain/utils/hash.h"
#include "megbrain/utils/enum_class_bit.h"
#include "megbrain/utils/metahelper.h" #include "megbrain/utils/metahelper.h"
#include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/thread.h" #include "megbrain/utils/thread.h"


+ 0
- 1
src/core/include/megbrain/graph/operator_node.h View File

@@ -16,7 +16,6 @@
#include "megbrain/graph/symbol_var.h" #include "megbrain/graph/symbol_var.h"


#include "megbrain/utils/hashable.h" #include "megbrain/utils/hashable.h"
#include "megbrain/utils/enum_class_bit.h"
#include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/small_vector.h" #include "megbrain/utils/small_vector.h"




+ 0
- 1
src/core/include/megbrain/graph/var_node.h View File

@@ -12,7 +12,6 @@
#pragma once #pragma once


#include "megbrain/graph/bases.h" #include "megbrain/graph/bases.h"
#include "megbrain/utils/enum_class_bit.h"
#include "megbrain/utils/comp_node_sync_manager.h" #include "megbrain/utils/comp_node_sync_manager.h"
#include "megbrain/utils/small_vector.h" #include "megbrain/utils/small_vector.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"


+ 5
- 4
src/core/include/megbrain/ir/base.td View File

@@ -33,10 +33,11 @@ class MgbHashableAttrMixin {
string reprFunction = "std::to_string($0)"; string reprFunction = "std::to_string($0)";
} }


class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> {
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit combined, bit toString> {
string parentNamespace = namespace; string parentNamespace = namespace;
string enumName = name; string enumName = name;
list<string> enumMembers = members; list<string> enumMembers = members;
bit enumCombined = combined;
bit supportToString = toString; bit supportToString = toString;
} }


@@ -166,8 +167,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>:
} }


// -- enum types // -- enum types
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> {
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit combined, bit toString=0>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, combined, toString> {
let storageType = "::mlir::IntegerAttr"; let storageType = "::mlir::IntegerAttr";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
@@ -176,7 +177,7 @@ class MgbEnumAttr<string namespace, string enumName, list<string> members, bit t
} }


class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>:
MgbEnumAttr<namespace, enumName, base.enumMembers>, MgbAliasAttrMixin<base>;
MgbEnumAttr<namespace, enumName, base.enumMembers, 0>, MgbAliasAttrMixin<base>;


// -- other types // -- other types
def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> {


+ 0
- 89
src/core/include/megbrain/utils/enum_class_bit.h View File

@@ -1,89 +0,0 @@
/**
* \file src/core/include/megbrain/utils/enum_class_bit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include <type_traits>

namespace mgb {
template<typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;

constexpr EnumClassBit(std::underlying_type_t<T> v):
m_val(v)
{
}

public:
constexpr EnumClassBit(T v):
m_val(static_cast<std::underlying_type_t<T>>(v))
{
}

constexpr operator T() const {
return static_cast<T>(m_val);
}

constexpr explicit operator bool() const {
return m_val;
}

#define DEF_OPR(op) \
constexpr EnumClassBit operator op (\
const EnumClassBit &rhs) const { \
return m_val op rhs.m_val; \
}

DEF_OPR(&)
DEF_OPR(|)
DEF_OPR(^)

constexpr EnumClassBit operator ~() const {
return ~m_val;
}


#undef DEF_OPR
};

}

#define _MGB_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::mgb::EnumClassBit<cls> operator op (cls x, cls y) { \
return ::mgb::EnumClassBit<cls>(x) op ::mgb::EnumClassBit<cls>(y); \
} \
inline constexpr ::mgb::EnumClassBit<cls> operator op ( \
::mgb::EnumClassBit<cls> x, cls y) { \
return x op ::mgb::EnumClassBit<cls>(y); \
}

#define _MGB_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
inline constexpr cls& operator op##= (cls& x, cls y) { \
x = x op ::mgb::EnumClassBit<cls>(y); \
return x; \
}

#define MGB_DEF_ENUM_CLASS_BIT_OPR(cls) \
_MGB_DECBO_SINGLE_OPR(cls, &) \
_MGB_DECBO_SINGLE_OPR(cls, |) \
_MGB_DECBO_SINGLE_OPR(cls, ^) \
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, &) \
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, |) \
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \
inline constexpr ::mgb::EnumClassBit<cls> operator ~ (cls x) { \
return ~::mgb::EnumClassBit<cls>(x); \
} \



// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}


+ 1
- 2
src/core/include/megbrain/utils/persistent_cache.h View File

@@ -100,8 +100,7 @@ namespace mgb {


struct ResultEntry { struct ResultEntry {
std::string algo; //! identifier of the algorithm std::string algo; //! identifier of the algorithm
//! sscanf will up bool as int
int reproducible; //! whether algorithm is reproducible
uint32_t attribute; //! algo attribute, e.g. reproducible
double time; //! execution time in seconds double time; //! execution time in seconds
size_t workspace; //! workspace in bytes size_t workspace; //! workspace in bytes
}; };


+ 3
- 6
src/gopt/impl/inference.cpp View File

@@ -54,7 +54,6 @@ using namespace gopt;


namespace { namespace {



template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder> template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
void param_merge(OptState& opt_state) { void param_merge(OptState& opt_state) {
auto rewriter = opt_state.graph().make_rewriter(); auto rewriter = opt_state.graph().make_rewriter();
@@ -102,7 +101,7 @@ void param_merge(OptState& opt_state) {
rewriter.apply_inplace(); rewriter.apply_inplace();
} }


}
} // namespace


/* ================ global functions ================ */ /* ================ global functions ================ */


@@ -190,12 +189,10 @@ void gopt::enable_opr_algo_profiling_inplace(


void gopt::enable_opr_use_profiling_cache_inplace( void gopt::enable_opr_use_profiling_cache_inplace(
const VarNodeArrayView& dest_vars) { const VarNodeArrayView& dest_vars) {
modify_opr_algo_strategy_inplace(
dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy::
Strategy::PROFILE_HEURISTIC);
using S = megdnn::param::ExecutionPolicy::Strategy;
modify_opr_algo_strategy_inplace(dest_vars, S::PROFILE | S::HEURISTIC);
} }



void gopt::set_opr_algo_workspace_limit_inplace( void gopt::set_opr_algo_workspace_limit_inplace(
const VarNodeArrayView& dest_vars, size_t workspace_limit) { const VarNodeArrayView& dest_vars, size_t workspace_limit) {
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>


+ 16
- 1
src/gopt/test/inference.cpp View File

@@ -1693,7 +1693,22 @@ TEST(TestGoptInference, ProfileCache) {
using S = opr::Convolution::ExecutionPolicy::Strategy; using S = opr::Convolution::ExecutionPolicy::Strategy;
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy);
gopt::enable_opr_use_profiling_cache_inplace({z + 2.3f}); gopt::enable_opr_use_profiling_cache_inplace({z + 2.3f});
ASSERT_EQ(S::PROFILE_HEURISTIC, conv.execution_policy().strategy);
ASSERT_EQ(S::PROFILE | S::HEURISTIC, conv.execution_policy().strategy);
}

TEST(TestGoptInference, FastProfileCache) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen({4, 3, 8, 9}), host_y = gen({2, 3, 3, 3});
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Host2DeviceCopy::make(*graph, host_y),
z = opr::Convolution::make(x, y);
auto&& conv = z.node()->owner_opr()->cast_final_safe<opr::Convolution>();
using S = opr::Convolution::ExecutionPolicy::Strategy;
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy);
gopt::modify_opr_algo_strategy_inplace({z + 2.3f},
S::PROFILE | S::OPTMIZED);
ASSERT_EQ(S::PROFILE | S::OPTMIZED, conv.execution_policy().strategy);
} }


TEST(TestGoptInference, AlgoWorkspaceLimit) { TEST(TestGoptInference, AlgoWorkspaceLimit) {


+ 0
- 1
src/opr/impl/dnn/dnn.sereg.h View File

@@ -20,7 +20,6 @@
#include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/tqt.h" #include "megbrain/opr/dnn/tqt.h"

#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"


+ 62
- 59
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -284,8 +284,9 @@ namespace mgb {
namespace opr { namespace opr {


template <typename Opr> template <typename Opr>
void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
if (ctx.get_profile_result_from_cache(require_reproducible).valid())
void AlgoChooser<Opr>::profile(ExeContext& ctx,
ExecutionStrategy select_strategy) {
if (ctx.get_profile_result_from_cache(select_strategy).valid())
return; return;
AlgoChooserProfileCache::Result prof_rst; AlgoChooserProfileCache::Result prof_rst;


@@ -305,7 +306,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
algo.name.c_str(), str_on_inp_shape.c_str()); algo.name.c_str(), str_on_inp_shape.c_str());
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
policy.algo = algo.desc; policy.algo = algo.desc;
ctx.construct_execution_policy(require_reproducible, policy);
ctx.construct_execution_policy(select_strategy, policy);
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) if (ctx.get_workspace_size_bytes(policy) >= workspace_limit)
continue; continue;


@@ -354,7 +355,8 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {


template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible,
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
ExecutionStrategy select_strategy,
bool enable_update) { bool enable_update) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile")))
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { if (ctx.owner_graph()->options().no_profiling_on_shape_change) {
@@ -376,11 +378,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible,
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
_item.param, ctx.mgb_opr(), ctx.comp_node(), _item.param, ctx.mgb_opr(), ctx.comp_node(),
ctx.execution_policy(), ctx.allow_weight_preprocess()); ctx.execution_policy(), ctx.allow_weight_preprocess());
AlgoChooser<_Opr>::profile(sub_ctx, require_reproducible);
AlgoChooser<_Opr>::profile(sub_ctx, select_strategy);
}); });
} }
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
ctx.construct_execution_policy(require_reproducible, policy);
ctx.construct_execution_policy(select_strategy, policy);
return policy; return policy;
MIDOUT_E MIDOUT_E
} }
@@ -402,11 +404,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
if (auto algo_choose_hook = mgb_opr->algo_chooser()) { if (auto algo_choose_hook = mgb_opr->algo_chooser()) {
policy = algo_choose_hook(mgb_opr); policy = algo_choose_hook(mgb_opr);
ctx.construct_execution_policy(
mgb_opr->execution_policy().strategy ==
mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::
HEURISTIC_REPRODUCIBLE,
policy, false);
ctx.construct_execution_policy((ExecutionStrategy::HEURISTIC |
ExecutionStrategy::REPRODUCIBLE),
policy, false);
} }
if (!policy.algo.valid()) { if (!policy.algo.valid()) {
policy = get_policy(ctx); policy = get_policy(ctx);
@@ -419,10 +419,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(palgo, "Unknown algo description"); mgb_assert(palgo, "Unknown algo description");
ret.append("): algo=" + std::string(palgo->name())); ret.append("): algo=" + std::string(palgo->name()));
ret.append(ssprintf(" workspace=%.2fMiB reproducible=%d",
ret.append(ssprintf(" workspace=%.2fMiB attirbute=%d",
workspace / (1024 * 1024.0), workspace / (1024 * 1024.0),
palgo->contain_attribute(
megdnn::AlgoAttribute::REPRODUCIBLE)));
static_cast<uint32_t>(palgo->attribute())));
mgb_log_debug("%s", ret.c_str()); mgb_log_debug("%s", ret.c_str());


megdnn_opr->execution_policy() = policy; megdnn_opr->execution_policy() = policy;
@@ -432,41 +431,39 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
ExeContext& ctx) { ExeContext& ctx) {
using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE);
switch (ctx.execution_policy().strategy) {
case S::HEURISTIC:
return ctx.choose_by_heuristic();
case S::HEURISTIC_REPRODUCIBLE:
return ctx.choose_by_heuristic(true);
case S::PROFILE_HEURISTIC: {
ImplExecutionPolicy policy = choose_by_profile(ctx, false, false);
if (!policy.algo.valid())
policy = ctx.choose_by_heuristic();
return policy;
}
auto opr_strategy = ctx.execution_policy().strategy;
if ((opr_strategy & ExecutionStrategy::HEURISTIC) &&
(opr_strategy & ExecutionStrategy::PROFILE)) {
ImplExecutionPolicy policy =
choose_by_profile(ctx, opr_strategy, false);
if (!policy.algo.valid())
policy = ctx.choose_by_heuristic(opr_strategy);
return policy;
} else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) {
return ctx.choose_by_heuristic(opr_strategy);
}
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
case S::PROFILE:
return choose_by_profile(ctx, false);
case S::PROFILE_REPRODUCIBLE:
return choose_by_profile(ctx, true);
else if (opr_strategy & ExecutionStrategy::PROFILE) {
return choose_by_profile(ctx, opr_strategy);
}
#endif #endif
default:
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy");
else {
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy");
} }
} }


#define INST(Opr) \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \
template void AlgoChooser<megdnn::Opr>::profile( \
ExeContext& ctx, bool require_reproducible); \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::choose_by_profile( \
ExeContext& ctx, bool require_reproducible, bool enable_update); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess); \
#define INST(Opr) \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \
template void AlgoChooser<megdnn::Opr>::profile(ExeContext& ctx, \
ExecutionStrategy); \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::choose_by_profile( \
ExeContext& ctx, ExecutionStrategy, bool enable_update); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess);


MGB_FOREACH_FASTRUN_OPR(INST) MGB_FOREACH_FASTRUN_OPR(INST)


@@ -498,7 +495,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext(
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgo typename AlgoChooser<Opr>::ImplAlgo
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
bool require_reproducible) const {
ExecutionStrategy select_strategy) const {
MIDOUT_B(Opr, MIDOUT_B(Opr,
midout_iv(MGB_HASH_STR( midout_iv(MGB_HASH_STR(
"AlgoChooser::ExeContext::get_profile_result_from_cache"))) "AlgoChooser::ExeContext::get_profile_result_from_cache")))
@@ -522,7 +519,9 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if (prof.empty()) if (prof.empty())
return {}; return {};
for (auto&& i : prof) { for (auto&& i : prof) {
if ((!require_reproducible || i.reproducible)) {
if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) ||
static_cast<AlgoAttribute>(i.attribute) &
AlgoAttribute::REPRODUCIBLE) {
auto iter = algo_map.find(i.algo); auto iter = algo_map.find(i.algo);
mgb_assert(iter != algo_map.end(), mgb_assert(iter != algo_map.end(),
"algorithm %s exists in " "algorithm %s exists in "
@@ -550,7 +549,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(


template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
ExecutionStrategy select_strategy) const {
if (m_execution_policy.workspace_limit != if (m_execution_policy.workspace_limit !=
std::numeric_limits<decltype( std::numeric_limits<decltype(
m_execution_policy.workspace_limit)>::max()) { m_execution_policy.workspace_limit)>::max()) {
@@ -558,6 +558,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
"workspace_limit should not be setted if choose algo by " "workspace_limit should not be setted if choose algo by "
"heuristic"); "heuristic");
} }
bool reproducible = static_cast<bool>(select_strategy &
ExecutionStrategy::REPRODUCIBLE);
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
@@ -579,7 +581,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess); m_allow_weight_preprocess);
policy.sub_policy.push_back(sub_ctx.choose_by_heuristic(reproducible));
policy.sub_policy.push_back(
sub_ctx.choose_by_heuristic(select_strategy));
}); });


return policy; return policy;
@@ -588,9 +591,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
template <typename Opr> template <typename Opr>
std::vector<typename AlgoChooser<Opr>::ImplAlgo> std::vector<typename AlgoChooser<Opr>::ImplAlgo>
AlgoChooser<Opr>::ExeContext::get_all_candidates() const { AlgoChooser<Opr>::ExeContext::get_all_candidates() const {
auto heu = choose_by_heuristic();
auto&& ret =
APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts);
auto heu = choose_by_heuristic(ExecutionStrategy::HEURISTIC);
auto&& ret = APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts);
bool found = false; bool found = false;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (ret[i].desc == heu.algo) { if (ret[i].desc == heu.algo) {
@@ -611,19 +613,21 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const {


template <typename Opr> template <typename Opr>
void AlgoChooser<Opr>::ExeContext::construct_execution_policy( void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
bool require_reproducible,
ExecutionStrategy select_strategy,
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, typename AlgoChooser<Opr>::ImplExecutionPolicy& policy,
bool retrive_from_cache) const { bool retrive_from_cache) const {
bool reproducible = static_cast<bool>(select_strategy &
ExecutionStrategy::REPRODUCIBLE);
if (!policy.algo.valid()) { if (!policy.algo.valid()) {
if (retrive_from_cache) { if (retrive_from_cache) {
policy.algo = policy.algo =
get_profile_result_from_cache(require_reproducible).desc;
get_profile_result_from_cache(select_strategy).desc;
} else { } else {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, args..., workspace_limit,
require_reproducible),
reproducible),
m_layouts) m_layouts)
.desc; .desc;
} }
@@ -647,7 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess); m_allow_weight_preprocess);
policy.sub_policy.push_back({}); policy.sub_policy.push_back({});
sub_ctx.construct_execution_policy(require_reproducible,
sub_ctx.construct_execution_policy(select_strategy,
policy.sub_policy.back(), policy.sub_policy.back(),
retrive_from_cache); retrive_from_cache);
}); });
@@ -718,8 +722,7 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo(
return None; return None;
return AlgoChooserProfileCache::ResultEntry{ return AlgoChooserProfileCache::ResultEntry{
palgo->name(), palgo->name(),
palgo->contain_attribute(
megdnn::AlgoAttribute::REPRODUCIBLE),
static_cast<uint32_t>(palgo->attribute()),
rst.val().time, param.workspace}; rst.val().time, param.workspace};
} }


@@ -768,10 +771,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
bool allow_weight_preprocess); \ bool allow_weight_preprocess); \
template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \ AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \
bool reproducible) const; \
ExecutionStrategy select_strategy) const; \
template typename AlgoChooser<megdnn::Opr>::ImplAlgo \ template typename AlgoChooser<megdnn::Opr>::ImplAlgo \
AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \ AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \
bool require_reproducible) const; \
ExecutionStrategy select_strategy) const; \
template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \ template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \
AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \ AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \
template size_t \ template size_t \
@@ -780,7 +783,7 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
policy) const; \ policy) const; \
template void \ template void \
AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \ AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \
bool require_reproducible, \
ExecutionStrategy select_strategy, \
typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \ typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \
bool retrive_from_cache) const; \ bool retrive_from_cache) const; \
template Maybe<AlgoChooserProfileCache::ResultEntry> \ template Maybe<AlgoChooserProfileCache::ResultEntry> \


+ 18
- 10
src/opr/include/megbrain/opr/search_policy/algo_chooser.h View File

@@ -35,6 +35,13 @@ MGB_FOREACH_FASTRUN_OPR(cb)
#undef cb #undef cb


namespace mgb { namespace mgb {

//! define logical operation of megdnn::param::ExecutionPolicy::Strategy::Enum
//! and megdnn::detail::AlgoAttribute enum
using ExecutionStrategy = megdnn::param::ExecutionPolicy::Strategy;

using AlgoAttribute = megdnn::AlgoAttribute;

namespace opr { namespace opr {


/* =================== AlgoChooser =================== */ /* =================== AlgoChooser =================== */
@@ -103,7 +110,7 @@ public:
const FixedTensorLayouts& layouts() const { return m_layouts; } const FixedTensorLayouts& layouts() const { return m_layouts; }


ImplExecutionPolicy choose_by_heuristic( ImplExecutionPolicy choose_by_heuristic(
bool reproducible = false) const;
ExecutionStrategy select_strategy) const;


//! get all candidate algos, and the one choose_by_heuristic() is //! get all candidate algos, and the one choose_by_heuristic() is
//! put first //! put first
@@ -126,19 +133,20 @@ public:
const ImplExecutionPolicy& policy, double& timeout) const; const ImplExecutionPolicy& policy, double& timeout) const;


//! get all profile algorithm from cache, return invalid if not exists //! get all profile algorithm from cache, return invalid if not exists
ImplAlgo get_profile_result_from_cache(bool require_reproducible) const;
ImplAlgo get_profile_result_from_cache(
ExecutionStrategy select_strategy) const;


/** /**
* \brief construct execution policy from cache or heuristic. * \brief construct execution policy from cache or heuristic.
* *
* \param require_reproducible select algo which is reproducible
* \param select_strategy select algo which matched this strategy
* \param policy execution policy * \param policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get * \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise. * from heuristic otherwise.
*/ */
void construct_execution_policy(
bool require_reproducible, ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const;
void construct_execution_policy(ExecutionStrategy select_strategy,
ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const;


private: private:
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const;
@@ -153,11 +161,11 @@ private:




//! profile and save to cache //! profile and save to cache
static void profile(ExeContext& ctx, bool require_reproducible);
static void profile(ExeContext& ctx, ExecutionStrategy select_strategy);


static ImplExecutionPolicy choose_by_profile(ExeContext& ctx,
bool require_reproducible,
bool enable_update = true);
static ImplExecutionPolicy choose_by_profile(
ExeContext& ctx, ExecutionStrategy select_strategy,
bool enable_update = true);


public: public:
/*! /*!


+ 0
- 2
src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h View File

@@ -13,7 +13,6 @@
#pragma once #pragma once


#include "megbrain/graph/operator_node.h" #include "megbrain/graph/operator_node.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/oprs/base.h" #include "megdnn/oprs/base.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"


@@ -73,7 +72,6 @@ protected:


}; };
} // namespace mixin } // namespace mixin

} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb




+ 4
- 3
src/opr/test/blas.cpp View File

@@ -429,10 +429,11 @@ TEST(TestOprDNN, MatrixMulExePolicy) {
auto cn = CompNode::load("cpux"); auto cn = CompNode::load("cpux");


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy: {S:HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif


auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();


+ 39
- 23
src/opr/test/dnn/convolution.cpp View File

@@ -355,11 +355,13 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
auto cn = CompNode::load("cpux"); auto cn = CompNode::load("cpux");


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy: {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif

auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;


@@ -397,7 +399,8 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) {


auto cn = CompNode::load("cpux"); auto cn = CompNode::load("cpux");


for (auto strategy: {S::PROFILE, S::PROFILE_REPRODUCIBLE}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) {


auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
@@ -439,10 +442,12 @@ TEST(TestOprDNN, ConvolutionExePolicy) {
PersistentCacheHook cache_hook{on_get}; PersistentCacheHook cache_hook{on_get};


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
using Checker = AutoOprChecker<2, 1>; using Checker = AutoOprChecker<2, 1>;


@@ -522,10 +527,11 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) {
PersistentCacheHook cache_hook{on_get}; PersistentCacheHook cache_hook{on_get};


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
for (auto strategy :
{S::PROFILE, S::HEURISTIC, S(S::PROFILE | S::REPRODUCIBLE),
S(S::PROFILE | S::HEURISTIC)}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy: {S:HEURISTIC, S(S::PROFILE | S::HEURISTIC)}) {
#endif #endif
using Checker = AutoOprChecker<2, 1>; using Checker = AutoOprChecker<2, 1>;


@@ -1183,9 +1189,12 @@ TEST(TestOprDNN, Convolution3DExePolicy) {
using S = Policy::Strategy; using S = Policy::Strategy;


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy: {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif


using Checker = AutoOprChecker<2, 1>; using Checker = AutoOprChecker<2, 1>;
@@ -1660,10 +1669,12 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) {
PersistentCacheHook cache_hook{on_get}; PersistentCacheHook cache_hook{on_get};


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto make_graph = [&](const Checker::SymInpArray& inputs) auto make_graph = [&](const Checker::SymInpArray& inputs)
-> Checker::SymOutArray { -> Checker::SymOutArray {
@@ -1769,10 +1780,12 @@ TEST(TestOprDNN, DeformableConvForward) {
Param param; Param param;


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy : {S : HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto make_graph = [&](const Checker::SymInpArray& inputs) auto make_graph = [&](const Checker::SymInpArray& inputs)
-> Checker::SymOutArray { -> Checker::SymOutArray {
@@ -1936,10 +1949,12 @@ TEST(TestOprDNN, BatchConvBiasForward) {
param.sparse = Param::Sparse::DENSE; param.sparse = Param::Sparse::DENSE;


#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
#else #else
for (auto strategy : {S : HEURISTIC, S::PROFILE_HEURISTIC}) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif


auto make_quantized = [&](SymbolVar x, const DType& dtype) { auto make_quantized = [&](SymbolVar x, const DType& dtype) {
@@ -2080,7 +2095,8 @@ TEST(TestOprDNN, HeuristicReproducible) {


constexpr size_t PH = 1, PW = 1, SH = 1, SW = 1; constexpr size_t PH = 1, PW = 1, SH = 1, SW = 1;


for (auto strategy : {S::HEURISTIC, S::HEURISTIC_REPRODUCIBLE}) {
for (auto strategy :
SmallVector<S>{S::HEURISTIC, S::HEURISTIC | S::REPRODUCIBLE}) {
VarNode* bwd_flt; VarNode* bwd_flt;
auto make_graph = [&](const Checker::SymInpArray& inputs) auto make_graph = [&](const Checker::SymInpArray& inputs)
-> Checker::SymOutArray { -> Checker::SymOutArray {
@@ -2126,7 +2142,7 @@ TEST(TestOprDNN, HeuristicReproducible) {
megdnn::Algorithm* palgo = megdnn::Algorithm* palgo =
megdnn_opr->get_algorithm_from_desc(algo); megdnn_opr->get_algorithm_from_desc(algo);
mgb_assert(palgo, "Unknown algo description"); mgb_assert(palgo, "Unknown algo description");
if (strategy == S::HEURISTIC_REPRODUCIBLE) {
if (strategy == S(S::HEURISTIC | S::REPRODUCIBLE)) {
EXPECT_TRUE(palgo->contain_attribute( EXPECT_TRUE(palgo->contain_attribute(
megdnn::AlgoAttribute::REPRODUCIBLE)); megdnn::AlgoAttribute::REPRODUCIBLE));
} }


+ 1
- 0
test/src/include/megbrain/test/helper.h View File

@@ -43,6 +43,7 @@ namespace megdnn {
std::ostream &ostr, const DType &dt) { std::ostream &ostr, const DType &dt) {
return ostr << dt.name(); return ostr << dt.name();
} }

} // namespace megdnn } // namespace megdnn


namespace mgb { namespace mgb {


+ 15
- 1
tools/param_defs/mgb_opr_param_defs.py View File

@@ -18,7 +18,7 @@ pdef('PersistentOutputStorage').add_fields(
add_const('int32', 'INVALID_AXIS', 'MAX_NDIM'). add_const('int32', 'INVALID_AXIS', 'MAX_NDIM').
add_fields('int32', 'axis', 'INVALID_AXIS')) add_fields('int32', 'axis', 'INVALID_AXIS'))


(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator').
(pdef('ExecutionPolicy', version=0, is_legacy=True).
add_enum('Strategy', add_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, ' Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, '
@@ -33,6 +33,20 @@ pdef('PersistentOutputStorage').add_fields(
Doc('workspace_limit', 'workspace limit in bytes'), Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull')) str(2**64-1)+'ull'))


(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1).
add_bit_combination_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('PROFILE',
'run possible algorithms on real device to find the best'),
Doc('REPRODUCIBLE',
'when profile or heuristic algo selection it require the algos'
'must be reproducible'),
Doc('OPTMIZED',
'profile require algos are optmized to achieve fast-profile')).
add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull'))

(pdef('AssertEqual'). (pdef('AssertEqual').
add_fields('float32', add_fields('float32',
Doc('maxerr', 'max allowed error; error is defined as the minimal ' Doc('maxerr', 'max allowed error; error is defined as the minimal '


Loading…
Cancel
Save