You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layout_options.cpp 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. #include <gflags/gflags.h>
  2. #include "misc.h"
  3. #include "models/model_lite.h"
  4. #include "models/model_mdl.h"
  5. #include "layout_options.h"
  6. namespace lar {
  7. template <>
  8. void LayoutOption::config_model_internel<ModelLite>(
  9. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  10. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  11. #define ENABLE_LAYOUT(layout) \
  12. LITE_WARN("enable " #layout " optimization"); \
  13. model->get_config().options.enable_##layout = true; \
  14. break;
  15. switch (m_option_flag) {
  16. case OptLayoutType::NCHW4:
  17. ENABLE_LAYOUT(nchw4)
  18. case OptLayoutType::CHWN4:
  19. LITE_THROW("lite model unsupport chwn4 layout");
  20. break;
  21. case OptLayoutType::NCHW44:
  22. ENABLE_LAYOUT(nchw44)
  23. case OptLayoutType::NCHW88:
  24. ENABLE_LAYOUT(nchw88)
  25. case OptLayoutType::NCHW32:
  26. ENABLE_LAYOUT(nchw32)
  27. case OptLayoutType::NCHW64:
  28. ENABLE_LAYOUT(nchw64)
  29. case OptLayoutType::NHWCD4:
  30. ENABLE_LAYOUT(nhwcd4)
  31. case OptLayoutType::NCHW44_DOT:
  32. ENABLE_LAYOUT(nchw44_dot)
  33. default:
  34. break;
  35. }
  36. #undef ENABLE_LAYOUT
  37. }
  38. }
  39. template <>
  40. void lar::LayoutOption::config_model_internel<ModelMdl>(
  41. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  42. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  43. #define ENABLE_LAYOUT(layout) \
  44. mgb_log_warn("enable " #layout " optimization"); \
  45. model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \
  46. break;
  47. switch (m_option_flag) {
  48. case OptLayoutType::NCHW4:
  49. ENABLE_LAYOUT(nchw4)
  50. case OptLayoutType::CHWN4:
  51. ENABLE_LAYOUT(chwn4)
  52. case OptLayoutType::NCHW44:
  53. ENABLE_LAYOUT(nchw44)
  54. case OptLayoutType::NCHW88:
  55. ENABLE_LAYOUT(nchw88)
  56. case OptLayoutType::NCHW32:
  57. ENABLE_LAYOUT(nchw32)
  58. case OptLayoutType::NCHW64:
  59. ENABLE_LAYOUT(nchw64)
  60. case OptLayoutType::NHWCD4:
  61. ENABLE_LAYOUT(nhwcd4)
  62. case OptLayoutType::NCHW44_DOT:
  63. ENABLE_LAYOUT(nchw44_dot)
  64. default:
  65. break;
  66. }
  67. #undef ENABLE_LAYOUT
  68. }
  69. }
  70. } // namespace lar
  71. using namespace lar;
  72. bool LayoutOption::m_valid;
  73. LayoutOption::LayoutOption() {
  74. m_option_name = "layout";
  75. m_option_flag = static_cast<OptLayoutType>(0);
  76. m_option = {
  77. {"enable_nchw4", lar::Bool::make(false)},
  78. {"enable_chwn4", lar::Bool::make(false)},
  79. {"enable_nchw44", lar::Bool::make(false)},
  80. {"enable_nchw88", lar::Bool::make(false)},
  81. {"enable_nchw32", lar::Bool::make(false)},
  82. {"enable_nchw64", lar::Bool::make(false)},
  83. {"enable_nhwcd4", lar::Bool::make(false)},
  84. {"enable_nchw44_dot", lar::Bool::make(false)},
  85. };
  86. std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])
  87. ->set_value(FLAGS_enable_nchw4);
  88. std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])
  89. ->set_value(FLAGS_enable_chwn4);
  90. std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])
  91. ->set_value(FLAGS_enable_nchw44);
  92. std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])
  93. ->set_value(FLAGS_enable_nchw88);
  94. std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])
  95. ->set_value(FLAGS_enable_nchw32);
  96. std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"])
  97. ->set_value(FLAGS_enable_nchw64);
  98. std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"])
  99. ->set_value(FLAGS_enable_nhwcd4);
  100. std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"])
  101. ->set_value(FLAGS_enable_nchw44_dot);
  102. }
  103. bool LayoutOption::is_valid() {
  104. size_t valid_flag = 0;
  105. if (FLAGS_enable_nchw4) {
  106. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4);
  107. }
  108. if (FLAGS_enable_chwn4) {
  109. valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4);
  110. }
  111. if (FLAGS_enable_nchw44) {
  112. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44);
  113. }
  114. if (FLAGS_enable_nchw88) {
  115. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88);
  116. }
  117. if (FLAGS_enable_nchw32) {
  118. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32);
  119. }
  120. if (FLAGS_enable_nchw64) {
  121. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64);
  122. }
  123. if (FLAGS_enable_nhwcd4) {
  124. valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4);
  125. }
  126. if (FLAGS_enable_nchw44_dot) {
  127. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT);
  128. }
  129. //! only one flag is valid
  130. bool ret = valid_flag && !(valid_flag & (valid_flag - 1));
  131. return ret | m_valid;
  132. };
  133. std::shared_ptr<OptionBase> LayoutOption::create_option() {
  134. static std::shared_ptr<LayoutOption> option(new LayoutOption);
  135. if (LayoutOption::is_valid()) {
  136. return std::static_pointer_cast<OptionBase>(option);
  137. } else {
  138. return nullptr;
  139. }
  140. }
  141. void LayoutOption::config_model(
  142. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  143. size_t valid_flag = 0;
  144. if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) {
  145. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4);
  146. }
  147. if (std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) {
  148. valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4);
  149. }
  150. if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) {
  151. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44);
  152. }
  153. if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) {
  154. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88);
  155. }
  156. if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) {
  157. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32);
  158. }
  159. if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"])->get_value()) {
  160. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64);
  161. }
  162. if (std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"])->get_value()) {
  163. valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4);
  164. }
  165. if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"])
  166. ->get_value()) {
  167. valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT);
  168. }
  169. mgb_throw_if(
  170. valid_flag && (valid_flag & (valid_flag - 1)), mgb::AssertionError,
  171. "invalid options of layout transform 0x%lx", valid_flag);
  172. m_option_flag = static_cast<OptLayoutType>(valid_flag);
  173. CONFIG_MODEL_FUN;
  174. }
  175. DEFINE_bool(enable_nchw4, false, "enable nchw4 layout optimization!!");
  176. DEFINE_bool(enable_chwn4, false, "enable chwn4 layout optimization!!");
  177. DEFINE_bool(enable_nchw44, false, "enable nchw44 layout optimization!!");
  178. DEFINE_bool(enable_nchw88, false, "enable nchw88 layout optimization!!");
  179. DEFINE_bool(enable_nchw32, false, "enable nchw32 layout optimization!!");
  180. DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!");
  181. DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!");
  182. DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!");
  183. REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option);
  184. REGIST_OPTION_VALIDATER(layout, lar::LayoutOption::set_valid);