GitOrigin-RevId: 92307dd2ca
tags/v1.3.0
@@ -73,7 +73,7 @@ PyTypeObject PyOpType(name); | |||
} \ | |||
} while (0) | |||
template<typename T, typename SFINAE=void> | |||
template <typename T, typename SFINAE = void> | |||
struct pyobj_convert_generic { | |||
static T from(PyObject* obj) { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
@@ -87,7 +87,12 @@ struct pyobj_convert_generic { | |||
} | |||
}; | |||
template<typename T> | |||
template <typename T> | |||
struct EnumTrait { | |||
static constexpr bool is_bit_combined = false; | |||
}; | |||
template <typename T> | |||
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | |||
PyObject* obj = type->tp_alloc(type, 0); | |||
T* self = reinterpret_cast<T*>(obj); | |||
@@ -203,9 +208,10 @@ struct EnumWrapper { | |||
} | |||
}; | |||
template<typename T> | |||
template <typename T> | |||
struct pyobj_convert_generic<T, | |||
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||
std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||
!EnumTrait<T>::is_bit_combined>> { | |||
using Wrapper = EnumWrapper<T>; | |||
static T from(PyObject* obj) { | |||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
@@ -223,6 +229,115 @@ struct pyobj_convert_generic<T, | |||
} | |||
}; | |||
template<typename T> | |||
struct BitCombinedEnumWrapper { | |||
static_assert(std::is_enum_v<T>); | |||
PyObject_HEAD | |||
T value; | |||
static const char* name; | |||
static PyTypeObject type; | |||
static std::unordered_map<T, std::string> type2str; | |||
static std::unordered_map<std::string, T> str2type; | |||
static PyNumberMethods number_methods; | |||
BitCombinedEnumWrapper() = default; | |||
BitCombinedEnumWrapper(T v): value(v) {} | |||
BitCombinedEnumWrapper(std::string&& str) | |||
: BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {} | |||
std::string to_string() const { | |||
if (static_cast<uint32_t>(value) == 0) { | |||
return "None"; | |||
} else { | |||
auto ret = std::string(); | |||
bool first = true; | |||
for (uint32_t i = 0; i < 32; i++) { | |||
uint32_t value_int = static_cast<uint32_t>(value); | |||
auto it = type2str.find(static_cast<T>((1 << i) & value_int)); | |||
if (it != type2str.end()) { | |||
if (!first) { | |||
ret += " + "; | |||
} else { | |||
first = false; | |||
} | |||
ret += (std::string(name) + "." + it->second); | |||
} | |||
} | |||
return ret; | |||
} | |||
} | |||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||
PyObject* obj = type->tp_alloc(type, 0); | |||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||
return obj; | |||
} | |||
static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||
int input = 1; | |||
if (PyArg_ParseTuple(args, "|i", &input)){ | |||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||
static_cast<T>(input); | |||
} | |||
return 0; | |||
} | |||
static PyObject* py_repr(PyObject* self) { | |||
return pyobj_convert_generic<std::string>::to( | |||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string()); | |||
} | |||
static PyObject* py_or(PyObject* self, PyObject* other) { | |||
if(!(self->ob_type == other->ob_type)){ | |||
return PyErr_Format( | |||
PyExc_RuntimeError, | |||
"Operand in or operator must be the same type."); | |||
} | |||
PyObject* obj = type.tp_alloc(&type, 0); | |||
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||
static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs)); | |||
return obj; | |||
} | |||
static PyObject* py_and(PyObject* self, PyObject* other) { | |||
if (!(self->ob_type == other->ob_type)) { | |||
return PyErr_Format( | |||
PyExc_RuntimeError, | |||
"Operand in and operator must be the same type."); | |||
} | |||
PyObject* obj = type.tp_alloc(&type, 0); | |||
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||
static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs)); | |||
return obj; | |||
} | |||
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { | |||
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||
if (op == Py_EQ || op == Py_NE) { | |||
RETURN_RICHCOMPARE(lhs, rhs, op); | |||
} | |||
Py_RETURN_NOTIMPLEMENTED; | |||
} | |||
}; | |||
template <typename T> | |||
struct pyobj_convert_generic<T, | |||
std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||
EnumTrait<T>::is_bit_combined>> { | |||
using Wrapper = BitCombinedEnumWrapper<T>; | |||
static T from(PyObject* obj) { | |||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
return reinterpret_cast<Wrapper*>(obj)->value; | |||
} | |||
// try as string | |||
// TODO: type checkcd | |||
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value; | |||
} | |||
static PyObject* to(T t) { | |||
PyTypeObject* pytype = &Wrapper::type; | |||
PyObject* obj = pytype->tp_alloc(pytype, 0); | |||
reinterpret_cast<Wrapper*>(obj)->value = t; | |||
return obj; | |||
} | |||
}; | |||
void _init_py_op_def(py::module m) { | |||
using py_op = PyOp(OpDef); | |||
auto& py_type = PyOpType(OpDef); | |||
@@ -408,61 +408,58 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||
os << ";\n\n"; | |||
} | |||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
auto className = op.getCppClassName(); | |||
static std::string gen_op_def_python_c_extension_enum( | |||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
llvm::StringRef className) { | |||
std::string body; | |||
// generate PyType for enum class member | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = | |||
llvm::cast<MgbEnumAttr>(aliasBase) | |||
.getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
auto enumName = attr->getEnumName(); | |||
body += "{\n"; | |||
body += formatv( | |||
"auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName | |||
); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", | |||
className, enumName); | |||
std::vector<std::string> pairStr; | |||
for (auto&& i: attr->getEnumMembers()) { | |||
pairStr.push_back(formatv( | |||
"{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
auto enumName = attr->getEnumName(); | |||
body += "{\n"; | |||
body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||
enumName); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> const char* EnumWrapper<{0}::{1}>::name = " | |||
"\"{0}.{1}\";\n", | |||
className, enumName); | |||
std::vector<std::string> pairStr; | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<std::string, {0}::{1}> | |||
EnumWrapper<{0}::{1}>::str2type = {{ | |||
{2} | |||
}; | |||
)", className, enumName, llvm::join(pairStr, ", ")); | |||
pairStr.clear(); | |||
for (auto&& i: attr->getEnumMembers()) { | |||
pairStr.push_back(formatv( | |||
"{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
pairStr.clear(); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<{0}::{1}, std::string> | |||
EnumWrapper<{0}::{1}>::type2str = {{ | |||
{2} | |||
}; | |||
)", className, enumName, llvm::join(pairStr, ", ")); | |||
body += formatv(R"( | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
body += formatv(R"( | |||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||
@@ -472,22 +469,140 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||
)", className, enumName); | |||
for (auto&& i: attr->getEnumMembers()) { | |||
body += formatv(R"({{ | |||
)", | |||
className, enumName); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
body += formatv(R"({{ | |||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
})", className, enumName, i); | |||
} | |||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
} | |||
body += formatv(R"( | |||
})", | |||
className, enumName, i); | |||
} | |||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
} | |||
body += formatv(R"( | |||
PyType_Modified(&e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
)", | |||
enumName); | |||
body += "}\n"; | |||
return body; | |||
} | |||
static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
llvm::StringRef className) { | |||
std::string body; | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
auto enumName = attr->getEnumName(); | |||
body += "{\n"; | |||
body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||
className, enumName); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"template<> PyTypeObject " | |||
"BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> PyNumberMethods " | |||
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||
"= \"{0}.{1}\";\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||
"bool is_bit_combined = true;};\n", | |||
className, enumName); | |||
std::vector<std::string> pairStr; | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<std::string, {0}::{1}> | |||
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||
{2} | |||
}; | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
pairStr.clear(); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<{0}::{1}, std::string> | |||
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||
{2} | |||
}; | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
body += formatv(R"( | |||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
e_type.tp_doc = "{0}.{1}"; | |||
e_type.tp_base = &PyBaseObject_Type; | |||
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||
e_type.tp_as_number = &number_method; | |||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||
)", | |||
className, enumName); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
body += formatv(R"({{ | |||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
})", | |||
className, enumName, i); | |||
} | |||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
} | |||
body += formatv(R"( | |||
PyType_Modified(&e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
)", enumName); | |||
body += "}\n"; | |||
)", | |||
enumName); | |||
body += "}\n"; | |||
return body; | |||
} | |||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
auto className = op.getCppClassName(); | |||
std::string body; | |||
// generate PyType for enum class member | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
if (attr->getEnumCombinedFlag()) { | |||
body += gen_op_def_python_c_extension_bit_combined_enum( | |||
os, ctx, attr, className); | |||
} else { | |||
body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||
className); | |||
} | |||
} | |||
} | |||
@@ -141,15 +141,13 @@ R"__usage__( | |||
)__usage__" | |||
#if MGB_ENABLE_FASTRUN | |||
R"__usage__( | |||
--fast-run | |||
This param will be deperated later, please replace with param --full-profile. | |||
--full-profile | |||
Enable full-profile mode. Operators with multiple algorithms would be profiled | |||
--full-run | |||
Enable full-run mode. Operators with multiple algorithms would be profiled | |||
on the real device with actual input shapes, all algorithms will be profiled | |||
include naive algorithms. | |||
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | |||
--fast-profile | |||
Enable fast-profile mode. Operators with multiple algorithms would be profiled | |||
--fast-run | |||
Enable fast-run mode. Operators with multiple algorithms would be profiled | |||
on the real device with actual input shapes, this mode will only profile the | |||
well optimized algorithms to get the profile result fast. | |||
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | |||
@@ -519,8 +517,8 @@ struct Args { | |||
bool disable_assert_throw = false; | |||
bool share_param_mem = false; | |||
#if MGB_ENABLE_FASTRUN | |||
bool use_full_profile = false; | |||
bool use_fast_profile = false; | |||
bool use_full_run = false; | |||
bool use_fast_run = false; | |||
#endif | |||
bool reproducible = false; | |||
std::string fast_run_cache_path; | |||
@@ -704,13 +702,13 @@ void run_test_st(Args &env) { | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
S strategy = S::HEURISTIC; | |||
#if MGB_ENABLE_FASTRUN | |||
if (env.use_full_profile) { | |||
if (env.use_full_run) { | |||
if (env.reproducible) { | |||
strategy = S::PROFILE | S::REPRODUCIBLE; | |||
} else { | |||
strategy = S::PROFILE; | |||
} | |||
} else if (env.use_fast_profile) { | |||
} else if (env.use_fast_run) { | |||
strategy = S::PROFILE | S::OPTMIZED; | |||
} else if (env.reproducible) { | |||
strategy = S::HEURISTIC | S::REPRODUCIBLE; | |||
@@ -740,12 +738,12 @@ void run_test_st(Args &env) { | |||
std::make_shared<InFilePersistentCache>(buf.get(), flen)); | |||
#if MGB_ENABLE_FASTRUN | |||
} else { | |||
mgb_assert(env.use_full_profile || env.use_fast_profile, | |||
"fast-run or fast-profile should be enabled"); | |||
mgb_assert(env.use_full_run || env.use_fast_run, | |||
"fast-run or fast-run should be enabled"); | |||
PersistentCache::set_impl( | |||
std::make_shared<InFilePersistentCache>()); | |||
} | |||
if (!env.use_full_profile && !env.use_fast_profile) | |||
if (!env.use_full_run && !env.use_fast_run) | |||
#endif | |||
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | |||
} | |||
@@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) { | |||
} | |||
#if MGB_ENABLE_FASTRUN | |||
if (!strcmp(argv[i], "--fast-run")) { | |||
mgb_log_warn( | |||
"--fast-run param will be deperated later, please replace " | |||
"with --full-profile or --fast-profile."); | |||
ret.use_full_profile = true; | |||
ret.use_fast_run = true; | |||
continue; | |||
} | |||
if (!strcmp(argv[i], "--full-profile")) { | |||
ret.use_full_profile = true; | |||
continue; | |||
} | |||
if (!strcmp(argv[i], "--fast-profile")) { | |||
ret.use_fast_profile = true; | |||
if (!strcmp(argv[i], "--full-run")) { | |||
ret.use_full_run = true; | |||
continue; | |||
} | |||
#endif | |||
@@ -12,7 +12,6 @@ | |||
#pragma once | |||
#include "megbrain_build_config.h" | |||
#include "megbrain/opr/param_defs.h" | |||
#include "megdnn/basic_types.h" | |||
#include <memory> | |||
@@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { | |||
} // namespace mgb | |||
namespace megdnn { | |||
namespace param { | |||
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -18,6 +18,7 @@ | |||
#include "megbrain/utils/hashable.h" | |||
#include "megbrain/utils/thin/hash_table.h" | |||
#include "megbrain/utils/small_vector.h" | |||
#include "megbrain/opr/param_defs.h" | |||
#include <type_traits> | |||
@@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \ | |||
} // namespace cg | |||
} // namespace mgb | |||
namespace megdnn { | |||
namespace param { | |||
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||
return ret; | |||
} | |||
//! Test whether the algo attribute of a algo match the require | |||
//! algo_strategy | |||
static bool algo_attribute_match_strategy(AlgoAttribute attribute, | |||
ExecutionStrategy selected_strategy) { | |||
bool ret = true; | |||
if (selected_strategy & ExecutionStrategy::OPTMIZED) { | |||
ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | |||
} else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | |||
ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | |||
} | |||
return ret; | |||
} | |||
} // namespace | |||
namespace mgb { | |||
@@ -285,8 +298,8 @@ namespace opr { | |||
template <typename Opr> | |||
void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
ExecutionStrategy select_strategy) { | |||
if (ctx.get_profile_result_from_cache(select_strategy).valid()) | |||
ExecutionStrategy selected_strategy) { | |||
if (ctx.get_profile_result_from_cache(selected_strategy).valid()) | |||
return; | |||
AlgoChooserProfileCache::Result prof_rst; | |||
@@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
algo.name.c_str(), str_on_inp_shape.c_str()); | |||
ImplExecutionPolicy policy; | |||
policy.algo = algo.desc; | |||
ctx.construct_execution_policy(select_strategy, policy); | |||
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) | |||
ctx.construct_execution_policy(selected_strategy, policy); | |||
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { | |||
continue; | |||
} | |||
auto algo_attribute = ctx.megdnn_opr() | |||
->get_algorithm_from_desc(policy.algo) | |||
->attribute(); | |||
if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) { | |||
mgb_log_debug( | |||
"skip algo %s, which is not match the profile strategy.", | |||
algo.name.c_str()); | |||
continue; | |||
} | |||
timer.reset(); | |||
MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } | |||
@@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
template <typename Opr> | |||
typename AlgoChooser<Opr>::ImplExecutionPolicy | |||
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||
ExecutionStrategy select_strategy, | |||
ExecutionStrategy selected_strategy, | |||
bool enable_update) { | |||
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | |||
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | |||
@@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | |||
_item.param, ctx.mgb_opr(), ctx.comp_node(), | |||
ctx.execution_policy(), ctx.allow_weight_preprocess()); | |||
AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); | |||
AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); | |||
}); | |||
} | |||
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | |||
ctx.construct_execution_policy(select_strategy, policy); | |||
ctx.construct_execution_policy(selected_strategy, policy); | |||
return policy; | |||
MIDOUT_E | |||
} | |||
@@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||
if (!policy.algo.valid()) | |||
policy = ctx.choose_by_heuristic(opr_strategy); | |||
return policy; | |||
} else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||
} else if (!static_cast<int>(opr_strategy) || | |||
(opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||
return ctx.choose_by_heuristic(opr_strategy); | |||
} | |||
#if MGB_ENABLE_FASTRUN | |||
@@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||
} | |||
#endif | |||
else { | |||
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); | |||
mgb_throw(GraphError, "bad ExecutionPolicy strategy"); | |||
} | |||
} | |||
@@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext( | |||
template <typename Opr> | |||
typename AlgoChooser<Opr>::ImplAlgo | |||
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
ExecutionStrategy select_strategy) const { | |||
ExecutionStrategy selected_strategy) const { | |||
MIDOUT_B(Opr, | |||
midout_iv(MGB_HASH_STR( | |||
"AlgoChooser::ExeContext::get_profile_result_from_cache"))) | |||
@@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
if (prof.empty()) | |||
return {}; | |||
for (auto&& i : prof) { | |||
if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||
if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||
static_cast<AlgoAttribute>(i.attribute) & | |||
AlgoAttribute::REPRODUCIBLE) { | |||
auto iter = algo_map.find(i.algo); | |||
@@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
template <typename Opr> | |||
typename AlgoChooser<Opr>::ImplExecutionPolicy | |||
AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
ExecutionStrategy select_strategy) const { | |||
ExecutionStrategy selected_strategy) const { | |||
if (m_execution_policy.workspace_limit != | |||
std::numeric_limits<decltype( | |||
m_execution_policy.workspace_limit)>::max()) { | |||
@@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
"workspace_limit should not be setted if choose algo by " | |||
"heuristic"); | |||
} | |||
bool reproducible = static_cast<bool>(select_strategy & | |||
bool reproducible = static_cast<bool>(selected_strategy & | |||
ExecutionStrategy::REPRODUCIBLE); | |||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
@@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, | |||
m_allow_weight_preprocess); | |||
policy.sub_policy.push_back( | |||
sub_ctx.choose_by_heuristic(select_strategy)); | |||
sub_ctx.choose_by_heuristic(selected_strategy)); | |||
}); | |||
return policy; | |||
@@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const { | |||
template <typename Opr> | |||
void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
ExecutionStrategy select_strategy, | |||
ExecutionStrategy selected_strategy, | |||
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | |||
bool retrive_from_cache) const { | |||
bool reproducible = static_cast<bool>(select_strategy & | |||
bool reproducible = static_cast<bool>(selected_strategy & | |||
ExecutionStrategy::REPRODUCIBLE); | |||
if (!policy.algo.valid()) { | |||
if (retrive_from_cache) { | |||
policy.algo = | |||
get_profile_result_from_cache(select_strategy).desc; | |||
get_profile_result_from_cache(selected_strategy).desc; | |||
} else { | |||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
@@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, | |||
m_allow_weight_preprocess); | |||
policy.sub_policy.push_back({}); | |||
sub_ctx.construct_execution_policy(select_strategy, | |||
sub_ctx.construct_execution_policy(selected_strategy, | |||
policy.sub_policy.back(), | |||
retrive_from_cache); | |||
}); | |||
@@ -110,7 +110,7 @@ public: | |||
const FixedTensorLayouts& layouts() const { return m_layouts; } | |||
ImplExecutionPolicy choose_by_heuristic( | |||
ExecutionStrategy select_strategy) const; | |||
ExecutionStrategy selected_strategy) const; | |||
//! get all candidate algos, and the one choose_by_heuristic() is | |||
//! put first | |||
@@ -134,17 +134,17 @@ public: | |||
//! get all profile algorithm from cache, return invalid if not exists | |||
ImplAlgo get_profile_result_from_cache( | |||
ExecutionStrategy select_strategy) const; | |||
ExecutionStrategy selected_strategy) const; | |||
/** | |||
* \brief construct execution policy from cache or heuristic. | |||
* | |||
* \param select_strategy select algo which matched this strategy | |||
* \param selected_strategy select algo which matched this strategy | |||
* \param policy execution policy | |||
* \param retrive_from_cache retrive algo from cache if set True, get | |||
* from heuristic otherwise. | |||
*/ | |||
void construct_execution_policy(ExecutionStrategy select_strategy, | |||
void construct_execution_policy(ExecutionStrategy selected_strategy, | |||
ImplExecutionPolicy& policy, | |||
bool retrive_from_cache = true) const; | |||
@@ -161,10 +161,10 @@ private: | |||
//! profile and save to cache | |||
static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); | |||
static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); | |||
static ImplExecutionPolicy choose_by_profile( | |||
ExeContext& ctx, ExecutionStrategy select_strategy, | |||
ExeContext& ctx, ExecutionStrategy selected_strategy, | |||
bool enable_update = true); | |||
public: | |||