GitOrigin-RevId: 843dc3a790
tags/v1.3.0
@@ -506,10 +506,66 @@ struct DynOutMallocPolicyCall { | |||||
} | } | ||||
}; | }; | ||||
template <typename T> | |||||
class EnumClassBit { | |||||
std::underlying_type_t<T> m_val; | |||||
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | |||||
public: | |||||
constexpr EnumClassBit(T v) | |||||
: m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||||
constexpr operator T() const { return static_cast<T>(m_val); } | |||||
constexpr explicit operator bool() const { return m_val; } | |||||
#define DEF_OPR(op) \ | |||||
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \ | |||||
return m_val op rhs.m_val; \ | |||||
} | |||||
DEF_OPR(&) | |||||
DEF_OPR(|) | |||||
DEF_OPR (^) | |||||
constexpr EnumClassBit operator~() const { return ~m_val; } | |||||
#undef DEF_OPR | |||||
}; | |||||
#endif // MEGDNN_CC_HOST | #endif // MEGDNN_CC_HOST | ||||
} // namespace megdnn | } // namespace megdnn | ||||
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||||
return ::megdnn::EnumClassBit<cls>(x) \ | |||||
op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||||
::megdnn::EnumClassBit<cls> x, cls y) { \ | |||||
return x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} | |||||
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | |||||
inline constexpr cls& operator op##=(cls& x, cls y) { \ | |||||
x = x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
return x; \ | |||||
} | |||||
#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \ | |||||
return ~::megdnn::EnumClassBit<cls>(x); \ | |||||
} | |||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -251,6 +251,8 @@ protected: | |||||
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; | Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; | ||||
}; | }; | ||||
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(Algorithm::Attribute) | |||||
//! policy for executing the operator | //! policy for executing the operator | ||||
struct ExecutionPolicy { | struct ExecutionPolicy { | ||||
//! INVALID_ALGO_TYPE algo_type means using heuristic | //! INVALID_ALGO_TYPE algo_type means using heuristic | ||||
@@ -53,9 +53,13 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
e = self._enums[(p, e)] | e = self._enums[(p, e)] | ||||
self._write_doc(e.name) | self._write_doc(e.name) | ||||
self._write("enum %s%s : uint {", p, e.name, indent=1) | self._write("enum %s%s : uint {", p, e.name, indent=1) | ||||
for member in e.members: | |||||
for idx, member in enumerate(e.members): | |||||
self._write_doc(member) | self._write_doc(member) | ||||
self._write("%s,", scramble_enum_member_name(str(member))) | |||||
if e.combined: | |||||
self._write("%s=%d,", scramble_enum_member_name(str(member)), | |||||
1<<idx) | |||||
else: | |||||
self._write("%s,", scramble_enum_member_name(str(member))) | |||||
self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
def _write_doc(self, doc): | def _write_doc(self, doc): | ||||
@@ -80,13 +80,13 @@ class member_defs: | |||||
:attr member_alias: list of (member, alias) pairs | :attr member_alias: list of (member, alias) pairs | ||||
""" | """ | ||||
__slots__ = ['name', 'name_field', 'members', 'default', | __slots__ = ['name', 'name_field', 'members', 'default', | ||||
'member_alias'] | |||||
'member_alias', 'combined'] | |||||
all_enums = {} | all_enums = {} | ||||
"""(param_name, name) => enum""" | """(param_name, name) => enum""" | ||||
def __init__(self, param_name, name, name_field, members, default, | def __init__(self, param_name, name, name_field, members, default, | ||||
member_alias): | |||||
member_alias, combined = False): | |||||
name = member_defs.Doc.make(name) | name = member_defs.Doc.make(name) | ||||
assert name.id[0].isupper() | assert name.id[0].isupper() | ||||
members = tuple(map(member_defs.Doc.make, members)) | members = tuple(map(member_defs.Doc.make, members)) | ||||
@@ -97,6 +97,7 @@ class member_defs: | |||||
default = name_field.index(default) | default = name_field.index(default) | ||||
assert isinstance(default, int) | assert isinstance(default, int) | ||||
self.name = name | self.name = name | ||||
self.combined = combined | |||||
self.name_field = self.get_name_field(name.id, name_field) | self.name_field = self.get_name_field(name.id, name_field) | ||||
self.members = members | self.members = members | ||||
self.default = default | self.default = default | ||||
@@ -197,6 +198,12 @@ class ParamDef: | |||||
self.name.id, name, name_field, members, default, member_alias)) | self.name.id, name, name_field, members, default, member_alias)) | ||||
return self | return self | ||||
def add_bit_combination_enum(self, name, *members, default=0, | |||||
name_field=None, member_alias=[]): | |||||
self.members.append(member_defs.Enum( | |||||
self.name.id, name, name_field, members, default, member_alias, True)) | |||||
return self | |||||
def add_enum_alias(self, name, src_class, src_name=None, name_field=None, | def add_enum_alias(self, name, src_class, src_name=None, name_field=None, | ||||
default=None): | default=None): | ||||
self.members.append(member_defs.EnumAlias( | self.members.append(member_defs.EnumAlias( | ||||
@@ -463,8 +470,12 @@ class SerializedDType(_ParamDefBase): | |||||
for idx, emem in enumerate(e.members): | for idx, emem in enumerate(e.members): | ||||
self._write('%s = "%s"', emem, emem) | self._write('%s = "%s"', emem, emem) | ||||
self._write_doc(emem) | self._write_doc(emem) | ||||
self._enum_member2num.append('id({}.{}):{}'.format( | |||||
qualname, emem, idx)) | |||||
if e.combined: | |||||
self._enum_member2num.append('id({}.{}):{}'.format( | |||||
qualname, emem, 1<<idx)) | |||||
else: | |||||
self._enum_member2num.append('id({}.{}):{}'.format( | |||||
qualname, emem, idx)) | |||||
for emem, emem_alis in e.member_alias: | for emem, emem_alis in e.member_alias: | ||||
self._write('%s = %s', emem_alis, emem) | self._write('%s = %s', emem_alis, emem) | ||||
@@ -622,6 +633,8 @@ class CPPWriter(IndentWriterBase): | |||||
for idx, i in enumerate(e.members): | for idx, i in enumerate(e.members): | ||||
self._write_doc(i) | self._write_doc(i) | ||||
v = '{} = {}'.format(i, idx) | v = '{} = {}'.format(i, idx) | ||||
if e.combined: | |||||
v = '{} = 1 << {}'.format(i, idx) | |||||
if i is not e.members[-1] or e.member_alias: | if i is not e.members[-1] or e.member_alias: | ||||
v += ',' | v += ',' | ||||
self._write(v) | self._write(v) | ||||
@@ -672,7 +685,6 @@ class CPPEnumValueWriter(CPPWriter): | |||||
self._write('static const uint32_t %s = %s;', alias, mem) | self._write('static const uint32_t %s = %s;', alias, mem) | ||||
self._write('};', indent=-1) | self._write('};', indent=-1) | ||||
def _on_member_enum_alias(self, e): | def _on_member_enum_alias(self, e): | ||||
s = e.src_enum | s = e.src_enum | ||||
self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name) | self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name) | ||||
@@ -91,12 +91,17 @@ class ConverterWriter(IndentWriterBase): | |||||
def format(v): | def format(v): | ||||
return '\"{}\"'.format(str(v)) | return '\"{}\"'.format(str(v)) | ||||
enum_def += ','.join(format(i) for i in e.members) | enum_def += ','.join(format(i) for i in e.members) | ||||
enum_def += "]" | |||||
if e.combined: | |||||
enum_def += "], 1" | |||||
else: | |||||
enum_def += "], 0" | |||||
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | ||||
enum_def += ", 1" # whether generate ToStringTrait | enum_def += ", 1" # whether generate ToStringTrait | ||||
enum_def += ">" | enum_def += ">" | ||||
self._write("def {} : {};".format(td_class, enum_def)) | |||||
self._write("def {} : {};".format(td_class, enum_def)) | |||||
if self._skip_current_param: | if self._skip_current_param: | ||||
return | return | ||||
@@ -21,8 +21,6 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(AlgoAttribute) | |||||
#define MEGDNN_DECL_ALGO_TYPE(_type) \ | #define MEGDNN_DECL_ALGO_TYPE(_type) \ | ||||
uint32_t type() const override { \ | uint32_t type() const override { \ | ||||
return static_cast<std::underlying_type<AlgoType>::type>( \ | return static_cast<std::underlying_type<AlgoType>::type>( \ | ||||
@@ -692,61 +692,6 @@ inline void* get_origin_ptr(const TensorND* tensor, void* ptr) { | |||||
tensor->layout.span().low_byte); | tensor->layout.span().low_byte); | ||||
} | } | ||||
template <typename T> | |||||
class EnumClassBit { | |||||
std::underlying_type_t<T> m_val; | |||||
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | |||||
public: | |||||
constexpr EnumClassBit(T v) | |||||
: m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||||
constexpr operator T() const { return static_cast<T>(m_val); } | |||||
constexpr explicit operator bool() const { return m_val; } | |||||
#define DEF_OPR(op) \ | |||||
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \ | |||||
return m_val op rhs.m_val; \ | |||||
} | |||||
DEF_OPR(&) | |||||
DEF_OPR(|) | |||||
DEF_OPR (^) | |||||
constexpr EnumClassBit operator~() const { return ~m_val; } | |||||
#undef DEF_OPR | |||||
}; | |||||
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||||
return ::megdnn::EnumClassBit<cls>(x) \ | |||||
op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||||
::megdnn::EnumClassBit<cls> x, cls y) { \ | |||||
return x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} | |||||
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | |||||
inline constexpr cls& operator op##=(cls& x, cls y) { \ | |||||
x = x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
return x; \ | |||||
} | |||||
#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \ | |||||
return ~::megdnn::EnumClassBit<cls>(x); \ | |||||
} | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -218,4 +218,3 @@ public: | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -8,9 +8,12 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import os | import os | ||||
from ..core.ops import builtin | |||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..utils.deprecation import deprecated | from ..utils.deprecation import deprecated | ||||
Strategy = builtin.ops.Convolution.Strategy | |||||
_execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC") | _execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC") | ||||
if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: | if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: | ||||
@@ -19,7 +22,7 @@ if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: | |||||
) | ) | ||||
def get_execution_strategy() -> str: | |||||
def get_execution_strategy() -> Strategy: | |||||
""" | """ | ||||
Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' | Returns the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' | ||||
@@ -28,12 +31,22 @@ def get_execution_strategy() -> str: | |||||
return _execution_strategy | return _execution_strategy | ||||
def set_execution_strategy(option: str): | |||||
def set_execution_strategy(option): | |||||
""" | """ | ||||
Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' | Sets the execution strategy of :class:`~.Conv2d` and :func:'~.matmul' | ||||
:param option: Decides how :class:`~.Conv2d` and :func:'~.matmul' algorithms are chosen. | |||||
Available values: | |||||
:param option: Decides how :class:`~.Conv2d`and :func:'~.matmul' algorithms are chosen. | |||||
Available value Strategy | |||||
* HEURISTIC uses heuristic to choose the fastest algorithm. | |||||
* PROFILE runs possible algorithms on real device to find the best one. | |||||
* REPRODUCIBLE uses the algorithms that is reproducible. | |||||
* OPTMIZED uses the algorithms that is optimized. | |||||
The default strategy is HEURISTIC, this options can be combined to | |||||
form a combination option, e.g. PROFILE | REPRODUCIBLE | |||||
can combined a option that uses the fastest of profiling result that is also reproducible. | |||||
Available values string: | |||||
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. | * 'HEURISTIC' uses heuristic to choose the fastest algorithm. | ||||
* 'PROFILE' runs possible algorithms on real device to find the best one. | * 'PROFILE' runs possible algorithms on real device to find the best one. | ||||
@@ -45,18 +58,29 @@ def set_execution_strategy(option: str): | |||||
It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. | It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. | ||||
""" | """ | ||||
valid_option = ( | |||||
"HEURISTIC", | |||||
"PROFILE", | |||||
"PROFILE_HEURISTIC", | |||||
"PROFILE_REPRODUCIBLE", | |||||
"HEURISTIC_REPRODUCIBLE", | |||||
) | |||||
if not option in valid_option: | |||||
raise ValueError("Valid option can only be one of {}".format(valid_option)) | |||||
valid_string_option = { | |||||
"REPRODUCIBLE": Strategy.REPRODUCIBLE, | |||||
"HEURISTIC": Strategy.HEURISTIC, | |||||
"PROFILE": Strategy.PROFILE, | |||||
} | |||||
global _execution_strategy # pylint: disable=global-statement | global _execution_strategy # pylint: disable=global-statement | ||||
_execution_strategy = option | |||||
if isinstance(option, Strategy): | |||||
_execution_strategy = option | |||||
return | |||||
assert isinstance(option, str) | |||||
strategy_tmp = Strategy(0) | |||||
for opt in option.split("_"): | |||||
if not opt in valid_string_option: | |||||
raise ValueError( | |||||
"Valid option can only be one of {}, or combine them with '_'.".format( | |||||
valid_string_option.keys() | |||||
) | |||||
) | |||||
strategy_tmp = strategy_tmp | valid_string_option[opt] | |||||
_execution_strategy = strategy_tmp | |||||
@deprecated(version="1.3", reason="use get_execution_strategy() instead") | @deprecated(version="1.3", reason="use get_execution_strategy() instead") | ||||
@@ -19,6 +19,7 @@ import megengine.autodiff as ad | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import jit | from megengine import jit | ||||
from megengine.core._trace_option import set_symbolic_shape | from megengine.core._trace_option import set_symbolic_shape | ||||
from megengine.core.ops import builtin | |||||
from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
from megengine.functional.debug_param import set_execution_strategy | from megengine.functional.debug_param import set_execution_strategy | ||||
from megengine.jit import SublinearMemoryConfig | from megengine.jit import SublinearMemoryConfig | ||||
@@ -33,6 +34,8 @@ from megengine.module import ( | |||||
from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
Strategy = builtin.ops.Convolution.Strategy | |||||
def get_gpu_name(): | def get_gpu_name(): | ||||
try: | try: | ||||
@@ -242,7 +245,7 @@ def test_correctness(): | |||||
else: | else: | ||||
model_name = "mnist_model_with_test_cpu.mge" | model_name = "mnist_model_with_test_cpu.mge" | ||||
model_path = os.path.join(os.path.dirname(__file__), model_name) | model_path = os.path.join(os.path.dirname(__file__), model_name) | ||||
set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||||
set_execution_strategy(Strategy.HEURISTIC | Strategy.REPRODUCIBLE) | |||||
run_train(model_path, False, False, max_err=1e-5) | run_train(model_path, False, False, max_err=1e-5) | ||||
run_train(model_path, True, False, max_err=1e-5) | run_train(model_path, True, False, max_err=1e-5) | ||||
@@ -337,6 +337,20 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||||
className, attr->getEnumName(), i | className, attr->getEnumName(), i | ||||
)); | )); | ||||
} | } | ||||
if (attr->getEnumCombinedFlag()) { | |||||
//! define operator | | |||||
os << formatv( | |||||
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||||
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||||
"\n })", | |||||
className, attr->getEnumName()); | |||||
//! define operator & | |||||
os << formatv( | |||||
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||||
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||||
"\n })", | |||||
className, attr->getEnumName()); | |||||
} | |||||
os << formatv( | os << formatv( | ||||
"\n .def(py::init([](const std::string& in) {" | "\n .def(py::init([](const std::string& in) {" | ||||
"\n auto&& str = normalize_enum(in);" | "\n auto&& str = normalize_enum(in);" | ||||
@@ -77,6 +77,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||||
bool supportToString() const { | bool supportToString() const { | ||||
return getBaseRecord()->getValueAsBit("supportToString"); | return getBaseRecord()->getValueAsBit("supportToString"); | ||||
} | } | ||||
bool getEnumCombinedFlag() const { | |||||
return getBaseRecord()->getValueAsBit("enumCombined"); | |||||
} | |||||
}; | }; | ||||
struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | ||||
@@ -142,8 +142,16 @@ R"__usage__( | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
R"__usage__( | R"__usage__( | ||||
--fast-run | --fast-run | ||||
Enable fast-run mode. Operators with multiple algorithms would be profiled | |||||
on the real device with actual input shapes. | |||||
This param will be deperated later, please replace with param --full-profile. | |||||
--full-profile | |||||
Enable full-profile mode. Operators with multiple algorithms would be profiled | |||||
on the real device with actual input shapes, all algorithms will be profiled | |||||
include naive algorithms. | |||||
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | |||||
--fast-profile | |||||
Enable fast-profile mode. Operators with multiple algorithms would be profiled | |||||
on the real device with actual input shapes, this mode will only profile the | |||||
well optimized algorithms to get the profile result fast. | |||||
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | ||||
)__usage__" | )__usage__" | ||||
#endif | #endif | ||||
@@ -511,7 +519,8 @@ struct Args { | |||||
bool disable_assert_throw = false; | bool disable_assert_throw = false; | ||||
bool share_param_mem = false; | bool share_param_mem = false; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
bool use_fast_run = false; | |||||
bool use_full_profile = false; | |||||
bool use_fast_profile = false; | |||||
#endif | #endif | ||||
bool reproducible = false; | bool reproducible = false; | ||||
std::string fast_run_cache_path; | std::string fast_run_cache_path; | ||||
@@ -695,18 +704,20 @@ void run_test_st(Args &env) { | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
S strategy = S::HEURISTIC; | S strategy = S::HEURISTIC; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (env.use_fast_run) { | |||||
if (env.use_full_profile) { | |||||
if (env.reproducible) { | if (env.reproducible) { | ||||
strategy = S::PROFILE_REPRODUCIBLE; | |||||
strategy = S::PROFILE | S::REPRODUCIBLE; | |||||
} else { | } else { | ||||
strategy = S::PROFILE; | strategy = S::PROFILE; | ||||
} | } | ||||
} else if (env.use_fast_profile) { | |||||
strategy = S::PROFILE | S::OPTMIZED; | |||||
} else if (env.reproducible) { | } else if (env.reproducible) { | ||||
strategy = S::HEURISTIC_REPRODUCIBLE; | |||||
strategy = S::HEURISTIC | S::REPRODUCIBLE; | |||||
} | } | ||||
#else | #else | ||||
if (env.reproducible) { | if (env.reproducible) { | ||||
strategy = S::HEURISTIC_REPRODUCIBLE; | |||||
strategy = S::HEURISTIC | S::REPRODUCIBLE; | |||||
} | } | ||||
#endif | #endif | ||||
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); | mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); | ||||
@@ -729,11 +740,12 @@ void run_test_st(Args &env) { | |||||
std::make_shared<InFilePersistentCache>(buf.get(), flen)); | std::make_shared<InFilePersistentCache>(buf.get(), flen)); | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
} else { | } else { | ||||
mgb_assert(env.use_fast_run, "fast-run should be enabled"); | |||||
mgb_assert(env.use_full_profile || env.use_fast_profile, | |||||
"fast-run or fast-profile should be enabled"); | |||||
PersistentCache::set_impl( | PersistentCache::set_impl( | ||||
std::make_shared<InFilePersistentCache>()); | std::make_shared<InFilePersistentCache>()); | ||||
} | } | ||||
if (!env.use_fast_run) | |||||
if (!env.use_full_profile && !env.use_fast_profile) | |||||
#endif | #endif | ||||
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | ||||
} | } | ||||
@@ -1314,7 +1326,18 @@ Args Args::from_argv(int argc, char **argv) { | |||||
} | } | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (!strcmp(argv[i], "--fast-run")) { | if (!strcmp(argv[i], "--fast-run")) { | ||||
ret.use_fast_run = true; | |||||
mgb_log_warn( | |||||
"--fast-run param will be deperated later, please replace " | |||||
"with --full-profile or --fast-profile."); | |||||
ret.use_full_profile = true; | |||||
continue; | |||||
} | |||||
if (!strcmp(argv[i], "--full-profile")) { | |||||
ret.use_full_profile = true; | |||||
continue; | |||||
} | |||||
if (!strcmp(argv[i], "--fast-profile")) { | |||||
ret.use_fast_profile = true; | |||||
continue; | continue; | ||||
} | } | ||||
#endif | #endif | ||||
@@ -188,7 +188,7 @@ AlgoChooserProfileCache::get(const Key &key) { | |||||
auto entry_len = read_uint32(); | auto entry_len = read_uint32(); | ||||
mgb_assert(buf + entry_len <= buf_end); | mgb_assert(buf + entry_len <= buf_end); | ||||
auto nr = sscanf(reinterpret_cast<const char*>(buf), ENTRY_FMT, | auto nr = sscanf(reinterpret_cast<const char*>(buf), ENTRY_FMT, | ||||
&i.reproducible, &i.time, &i.workspace); | |||||
&i.attribute, &i.time, &i.workspace); | |||||
mgb_assert(nr == 3); | mgb_assert(nr == 3); | ||||
buf += entry_len; | buf += entry_len; | ||||
} | } | ||||
@@ -210,10 +210,10 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) { | |||||
auto &&cur = result[i]; | auto &&cur = result[i]; | ||||
if (prev.workspace <= cur.workspace && | if (prev.workspace <= cur.workspace && | ||||
prev.reproducible == cur.reproducible) { | |||||
prev.attribute == cur.attribute) { | |||||
result.erase(result.begin() + i); | result.erase(result.begin() + i); | ||||
} else { | } else { | ||||
++ i; | |||||
++i; | |||||
} | } | ||||
} | } | ||||
@@ -235,8 +235,8 @@ void AlgoChooserProfileCache::put(const Key &key, Result &result) { | |||||
write_uint32(0); | write_uint32(0); | ||||
pos = val.size(); | pos = val.size(); | ||||
val.resize(pos + SPR_SIZE); | val.resize(pos + SPR_SIZE); | ||||
uint32_t nr = snprintf(&val[pos], SPR_SIZE, | |||||
ENTRY_FMT, i.reproducible, i.time, i.workspace); | |||||
uint32_t nr = snprintf(&val[pos], SPR_SIZE, ENTRY_FMT, i.attribute, | |||||
i.time, i.workspace); | |||||
//! for memory boundary failed, snprintf ret do not contain \0 | //! for memory boundary failed, snprintf ret do not contain \0 | ||||
nr += 1; | nr += 1; | ||||
mgb_assert(nr < SPR_SIZE); | mgb_assert(nr < SPR_SIZE); | ||||
@@ -12,6 +12,8 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#include "megbrain/opr/param_defs.h" | |||||
#include "megdnn/basic_types.h" | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
@@ -242,6 +244,16 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { | |||||
return n; | return n; | ||||
} | } | ||||
#endif | #endif | ||||
#define MGB_DEF_ENUM_CLASS_BIT_OPR(cls) \ | |||||
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) | |||||
} // namespace mgb | } // namespace mgb | ||||
namespace megdnn { | |||||
namespace param { | |||||
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,7 +12,6 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/utils/hash.h" | #include "megbrain/utils/hash.h" | ||||
#include "megbrain/utils/enum_class_bit.h" | |||||
#include "megbrain/utils/metahelper.h" | #include "megbrain/utils/metahelper.h" | ||||
#include "megbrain/utils/thin/hash_table.h" | #include "megbrain/utils/thin/hash_table.h" | ||||
#include "megbrain/utils/thread.h" | #include "megbrain/utils/thread.h" | ||||
@@ -16,7 +16,6 @@ | |||||
#include "megbrain/graph/symbol_var.h" | #include "megbrain/graph/symbol_var.h" | ||||
#include "megbrain/utils/hashable.h" | #include "megbrain/utils/hashable.h" | ||||
#include "megbrain/utils/enum_class_bit.h" | |||||
#include "megbrain/utils/thin/hash_table.h" | #include "megbrain/utils/thin/hash_table.h" | ||||
#include "megbrain/utils/small_vector.h" | #include "megbrain/utils/small_vector.h" | ||||
@@ -12,7 +12,6 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/graph/bases.h" | #include "megbrain/graph/bases.h" | ||||
#include "megbrain/utils/enum_class_bit.h" | |||||
#include "megbrain/utils/comp_node_sync_manager.h" | #include "megbrain/utils/comp_node_sync_manager.h" | ||||
#include "megbrain/utils/small_vector.h" | #include "megbrain/utils/small_vector.h" | ||||
#include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
@@ -33,10 +33,11 @@ class MgbHashableAttrMixin { | |||||
string reprFunction = "std::to_string($0)"; | string reprFunction = "std::to_string($0)"; | ||||
} | } | ||||
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> { | |||||
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit combined, bit toString> { | |||||
string parentNamespace = namespace; | string parentNamespace = namespace; | ||||
string enumName = name; | string enumName = name; | ||||
list<string> enumMembers = members; | list<string> enumMembers = members; | ||||
bit enumCombined = combined; | |||||
bit supportToString = toString; | bit supportToString = toString; | ||||
} | } | ||||
@@ -166,8 +167,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>: | |||||
} | } | ||||
// -- enum types | // -- enum types | ||||
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>: | |||||
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> { | |||||
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit combined, bit toString=0>: | |||||
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, combined, toString> { | |||||
let storageType = "::mlir::IntegerAttr"; | let storageType = "::mlir::IntegerAttr"; | ||||
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | ||||
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | ||||
@@ -176,7 +177,7 @@ class MgbEnumAttr<string namespace, string enumName, list<string> members, bit t | |||||
} | } | ||||
class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: | class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: | ||||
MgbEnumAttr<namespace, enumName, base.enumMembers>, MgbAliasAttrMixin<base>; | |||||
MgbEnumAttr<namespace, enumName, base.enumMembers, 0>, MgbAliasAttrMixin<base>; | |||||
// -- other types | // -- other types | ||||
def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { | def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { | ||||
@@ -1,89 +0,0 @@ | |||||
/** | |||||
* \file src/core/include/megbrain/utils/enum_class_bit.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <type_traits> | |||||
namespace mgb { | |||||
template<typename T> | |||||
class EnumClassBit { | |||||
std::underlying_type_t<T> m_val; | |||||
constexpr EnumClassBit(std::underlying_type_t<T> v): | |||||
m_val(v) | |||||
{ | |||||
} | |||||
public: | |||||
constexpr EnumClassBit(T v): | |||||
m_val(static_cast<std::underlying_type_t<T>>(v)) | |||||
{ | |||||
} | |||||
constexpr operator T() const { | |||||
return static_cast<T>(m_val); | |||||
} | |||||
constexpr explicit operator bool() const { | |||||
return m_val; | |||||
} | |||||
#define DEF_OPR(op) \ | |||||
constexpr EnumClassBit operator op (\ | |||||
const EnumClassBit &rhs) const { \ | |||||
return m_val op rhs.m_val; \ | |||||
} | |||||
DEF_OPR(&) | |||||
DEF_OPR(|) | |||||
DEF_OPR(^) | |||||
constexpr EnumClassBit operator ~() const { | |||||
return ~m_val; | |||||
} | |||||
#undef DEF_OPR | |||||
}; | |||||
} | |||||
#define _MGB_DECBO_SINGLE_OPR(cls, op) \ | |||||
inline constexpr ::mgb::EnumClassBit<cls> operator op (cls x, cls y) { \ | |||||
return ::mgb::EnumClassBit<cls>(x) op ::mgb::EnumClassBit<cls>(y); \ | |||||
} \ | |||||
inline constexpr ::mgb::EnumClassBit<cls> operator op ( \ | |||||
::mgb::EnumClassBit<cls> x, cls y) { \ | |||||
return x op ::mgb::EnumClassBit<cls>(y); \ | |||||
} | |||||
#define _MGB_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | |||||
inline constexpr cls& operator op##= (cls& x, cls y) { \ | |||||
x = x op ::mgb::EnumClassBit<cls>(y); \ | |||||
return x; \ | |||||
} | |||||
#define MGB_DEF_ENUM_CLASS_BIT_OPR(cls) \ | |||||
_MGB_DECBO_SINGLE_OPR(cls, &) \ | |||||
_MGB_DECBO_SINGLE_OPR(cls, |) \ | |||||
_MGB_DECBO_SINGLE_OPR(cls, ^) \ | |||||
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, &) \ | |||||
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, |) \ | |||||
_MGB_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \ | |||||
inline constexpr ::mgb::EnumClassBit<cls> operator ~ (cls x) { \ | |||||
return ~::mgb::EnumClassBit<cls>(x); \ | |||||
} \ | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
@@ -100,8 +100,7 @@ namespace mgb { | |||||
struct ResultEntry { | struct ResultEntry { | ||||
std::string algo; //! identifier of the algorithm | std::string algo; //! identifier of the algorithm | ||||
//! sscanf will up bool as int | |||||
int reproducible; //! whether algorithm is reproducible | |||||
uint32_t attribute; //! algo attribute, e.g. reproducible | |||||
double time; //! execution time in seconds | double time; //! execution time in seconds | ||||
size_t workspace; //! workspace in bytes | size_t workspace; //! workspace in bytes | ||||
}; | }; | ||||
@@ -54,7 +54,6 @@ using namespace gopt; | |||||
namespace { | namespace { | ||||
template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder> | template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder> | ||||
void param_merge(OptState& opt_state) { | void param_merge(OptState& opt_state) { | ||||
auto rewriter = opt_state.graph().make_rewriter(); | auto rewriter = opt_state.graph().make_rewriter(); | ||||
@@ -102,7 +101,7 @@ void param_merge(OptState& opt_state) { | |||||
rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
} | } | ||||
} | |||||
} // namespace | |||||
/* ================ global functions ================ */ | /* ================ global functions ================ */ | ||||
@@ -190,12 +189,10 @@ void gopt::enable_opr_algo_profiling_inplace( | |||||
void gopt::enable_opr_use_profiling_cache_inplace( | void gopt::enable_opr_use_profiling_cache_inplace( | ||||
const VarNodeArrayView& dest_vars) { | const VarNodeArrayView& dest_vars) { | ||||
modify_opr_algo_strategy_inplace( | |||||
dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy:: | |||||
Strategy::PROFILE_HEURISTIC); | |||||
using S = megdnn::param::ExecutionPolicy::Strategy; | |||||
modify_opr_algo_strategy_inplace(dest_vars, S::PROFILE | S::HEURISTIC); | |||||
} | } | ||||
void gopt::set_opr_algo_workspace_limit_inplace( | void gopt::set_opr_algo_workspace_limit_inplace( | ||||
const VarNodeArrayView& dest_vars, size_t workspace_limit) { | const VarNodeArrayView& dest_vars, size_t workspace_limit) { | ||||
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> | static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> | ||||
@@ -1693,7 +1693,22 @@ TEST(TestGoptInference, ProfileCache) { | |||||
using S = opr::Convolution::ExecutionPolicy::Strategy; | using S = opr::Convolution::ExecutionPolicy::Strategy; | ||||
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); | ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); | ||||
gopt::enable_opr_use_profiling_cache_inplace({z + 2.3f}); | gopt::enable_opr_use_profiling_cache_inplace({z + 2.3f}); | ||||
ASSERT_EQ(S::PROFILE_HEURISTIC, conv.execution_policy().strategy); | |||||
ASSERT_EQ(S::PROFILE | S::HEURISTIC, conv.execution_policy().strategy); | |||||
} | |||||
TEST(TestGoptInference, FastProfileCache) { | |||||
HostTensorGenerator<> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
auto host_x = gen({4, 3, 8, 9}), host_y = gen({2, 3, 3, 3}); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
y = opr::Host2DeviceCopy::make(*graph, host_y), | |||||
z = opr::Convolution::make(x, y); | |||||
auto&& conv = z.node()->owner_opr()->cast_final_safe<opr::Convolution>(); | |||||
using S = opr::Convolution::ExecutionPolicy::Strategy; | |||||
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); | |||||
gopt::modify_opr_algo_strategy_inplace({z + 2.3f}, | |||||
S::PROFILE | S::OPTMIZED); | |||||
ASSERT_EQ(S::PROFILE | S::OPTMIZED, conv.execution_policy().strategy); | |||||
} | } | ||||
TEST(TestGoptInference, AlgoWorkspaceLimit) { | TEST(TestGoptInference, AlgoWorkspaceLimit) { | ||||
@@ -20,7 +20,6 @@ | |||||
#include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
#include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
#include "megbrain/opr/dnn/tqt.h" | #include "megbrain/opr/dnn/tqt.h" | ||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
@@ -284,8 +284,9 @@ namespace mgb { | |||||
namespace opr { | namespace opr { | ||||
template <typename Opr> | template <typename Opr> | ||||
void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { | |||||
if (ctx.get_profile_result_from_cache(require_reproducible).valid()) | |||||
void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
ExecutionStrategy select_strategy) { | |||||
if (ctx.get_profile_result_from_cache(select_strategy).valid()) | |||||
return; | return; | ||||
AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
@@ -305,7 +306,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { | |||||
algo.name.c_str(), str_on_inp_shape.c_str()); | algo.name.c_str(), str_on_inp_shape.c_str()); | ||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
policy.algo = algo.desc; | policy.algo = algo.desc; | ||||
ctx.construct_execution_policy(require_reproducible, policy); | |||||
ctx.construct_execution_policy(select_strategy, policy); | |||||
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) | if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) | ||||
continue; | continue; | ||||
@@ -354,7 +355,8 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy | typename AlgoChooser<Opr>::ImplExecutionPolicy | ||||
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible, | |||||
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||||
ExecutionStrategy select_strategy, | |||||
bool enable_update) { | bool enable_update) { | ||||
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | ||||
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | ||||
@@ -376,11 +378,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible, | |||||
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | ||||
_item.param, ctx.mgb_opr(), ctx.comp_node(), | _item.param, ctx.mgb_opr(), ctx.comp_node(), | ||||
ctx.execution_policy(), ctx.allow_weight_preprocess()); | ctx.execution_policy(), ctx.allow_weight_preprocess()); | ||||
AlgoChooser<_Opr>::profile(sub_ctx, require_reproducible); | |||||
AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); | |||||
}); | }); | ||||
} | } | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | ||||
ctx.construct_execution_policy(require_reproducible, policy); | |||||
ctx.construct_execution_policy(select_strategy, policy); | |||||
return policy; | return policy; | ||||
MIDOUT_E | MIDOUT_E | ||||
} | } | ||||
@@ -402,11 +404,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, | |||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
if (auto algo_choose_hook = mgb_opr->algo_chooser()) { | if (auto algo_choose_hook = mgb_opr->algo_chooser()) { | ||||
policy = algo_choose_hook(mgb_opr); | policy = algo_choose_hook(mgb_opr); | ||||
ctx.construct_execution_policy( | |||||
mgb_opr->execution_policy().strategy == | |||||
mixin::AlgoChooserHelper::ExecutionPolicy::Strategy:: | |||||
HEURISTIC_REPRODUCIBLE, | |||||
policy, false); | |||||
ctx.construct_execution_policy((ExecutionStrategy::HEURISTIC | | |||||
ExecutionStrategy::REPRODUCIBLE), | |||||
policy, false); | |||||
} | } | ||||
if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
policy = get_policy(ctx); | policy = get_policy(ctx); | ||||
@@ -419,10 +419,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, | |||||
Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); | Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); | ||||
mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
ret.append("): algo=" + std::string(palgo->name())); | ret.append("): algo=" + std::string(palgo->name())); | ||||
ret.append(ssprintf(" workspace=%.2fMiB reproducible=%d", | |||||
ret.append(ssprintf(" workspace=%.2fMiB attirbute=%d", | |||||
workspace / (1024 * 1024.0), | workspace / (1024 * 1024.0), | ||||
palgo->contain_attribute( | |||||
megdnn::AlgoAttribute::REPRODUCIBLE))); | |||||
static_cast<uint32_t>(palgo->attribute()))); | |||||
mgb_log_debug("%s", ret.c_str()); | mgb_log_debug("%s", ret.c_str()); | ||||
megdnn_opr->execution_policy() = policy; | megdnn_opr->execution_policy() = policy; | ||||
@@ -432,41 +431,39 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | ||||
ExeContext& ctx) { | ExeContext& ctx) { | ||||
using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||||
MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); | MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); | ||||
switch (ctx.execution_policy().strategy) { | |||||
case S::HEURISTIC: | |||||
return ctx.choose_by_heuristic(); | |||||
case S::HEURISTIC_REPRODUCIBLE: | |||||
return ctx.choose_by_heuristic(true); | |||||
case S::PROFILE_HEURISTIC: { | |||||
ImplExecutionPolicy policy = choose_by_profile(ctx, false, false); | |||||
if (!policy.algo.valid()) | |||||
policy = ctx.choose_by_heuristic(); | |||||
return policy; | |||||
} | |||||
auto opr_strategy = ctx.execution_policy().strategy; | |||||
if ((opr_strategy & ExecutionStrategy::HEURISTIC) && | |||||
(opr_strategy & ExecutionStrategy::PROFILE)) { | |||||
ImplExecutionPolicy policy = | |||||
choose_by_profile(ctx, opr_strategy, false); | |||||
if (!policy.algo.valid()) | |||||
policy = ctx.choose_by_heuristic(opr_strategy); | |||||
return policy; | |||||
} else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||||
return ctx.choose_by_heuristic(opr_strategy); | |||||
} | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
case S::PROFILE: | |||||
return choose_by_profile(ctx, false); | |||||
case S::PROFILE_REPRODUCIBLE: | |||||
return choose_by_profile(ctx, true); | |||||
else if (opr_strategy & ExecutionStrategy::PROFILE) { | |||||
return choose_by_profile(ctx, opr_strategy); | |||||
} | |||||
#endif | #endif | ||||
default: | |||||
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); | |||||
else { | |||||
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); | |||||
} | } | ||||
} | } | ||||
#define INST(Opr) \ | |||||
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ | |||||
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \ | |||||
template void AlgoChooser<megdnn::Opr>::profile( \ | |||||
ExeContext& ctx, bool require_reproducible); \ | |||||
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ | |||||
AlgoChooser<megdnn::Opr>::choose_by_profile( \ | |||||
ExeContext& ctx, bool require_reproducible, bool enable_update); \ | |||||
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \ | |||||
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | |||||
const MGBOpr* mgb_opr, bool allow_weight_preprocess); \ | |||||
#define INST(Opr) \ | |||||
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ | |||||
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \ | |||||
template void AlgoChooser<megdnn::Opr>::profile(ExeContext& ctx, \ | |||||
ExecutionStrategy); \ | |||||
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ | |||||
AlgoChooser<megdnn::Opr>::choose_by_profile( \ | |||||
ExeContext& ctx, ExecutionStrategy, bool enable_update); \ | |||||
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \ | |||||
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | |||||
const MGBOpr* mgb_opr, bool allow_weight_preprocess); | |||||
MGB_FOREACH_FASTRUN_OPR(INST) | MGB_FOREACH_FASTRUN_OPR(INST) | ||||
@@ -498,7 +495,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext( | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplAlgo | typename AlgoChooser<Opr>::ImplAlgo | ||||
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | ||||
bool require_reproducible) const { | |||||
ExecutionStrategy select_strategy) const { | |||||
MIDOUT_B(Opr, | MIDOUT_B(Opr, | ||||
midout_iv(MGB_HASH_STR( | midout_iv(MGB_HASH_STR( | ||||
"AlgoChooser::ExeContext::get_profile_result_from_cache"))) | "AlgoChooser::ExeContext::get_profile_result_from_cache"))) | ||||
@@ -522,7 +519,9 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
if (prof.empty()) | if (prof.empty()) | ||||
return {}; | return {}; | ||||
for (auto&& i : prof) { | for (auto&& i : prof) { | ||||
if ((!require_reproducible || i.reproducible)) { | |||||
if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||||
static_cast<AlgoAttribute>(i.attribute) & | |||||
AlgoAttribute::REPRODUCIBLE) { | |||||
auto iter = algo_map.find(i.algo); | auto iter = algo_map.find(i.algo); | ||||
mgb_assert(iter != algo_map.end(), | mgb_assert(iter != algo_map.end(), | ||||
"algorithm %s exists in " | "algorithm %s exists in " | ||||
@@ -550,7 +549,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy | typename AlgoChooser<Opr>::ImplExecutionPolicy | ||||
AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { | |||||
AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
ExecutionStrategy select_strategy) const { | |||||
if (m_execution_policy.workspace_limit != | if (m_execution_policy.workspace_limit != | ||||
std::numeric_limits<decltype( | std::numeric_limits<decltype( | ||||
m_execution_policy.workspace_limit)>::max()) { | m_execution_policy.workspace_limit)>::max()) { | ||||
@@ -558,6 +558,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { | |||||
"workspace_limit should not be setted if choose algo by " | "workspace_limit should not be setted if choose algo by " | ||||
"heuristic"); | "heuristic"); | ||||
} | } | ||||
bool reproducible = static_cast<bool>(select_strategy & | |||||
ExecutionStrategy::REPRODUCIBLE); | |||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
@@ -579,7 +581,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { | |||||
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | ||||
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, | _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | ||||
m_allow_weight_preprocess); | m_allow_weight_preprocess); | ||||
policy.sub_policy.push_back(sub_ctx.choose_by_heuristic(reproducible)); | |||||
policy.sub_policy.push_back( | |||||
sub_ctx.choose_by_heuristic(select_strategy)); | |||||
}); | }); | ||||
return policy; | return policy; | ||||
@@ -588,9 +591,8 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { | |||||
template <typename Opr> | template <typename Opr> | ||||
std::vector<typename AlgoChooser<Opr>::ImplAlgo> | std::vector<typename AlgoChooser<Opr>::ImplAlgo> | ||||
AlgoChooser<Opr>::ExeContext::get_all_candidates() const { | AlgoChooser<Opr>::ExeContext::get_all_candidates() const { | ||||
auto heu = choose_by_heuristic(); | |||||
auto&& ret = | |||||
APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts); | |||||
auto heu = choose_by_heuristic(ExecutionStrategy::HEURISTIC); | |||||
auto&& ret = APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts); | |||||
bool found = false; | bool found = false; | ||||
for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
if (ret[i].desc == heu.algo) { | if (ret[i].desc == heu.algo) { | ||||
@@ -611,19 +613,21 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const { | |||||
template <typename Opr> | template <typename Opr> | ||||
void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | ||||
bool require_reproducible, | |||||
ExecutionStrategy select_strategy, | |||||
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | ||||
bool retrive_from_cache) const { | bool retrive_from_cache) const { | ||||
bool reproducible = static_cast<bool>(select_strategy & | |||||
ExecutionStrategy::REPRODUCIBLE); | |||||
if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
policy.algo = | policy.algo = | ||||
get_profile_result_from_cache(require_reproducible).desc; | |||||
get_profile_result_from_cache(select_strategy).desc; | |||||
} else { | } else { | ||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | ||||
args..., workspace_limit, | args..., workspace_limit, | ||||
require_reproducible), | |||||
reproducible), | |||||
m_layouts) | m_layouts) | ||||
.desc; | .desc; | ||||
} | } | ||||
@@ -647,7 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, | _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | ||||
m_allow_weight_preprocess); | m_allow_weight_preprocess); | ||||
policy.sub_policy.push_back({}); | policy.sub_policy.push_back({}); | ||||
sub_ctx.construct_execution_policy(require_reproducible, | |||||
sub_ctx.construct_execution_policy(select_strategy, | |||||
policy.sub_policy.back(), | policy.sub_policy.back(), | ||||
retrive_from_cache); | retrive_from_cache); | ||||
}); | }); | ||||
@@ -718,8 +722,7 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo( | |||||
return None; | return None; | ||||
return AlgoChooserProfileCache::ResultEntry{ | return AlgoChooserProfileCache::ResultEntry{ | ||||
palgo->name(), | palgo->name(), | ||||
palgo->contain_attribute( | |||||
megdnn::AlgoAttribute::REPRODUCIBLE), | |||||
static_cast<uint32_t>(palgo->attribute()), | |||||
rst.val().time, param.workspace}; | rst.val().time, param.workspace}; | ||||
} | } | ||||
@@ -768,10 +771,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||||
bool allow_weight_preprocess); \ | bool allow_weight_preprocess); \ | ||||
template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ | template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ | ||||
AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \ | AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \ | ||||
bool reproducible) const; \ | |||||
ExecutionStrategy select_strategy) const; \ | |||||
template typename AlgoChooser<megdnn::Opr>::ImplAlgo \ | template typename AlgoChooser<megdnn::Opr>::ImplAlgo \ | ||||
AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \ | AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \ | ||||
bool require_reproducible) const; \ | |||||
ExecutionStrategy select_strategy) const; \ | |||||
template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \ | template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \ | ||||
AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \ | AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \ | ||||
template size_t \ | template size_t \ | ||||
@@ -780,7 +783,7 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||||
policy) const; \ | policy) const; \ | ||||
template void \ | template void \ | ||||
AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \ | AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \ | ||||
bool require_reproducible, \ | |||||
ExecutionStrategy select_strategy, \ | |||||
typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \ | typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \ | ||||
bool retrive_from_cache) const; \ | bool retrive_from_cache) const; \ | ||||
template Maybe<AlgoChooserProfileCache::ResultEntry> \ | template Maybe<AlgoChooserProfileCache::ResultEntry> \ | ||||
@@ -35,6 +35,13 @@ MGB_FOREACH_FASTRUN_OPR(cb) | |||||
#undef cb | #undef cb | ||||
namespace mgb { | namespace mgb { | ||||
//! define logical operation of megdnn::param::ExecutionPolicy::Strategy::Enum | |||||
//! and megdnn::detail::AlgoAttribute enum | |||||
using ExecutionStrategy = megdnn::param::ExecutionPolicy::Strategy; | |||||
using AlgoAttribute = megdnn::AlgoAttribute; | |||||
namespace opr { | namespace opr { | ||||
/* =================== AlgoChooser =================== */ | /* =================== AlgoChooser =================== */ | ||||
@@ -103,7 +110,7 @@ public: | |||||
const FixedTensorLayouts& layouts() const { return m_layouts; } | const FixedTensorLayouts& layouts() const { return m_layouts; } | ||||
ImplExecutionPolicy choose_by_heuristic( | ImplExecutionPolicy choose_by_heuristic( | ||||
bool reproducible = false) const; | |||||
ExecutionStrategy select_strategy) const; | |||||
//! get all candidate algos, and the one choose_by_heuristic() is | //! get all candidate algos, and the one choose_by_heuristic() is | ||||
//! put first | //! put first | ||||
@@ -126,19 +133,20 @@ public: | |||||
const ImplExecutionPolicy& policy, double& timeout) const; | const ImplExecutionPolicy& policy, double& timeout) const; | ||||
//! get all profile algorithm from cache, return invalid if not exists | //! get all profile algorithm from cache, return invalid if not exists | ||||
ImplAlgo get_profile_result_from_cache(bool require_reproducible) const; | |||||
ImplAlgo get_profile_result_from_cache( | |||||
ExecutionStrategy select_strategy) const; | |||||
/** | /** | ||||
* \brief construct execution policy from cache or heuristic. | * \brief construct execution policy from cache or heuristic. | ||||
* | * | ||||
* \param require_reproducible select algo which is reproducible | |||||
* \param select_strategy select algo which matched this strategy | |||||
* \param policy execution policy | * \param policy execution policy | ||||
* \param retrive_from_cache retrive algo from cache if set True, get | * \param retrive_from_cache retrive algo from cache if set True, get | ||||
* from heuristic otherwise. | * from heuristic otherwise. | ||||
*/ | */ | ||||
void construct_execution_policy( | |||||
bool require_reproducible, ImplExecutionPolicy& policy, | |||||
bool retrive_from_cache = true) const; | |||||
void construct_execution_policy(ExecutionStrategy select_strategy, | |||||
ImplExecutionPolicy& policy, | |||||
bool retrive_from_cache = true) const; | |||||
private: | private: | ||||
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | ||||
@@ -153,11 +161,11 @@ private: | |||||
//! profile and save to cache | //! profile and save to cache | ||||
static void profile(ExeContext& ctx, bool require_reproducible); | |||||
static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); | |||||
static ImplExecutionPolicy choose_by_profile(ExeContext& ctx, | |||||
bool require_reproducible, | |||||
bool enable_update = true); | |||||
static ImplExecutionPolicy choose_by_profile( | |||||
ExeContext& ctx, ExecutionStrategy select_strategy, | |||||
bool enable_update = true); | |||||
public: | public: | ||||
/*! | /*! | ||||
@@ -13,7 +13,6 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/graph/operator_node.h" | #include "megbrain/graph/operator_node.h" | ||||
#include "megbrain/opr/param_defs.h" | |||||
#include "megdnn/oprs/base.h" | #include "megdnn/oprs/base.h" | ||||
#include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
@@ -73,7 +72,6 @@ protected: | |||||
}; | }; | ||||
} // namespace mixin | } // namespace mixin | ||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -429,10 +429,11 @@ TEST(TestOprDNN, MatrixMulExePolicy) { | |||||
auto cn = CompNode::load("cpux"); | auto cn = CompNode::load("cpux"); | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||||
S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||||
S::PROFILE | S::HEURISTIC}) { | |||||
#else | #else | ||||
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy: {S:HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||||
#endif | #endif | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -355,11 +355,13 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||||
auto cn = CompNode::load("cpux"); | auto cn = CompNode::load("cpux"); | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy: {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
#else | #else | ||||
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||||
#endif | #endif | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
@@ -397,7 +399,8 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | |||||
auto cn = CompNode::load("cpux"); | auto cn = CompNode::load("cpux"); | ||||
for (auto strategy: {S::PROFILE, S::PROFILE_REPRODUCIBLE}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
@@ -439,10 +442,12 @@ TEST(TestOprDNN, ConvolutionExePolicy) { | |||||
PersistentCacheHook cache_hook{on_get}; | PersistentCacheHook cache_hook{on_get}; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||||
S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
#else | #else | ||||
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||||
#endif | #endif | ||||
using Checker = AutoOprChecker<2, 1>; | using Checker = AutoOprChecker<2, 1>; | ||||
@@ -522,10 +527,11 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) { | |||||
PersistentCacheHook cache_hook{on_get}; | PersistentCacheHook cache_hook{on_get}; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||||
S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
{S::PROFILE, S::HEURISTIC, S(S::PROFILE | S::REPRODUCIBLE), | |||||
S(S::PROFILE | S::HEURISTIC)}) { | |||||
#else | #else | ||||
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy: {S:HEURISTIC, S(S::PROFILE | S::HEURISTIC)}) { | |||||
#endif | #endif | ||||
using Checker = AutoOprChecker<2, 1>; | using Checker = AutoOprChecker<2, 1>; | ||||
@@ -1183,9 +1189,12 @@ TEST(TestOprDNN, Convolution3DExePolicy) { | |||||
using S = Policy::Strategy; | using S = Policy::Strategy; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy: {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||||
S::PROFILE | S::HEURISTIC}) { | |||||
#else | #else | ||||
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||||
#endif | #endif | ||||
using Checker = AutoOprChecker<2, 1>; | using Checker = AutoOprChecker<2, 1>; | ||||
@@ -1660,10 +1669,12 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) { | |||||
PersistentCacheHook cache_hook{on_get}; | PersistentCacheHook cache_hook{on_get}; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||||
S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
#else | #else | ||||
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||||
#endif | #endif | ||||
auto make_graph = [&](const Checker::SymInpArray& inputs) | auto make_graph = [&](const Checker::SymInpArray& inputs) | ||||
-> Checker::SymOutArray { | -> Checker::SymOutArray { | ||||
@@ -1769,10 +1780,12 @@ TEST(TestOprDNN, DeformableConvForward) { | |||||
Param param; | Param param; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||||
S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
#else | #else | ||||
for (auto strategy : {S : HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||||
#endif | #endif | ||||
auto make_graph = [&](const Checker::SymInpArray& inputs) | auto make_graph = [&](const Checker::SymInpArray& inputs) | ||||
-> Checker::SymOutArray { | -> Checker::SymOutArray { | ||||
@@ -1936,10 +1949,12 @@ TEST(TestOprDNN, BatchConvBiasForward) { | |||||
param.sparse = Param::Sparse::DENSE; | param.sparse = Param::Sparse::DENSE; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||||
S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||||
#else | #else | ||||
for (auto strategy : {S : HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||||
#endif | #endif | ||||
auto make_quantized = [&](SymbolVar x, const DType& dtype) { | auto make_quantized = [&](SymbolVar x, const DType& dtype) { | ||||
@@ -2080,7 +2095,8 @@ TEST(TestOprDNN, HeuristicReproducible) { | |||||
constexpr size_t PH = 1, PW = 1, SH = 1, SW = 1; | constexpr size_t PH = 1, PW = 1, SH = 1, SW = 1; | ||||
for (auto strategy : {S::HEURISTIC, S::HEURISTIC_REPRODUCIBLE}) { | |||||
for (auto strategy : | |||||
SmallVector<S>{S::HEURISTIC, S::HEURISTIC | S::REPRODUCIBLE}) { | |||||
VarNode* bwd_flt; | VarNode* bwd_flt; | ||||
auto make_graph = [&](const Checker::SymInpArray& inputs) | auto make_graph = [&](const Checker::SymInpArray& inputs) | ||||
-> Checker::SymOutArray { | -> Checker::SymOutArray { | ||||
@@ -2126,7 +2142,7 @@ TEST(TestOprDNN, HeuristicReproducible) { | |||||
megdnn::Algorithm* palgo = | megdnn::Algorithm* palgo = | ||||
megdnn_opr->get_algorithm_from_desc(algo); | megdnn_opr->get_algorithm_from_desc(algo); | ||||
mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
if (strategy == S::HEURISTIC_REPRODUCIBLE) { | |||||
if (strategy == S(S::HEURISTIC | S::REPRODUCIBLE)) { | |||||
EXPECT_TRUE(palgo->contain_attribute( | EXPECT_TRUE(palgo->contain_attribute( | ||||
megdnn::AlgoAttribute::REPRODUCIBLE)); | megdnn::AlgoAttribute::REPRODUCIBLE)); | ||||
} | } | ||||
@@ -43,6 +43,7 @@ namespace megdnn { | |||||
std::ostream &ostr, const DType &dt) { | std::ostream &ostr, const DType &dt) { | ||||
return ostr << dt.name(); | return ostr << dt.name(); | ||||
} | } | ||||
} // namespace megdnn | } // namespace megdnn | ||||
namespace mgb { | namespace mgb { | ||||
@@ -18,7 +18,7 @@ pdef('PersistentOutputStorage').add_fields( | |||||
add_const('int32', 'INVALID_AXIS', 'MAX_NDIM'). | add_const('int32', 'INVALID_AXIS', 'MAX_NDIM'). | ||||
add_fields('int32', 'axis', 'INVALID_AXIS')) | add_fields('int32', 'axis', 'INVALID_AXIS')) | ||||
(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator'). | |||||
(pdef('ExecutionPolicy', version=0, is_legacy=True). | |||||
add_enum('Strategy', | add_enum('Strategy', | ||||
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), | Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), | ||||
Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, ' | Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, ' | ||||
@@ -33,6 +33,20 @@ pdef('PersistentOutputStorage').add_fields( | |||||
Doc('workspace_limit', 'workspace limit in bytes'), | Doc('workspace_limit', 'workspace limit in bytes'), | ||||
str(2**64-1)+'ull')) | str(2**64-1)+'ull')) | ||||
(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1). | |||||
add_bit_combination_enum('Strategy', | |||||
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), | |||||
Doc('PROFILE', | |||||
'run possible algorithms on real device to find the best'), | |||||
Doc('REPRODUCIBLE', | |||||
'when profile or heuristic algo selection it require the algos' | |||||
'must be reproducible'), | |||||
Doc('OPTMIZED', | |||||
'profile require algos are optmized to achieve fast-profile')). | |||||
add_fields('uint64', | |||||
Doc('workspace_limit', 'workspace limit in bytes'), | |||||
str(2**64-1)+'ull')) | |||||
(pdef('AssertEqual'). | (pdef('AssertEqual'). | ||||
add_fields('float32', | add_fields('float32', | ||||
Doc('maxerr', 'max allowed error; error is defined as the minimal ' | Doc('maxerr', 'max allowed error; error is defined as the minimal ' | ||||