GitOrigin-RevId: 92307dd2ca
tags/v1.3.0
@@ -73,7 +73,7 @@ PyTypeObject PyOpType(name); | |||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
template<typename T, typename SFINAE=void> | |||||
template <typename T, typename SFINAE = void> | |||||
struct pyobj_convert_generic { | struct pyobj_convert_generic { | ||||
static T from(PyObject* obj) { | static T from(PyObject* obj) { | ||||
// TODO: remove this guard which is used for pybind11 implicit conversion | // TODO: remove this guard which is used for pybind11 implicit conversion | ||||
@@ -87,7 +87,12 @@ struct pyobj_convert_generic { | |||||
} | } | ||||
}; | }; | ||||
template<typename T> | |||||
template <typename T> | |||||
struct EnumTrait { | |||||
static constexpr bool is_bit_combined = false; | |||||
}; | |||||
template <typename T> | |||||
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | ||||
PyObject* obj = type->tp_alloc(type, 0); | PyObject* obj = type->tp_alloc(type, 0); | ||||
T* self = reinterpret_cast<T*>(obj); | T* self = reinterpret_cast<T*>(obj); | ||||
@@ -203,9 +208,10 @@ struct EnumWrapper { | |||||
} | } | ||||
}; | }; | ||||
template<typename T> | |||||
template <typename T> | |||||
struct pyobj_convert_generic<T, | struct pyobj_convert_generic<T, | ||||
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||||
std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||||
!EnumTrait<T>::is_bit_combined>> { | |||||
using Wrapper = EnumWrapper<T>; | using Wrapper = EnumWrapper<T>; | ||||
static T from(PyObject* obj) { | static T from(PyObject* obj) { | ||||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | if (PyObject_TypeCheck(obj, &Wrapper::type)) { | ||||
@@ -223,6 +229,115 @@ struct pyobj_convert_generic<T, | |||||
} | } | ||||
}; | }; | ||||
template<typename T> | |||||
struct BitCombinedEnumWrapper { | |||||
static_assert(std::is_enum_v<T>); | |||||
PyObject_HEAD | |||||
T value; | |||||
static const char* name; | |||||
static PyTypeObject type; | |||||
static std::unordered_map<T, std::string> type2str; | |||||
static std::unordered_map<std::string, T> str2type; | |||||
static PyNumberMethods number_methods; | |||||
BitCombinedEnumWrapper() = default; | |||||
BitCombinedEnumWrapper(T v): value(v) {} | |||||
BitCombinedEnumWrapper(std::string&& str) | |||||
: BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {} | |||||
std::string to_string() const { | |||||
if (static_cast<uint32_t>(value) == 0) { | |||||
return "None"; | |||||
} else { | |||||
auto ret = std::string(); | |||||
bool first = true; | |||||
for (uint32_t i = 0; i < 32; i++) { | |||||
uint32_t value_int = static_cast<uint32_t>(value); | |||||
auto it = type2str.find(static_cast<T>((1 << i) & value_int)); | |||||
if (it != type2str.end()) { | |||||
if (!first) { | |||||
ret += " + "; | |||||
} else { | |||||
first = false; | |||||
} | |||||
ret += (std::string(name) + "." + it->second); | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
} | |||||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||||
return obj; | |||||
} | |||||
static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||||
int input = 1; | |||||
if (PyArg_ParseTuple(args, "|i", &input)){ | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||||
static_cast<T>(input); | |||||
} | |||||
return 0; | |||||
} | |||||
static PyObject* py_repr(PyObject* self) { | |||||
return pyobj_convert_generic<std::string>::to( | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string()); | |||||
} | |||||
static PyObject* py_or(PyObject* self, PyObject* other) { | |||||
if(!(self->ob_type == other->ob_type)){ | |||||
return PyErr_Format( | |||||
PyExc_RuntimeError, | |||||
"Operand in or operator must be the same type."); | |||||
} | |||||
PyObject* obj = type.tp_alloc(&type, 0); | |||||
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||||
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||||
static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs)); | |||||
return obj; | |||||
} | |||||
static PyObject* py_and(PyObject* self, PyObject* other) { | |||||
if (!(self->ob_type == other->ob_type)) { | |||||
return PyErr_Format( | |||||
PyExc_RuntimeError, | |||||
"Operand in and operator must be the same type."); | |||||
} | |||||
PyObject* obj = type.tp_alloc(&type, 0); | |||||
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||||
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||||
static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs)); | |||||
return obj; | |||||
} | |||||
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { | |||||
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||||
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||||
if (op == Py_EQ || op == Py_NE) { | |||||
RETURN_RICHCOMPARE(lhs, rhs, op); | |||||
} | |||||
Py_RETURN_NOTIMPLEMENTED; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct pyobj_convert_generic<T, | |||||
std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||||
EnumTrait<T>::is_bit_combined>> { | |||||
using Wrapper = BitCombinedEnumWrapper<T>; | |||||
static T from(PyObject* obj) { | |||||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||||
return reinterpret_cast<Wrapper*>(obj)->value; | |||||
} | |||||
// try as string | |||||
// TODO: type checkcd | |||||
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value; | |||||
} | |||||
static PyObject* to(T t) { | |||||
PyTypeObject* pytype = &Wrapper::type; | |||||
PyObject* obj = pytype->tp_alloc(pytype, 0); | |||||
reinterpret_cast<Wrapper*>(obj)->value = t; | |||||
return obj; | |||||
} | |||||
}; | |||||
void _init_py_op_def(py::module m) { | void _init_py_op_def(py::module m) { | ||||
using py_op = PyOp(OpDef); | using py_op = PyOp(OpDef); | ||||
auto& py_type = PyOpType(OpDef); | auto& py_type = PyOpType(OpDef); | ||||
@@ -408,61 +408,58 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||||
os << ";\n\n"; | os << ";\n\n"; | ||||
} | } | ||||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
auto className = op.getCppClassName(); | |||||
static std::string gen_op_def_python_c_extension_enum( | |||||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||||
llvm::StringRef className) { | |||||
std::string body; | std::string body; | ||||
// generate PyType for enum class member | |||||
for (auto&& i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = | |||||
llvm::cast<MgbEnumAttr>(aliasBase) | |||||
.getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = ctx.enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
auto enumName = attr->getEnumName(); | |||||
body += "{\n"; | |||||
body += formatv( | |||||
"auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName | |||||
); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", | |||||
className, enumName); | |||||
std::vector<std::string> pairStr; | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
pairStr.push_back(formatv( | |||||
"{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = ctx.enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
auto enumName = attr->getEnumName(); | |||||
body += "{\n"; | |||||
body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||||
enumName); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> const char* EnumWrapper<{0}::{1}>::name = " | |||||
"\"{0}.{1}\";\n", | |||||
className, enumName); | |||||
std::vector<std::string> pairStr; | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<std::string, {0}::{1}> | template<> std::unordered_map<std::string, {0}::{1}> | ||||
EnumWrapper<{0}::{1}>::str2type = {{ | EnumWrapper<{0}::{1}>::str2type = {{ | ||||
{2} | {2} | ||||
}; | }; | ||||
)", className, enumName, llvm::join(pairStr, ", ")); | |||||
pairStr.clear(); | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
pairStr.push_back(formatv( | |||||
"{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
pairStr.clear(); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<{0}::{1}, std::string> | template<> std::unordered_map<{0}::{1}, std::string> | ||||
EnumWrapper<{0}::{1}>::type2str = {{ | EnumWrapper<{0}::{1}>::type2str = {{ | ||||
{2} | {2} | ||||
}; | }; | ||||
)", className, enumName, llvm::join(pairStr, ", ")); | |||||
body += formatv(R"( | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
body += formatv(R"( | |||||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | ||||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | ||||
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | ||||
@@ -472,22 +469,140 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||||
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | ||||
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | ||||
mgb_assert(PyType_Ready(&e_type) >= 0); | mgb_assert(PyType_Ready(&e_type) >= 0); | ||||
)", className, enumName); | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
body += formatv(R"({{ | |||||
)", | |||||
className, enumName); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
body += formatv(R"({{ | |||||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | PyObject* inst = e_type.tp_alloc(&e_type, 0); | ||||
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | ||||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | ||||
})", className, enumName, i); | |||||
} | |||||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
} | |||||
body += formatv(R"( | |||||
})", | |||||
className, enumName, i); | |||||
} | |||||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
} | |||||
body += formatv(R"( | |||||
PyType_Modified(&e_type); | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||||
)", | |||||
enumName); | |||||
body += "}\n"; | |||||
return body; | |||||
} | |||||
static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||||
llvm::StringRef className) { | |||||
std::string body; | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = ctx.enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
auto enumName = attr->getEnumName(); | |||||
body += "{\n"; | |||||
body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||||
className, enumName); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"template<> PyTypeObject " | |||||
"BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> PyNumberMethods " | |||||
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||||
"= \"{0}.{1}\";\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||||
"bool is_bit_combined = true;};\n", | |||||
className, enumName); | |||||
std::vector<std::string> pairStr; | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<std::string, {0}::{1}> | |||||
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||||
{2} | |||||
}; | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
pairStr.clear(); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
pairStr.push_back( | |||||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<{0}::{1}, std::string> | |||||
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||||
{2} | |||||
}; | |||||
)", | |||||
className, enumName, llvm::join(pairStr, ", ")); | |||||
body += formatv(R"( | |||||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||||
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
e_type.tp_doc = "{0}.{1}"; | |||||
e_type.tp_base = &PyBaseObject_Type; | |||||
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||||
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||||
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||||
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||||
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||||
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||||
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||||
e_type.tp_as_number = &number_method; | |||||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||||
)", | |||||
className, enumName); | |||||
for (auto&& i : attr->getEnumMembers()) { | |||||
body += formatv(R"({{ | |||||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||||
})", | |||||
className, enumName, i); | |||||
} | |||||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
} | |||||
body += formatv(R"( | |||||
PyType_Modified(&e_type); | PyType_Modified(&e_type); | ||||
mgb_assert(PyDict_SetItemString( | mgb_assert(PyDict_SetItemString( | ||||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | ||||
)", enumName); | |||||
body += "}\n"; | |||||
)", | |||||
enumName); | |||||
body += "}\n"; | |||||
return body; | |||||
} | |||||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
auto className = op.getCppClassName(); | |||||
std::string body; | |||||
// generate PyType for enum class member | |||||
for (auto&& i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
if (attr->getEnumCombinedFlag()) { | |||||
body += gen_op_def_python_c_extension_bit_combined_enum( | |||||
os, ctx, attr, className); | |||||
} else { | |||||
body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||||
className); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -141,15 +141,13 @@ R"__usage__( | |||||
)__usage__" | )__usage__" | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
R"__usage__( | R"__usage__( | ||||
--fast-run | |||||
This param will be deperated later, please replace with param --full-profile. | |||||
--full-profile | |||||
Enable full-profile mode. Operators with multiple algorithms would be profiled | |||||
--full-run | |||||
Enable full-run mode. Operators with multiple algorithms would be profiled | |||||
on the real device with actual input shapes, all algorithms will be profiled | on the real device with actual input shapes, all algorithms will be profiled | ||||
include naive algorithms. | include naive algorithms. | ||||
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | ||||
--fast-profile | |||||
Enable fast-profile mode. Operators with multiple algorithms would be profiled | |||||
--fast-run | |||||
Enable fast-run mode. Operators with multiple algorithms would be profiled | |||||
on the real device with actual input shapes, this mode will only profile the | on the real device with actual input shapes, this mode will only profile the | ||||
well optimized algorithms to get the profile result fast. | well optimized algorithms to get the profile result fast. | ||||
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | ||||
@@ -519,8 +517,8 @@ struct Args { | |||||
bool disable_assert_throw = false; | bool disable_assert_throw = false; | ||||
bool share_param_mem = false; | bool share_param_mem = false; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
bool use_full_profile = false; | |||||
bool use_fast_profile = false; | |||||
bool use_full_run = false; | |||||
bool use_fast_run = false; | |||||
#endif | #endif | ||||
bool reproducible = false; | bool reproducible = false; | ||||
std::string fast_run_cache_path; | std::string fast_run_cache_path; | ||||
@@ -704,13 +702,13 @@ void run_test_st(Args &env) { | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
S strategy = S::HEURISTIC; | S strategy = S::HEURISTIC; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (env.use_full_profile) { | |||||
if (env.use_full_run) { | |||||
if (env.reproducible) { | if (env.reproducible) { | ||||
strategy = S::PROFILE | S::REPRODUCIBLE; | strategy = S::PROFILE | S::REPRODUCIBLE; | ||||
} else { | } else { | ||||
strategy = S::PROFILE; | strategy = S::PROFILE; | ||||
} | } | ||||
} else if (env.use_fast_profile) { | |||||
} else if (env.use_fast_run) { | |||||
strategy = S::PROFILE | S::OPTMIZED; | strategy = S::PROFILE | S::OPTMIZED; | ||||
} else if (env.reproducible) { | } else if (env.reproducible) { | ||||
strategy = S::HEURISTIC | S::REPRODUCIBLE; | strategy = S::HEURISTIC | S::REPRODUCIBLE; | ||||
@@ -740,12 +738,12 @@ void run_test_st(Args &env) { | |||||
std::make_shared<InFilePersistentCache>(buf.get(), flen)); | std::make_shared<InFilePersistentCache>(buf.get(), flen)); | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
} else { | } else { | ||||
mgb_assert(env.use_full_profile || env.use_fast_profile, | |||||
"fast-run or fast-profile should be enabled"); | |||||
mgb_assert(env.use_full_run || env.use_fast_run, | |||||
"fast-run or fast-run should be enabled"); | |||||
PersistentCache::set_impl( | PersistentCache::set_impl( | ||||
std::make_shared<InFilePersistentCache>()); | std::make_shared<InFilePersistentCache>()); | ||||
} | } | ||||
if (!env.use_full_profile && !env.use_fast_profile) | |||||
if (!env.use_full_run && !env.use_fast_run) | |||||
#endif | #endif | ||||
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | ||||
} | } | ||||
@@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) { | |||||
} | } | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (!strcmp(argv[i], "--fast-run")) { | if (!strcmp(argv[i], "--fast-run")) { | ||||
mgb_log_warn( | |||||
"--fast-run param will be deperated later, please replace " | |||||
"with --full-profile or --fast-profile."); | |||||
ret.use_full_profile = true; | |||||
ret.use_fast_run = true; | |||||
continue; | continue; | ||||
} | } | ||||
if (!strcmp(argv[i], "--full-profile")) { | |||||
ret.use_full_profile = true; | |||||
continue; | |||||
} | |||||
if (!strcmp(argv[i], "--fast-profile")) { | |||||
ret.use_fast_profile = true; | |||||
if (!strcmp(argv[i], "--full-run")) { | |||||
ret.use_full_run = true; | |||||
continue; | continue; | ||||
} | } | ||||
#endif | #endif | ||||
@@ -12,7 +12,6 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#include "megbrain/opr/param_defs.h" | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include <memory> | #include <memory> | ||||
@@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { | |||||
} // namespace mgb | } // namespace mgb | ||||
namespace megdnn { | |||||
namespace param { | |||||
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -18,6 +18,7 @@ | |||||
#include "megbrain/utils/hashable.h" | #include "megbrain/utils/hashable.h" | ||||
#include "megbrain/utils/thin/hash_table.h" | #include "megbrain/utils/thin/hash_table.h" | ||||
#include "megbrain/utils/small_vector.h" | #include "megbrain/utils/small_vector.h" | ||||
#include "megbrain/opr/param_defs.h" | |||||
#include <type_traits> | #include <type_traits> | ||||
@@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \ | |||||
} // namespace cg | } // namespace cg | ||||
} // namespace mgb | } // namespace mgb | ||||
namespace megdnn { | |||||
namespace param { | |||||
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||||
return ret; | return ret; | ||||
} | } | ||||
//! Test whether the algo attribute of a algo match the require | |||||
//! algo_strategy | |||||
static bool algo_attribute_match_strategy(AlgoAttribute attribute, | |||||
ExecutionStrategy selected_strategy) { | |||||
bool ret = true; | |||||
if (selected_strategy & ExecutionStrategy::OPTMIZED) { | |||||
ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | |||||
} else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | |||||
} | |||||
return ret; | |||||
} | |||||
} // namespace | } // namespace | ||||
namespace mgb { | namespace mgb { | ||||
@@ -285,8 +298,8 @@ namespace opr { | |||||
template <typename Opr> | template <typename Opr> | ||||
void AlgoChooser<Opr>::profile(ExeContext& ctx, | void AlgoChooser<Opr>::profile(ExeContext& ctx, | ||||
ExecutionStrategy select_strategy) { | |||||
if (ctx.get_profile_result_from_cache(select_strategy).valid()) | |||||
ExecutionStrategy selected_strategy) { | |||||
if (ctx.get_profile_result_from_cache(selected_strategy).valid()) | |||||
return; | return; | ||||
AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
@@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
algo.name.c_str(), str_on_inp_shape.c_str()); | algo.name.c_str(), str_on_inp_shape.c_str()); | ||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
policy.algo = algo.desc; | policy.algo = algo.desc; | ||||
ctx.construct_execution_policy(select_strategy, policy); | |||||
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) | |||||
ctx.construct_execution_policy(selected_strategy, policy); | |||||
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { | |||||
continue; | continue; | ||||
} | |||||
auto algo_attribute = ctx.megdnn_opr() | |||||
->get_algorithm_from_desc(policy.algo) | |||||
->attribute(); | |||||
if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) { | |||||
mgb_log_debug( | |||||
"skip algo %s, which is not match the profile strategy.", | |||||
algo.name.c_str()); | |||||
continue; | |||||
} | |||||
timer.reset(); | timer.reset(); | ||||
MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } | MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } | ||||
@@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy | typename AlgoChooser<Opr>::ImplExecutionPolicy | ||||
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | ||||
ExecutionStrategy select_strategy, | |||||
ExecutionStrategy selected_strategy, | |||||
bool enable_update) { | bool enable_update) { | ||||
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | ||||
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | ||||
@@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||||
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | ||||
_item.param, ctx.mgb_opr(), ctx.comp_node(), | _item.param, ctx.mgb_opr(), ctx.comp_node(), | ||||
ctx.execution_policy(), ctx.allow_weight_preprocess()); | ctx.execution_policy(), ctx.allow_weight_preprocess()); | ||||
AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); | |||||
AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); | |||||
}); | }); | ||||
} | } | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | ||||
ctx.construct_execution_policy(select_strategy, policy); | |||||
ctx.construct_execution_policy(selected_strategy, policy); | |||||
return policy; | return policy; | ||||
MIDOUT_E | MIDOUT_E | ||||
} | } | ||||
@@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||||
if (!policy.algo.valid()) | if (!policy.algo.valid()) | ||||
policy = ctx.choose_by_heuristic(opr_strategy); | policy = ctx.choose_by_heuristic(opr_strategy); | ||||
return policy; | return policy; | ||||
} else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||||
} else if (!static_cast<int>(opr_strategy) || | |||||
(opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||||
return ctx.choose_by_heuristic(opr_strategy); | return ctx.choose_by_heuristic(opr_strategy); | ||||
} | } | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
@@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||||
} | } | ||||
#endif | #endif | ||||
else { | else { | ||||
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); | |||||
mgb_throw(GraphError, "bad ExecutionPolicy strategy"); | |||||
} | } | ||||
} | } | ||||
@@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext( | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplAlgo | typename AlgoChooser<Opr>::ImplAlgo | ||||
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | ||||
ExecutionStrategy select_strategy) const { | |||||
ExecutionStrategy selected_strategy) const { | |||||
MIDOUT_B(Opr, | MIDOUT_B(Opr, | ||||
midout_iv(MGB_HASH_STR( | midout_iv(MGB_HASH_STR( | ||||
"AlgoChooser::ExeContext::get_profile_result_from_cache"))) | "AlgoChooser::ExeContext::get_profile_result_from_cache"))) | ||||
@@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
if (prof.empty()) | if (prof.empty()) | ||||
return {}; | return {}; | ||||
for (auto&& i : prof) { | for (auto&& i : prof) { | ||||
if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||||
if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||||
static_cast<AlgoAttribute>(i.attribute) & | static_cast<AlgoAttribute>(i.attribute) & | ||||
AlgoAttribute::REPRODUCIBLE) { | AlgoAttribute::REPRODUCIBLE) { | ||||
auto iter = algo_map.find(i.algo); | auto iter = algo_map.find(i.algo); | ||||
@@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy | typename AlgoChooser<Opr>::ImplExecutionPolicy | ||||
AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | ||||
ExecutionStrategy select_strategy) const { | |||||
ExecutionStrategy selected_strategy) const { | |||||
if (m_execution_policy.workspace_limit != | if (m_execution_policy.workspace_limit != | ||||
std::numeric_limits<decltype( | std::numeric_limits<decltype( | ||||
m_execution_policy.workspace_limit)>::max()) { | m_execution_policy.workspace_limit)>::max()) { | ||||
@@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
"workspace_limit should not be setted if choose algo by " | "workspace_limit should not be setted if choose algo by " | ||||
"heuristic"); | "heuristic"); | ||||
} | } | ||||
bool reproducible = static_cast<bool>(select_strategy & | |||||
bool reproducible = static_cast<bool>(selected_strategy & | |||||
ExecutionStrategy::REPRODUCIBLE); | ExecutionStrategy::REPRODUCIBLE); | ||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
@@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, | _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | ||||
m_allow_weight_preprocess); | m_allow_weight_preprocess); | ||||
policy.sub_policy.push_back( | policy.sub_policy.push_back( | ||||
sub_ctx.choose_by_heuristic(select_strategy)); | |||||
sub_ctx.choose_by_heuristic(selected_strategy)); | |||||
}); | }); | ||||
return policy; | return policy; | ||||
@@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const { | |||||
template <typename Opr> | template <typename Opr> | ||||
void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | ||||
ExecutionStrategy select_strategy, | |||||
ExecutionStrategy selected_strategy, | |||||
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | ||||
bool retrive_from_cache) const { | bool retrive_from_cache) const { | ||||
bool reproducible = static_cast<bool>(select_strategy & | |||||
bool reproducible = static_cast<bool>(selected_strategy & | |||||
ExecutionStrategy::REPRODUCIBLE); | ExecutionStrategy::REPRODUCIBLE); | ||||
if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
policy.algo = | policy.algo = | ||||
get_profile_result_from_cache(select_strategy).desc; | |||||
get_profile_result_from_cache(selected_strategy).desc; | |||||
} else { | } else { | ||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
@@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, | _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | ||||
m_allow_weight_preprocess); | m_allow_weight_preprocess); | ||||
policy.sub_policy.push_back({}); | policy.sub_policy.push_back({}); | ||||
sub_ctx.construct_execution_policy(select_strategy, | |||||
sub_ctx.construct_execution_policy(selected_strategy, | |||||
policy.sub_policy.back(), | policy.sub_policy.back(), | ||||
retrive_from_cache); | retrive_from_cache); | ||||
}); | }); | ||||
@@ -110,7 +110,7 @@ public: | |||||
const FixedTensorLayouts& layouts() const { return m_layouts; } | const FixedTensorLayouts& layouts() const { return m_layouts; } | ||||
ImplExecutionPolicy choose_by_heuristic( | ImplExecutionPolicy choose_by_heuristic( | ||||
ExecutionStrategy select_strategy) const; | |||||
ExecutionStrategy selected_strategy) const; | |||||
//! get all candidate algos, and the one choose_by_heuristic() is | //! get all candidate algos, and the one choose_by_heuristic() is | ||||
//! put first | //! put first | ||||
@@ -134,17 +134,17 @@ public: | |||||
//! get all profile algorithm from cache, return invalid if not exists | //! get all profile algorithm from cache, return invalid if not exists | ||||
ImplAlgo get_profile_result_from_cache( | ImplAlgo get_profile_result_from_cache( | ||||
ExecutionStrategy select_strategy) const; | |||||
ExecutionStrategy selected_strategy) const; | |||||
/** | /** | ||||
* \brief construct execution policy from cache or heuristic. | * \brief construct execution policy from cache or heuristic. | ||||
* | * | ||||
* \param select_strategy select algo which matched this strategy | |||||
* \param selected_strategy select algo which matched this strategy | |||||
* \param policy execution policy | * \param policy execution policy | ||||
* \param retrive_from_cache retrive algo from cache if set True, get | * \param retrive_from_cache retrive algo from cache if set True, get | ||||
* from heuristic otherwise. | * from heuristic otherwise. | ||||
*/ | */ | ||||
void construct_execution_policy(ExecutionStrategy select_strategy, | |||||
void construct_execution_policy(ExecutionStrategy selected_strategy, | |||||
ImplExecutionPolicy& policy, | ImplExecutionPolicy& policy, | ||||
bool retrive_from_cache = true) const; | bool retrive_from_cache = true) const; | ||||
@@ -161,10 +161,10 @@ private: | |||||
//! profile and save to cache | //! profile and save to cache | ||||
static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); | |||||
static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); | |||||
static ImplExecutionPolicy choose_by_profile( | static ImplExecutionPolicy choose_by_profile( | ||||
ExeContext& ctx, ExecutionStrategy select_strategy, | |||||
ExeContext& ctx, ExecutionStrategy selected_strategy, | |||||
bool enable_update = true); | bool enable_update = true); | ||||
public: | public: | ||||