You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

function_dft.h 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. /**
  2. * \file src/mge/function_dft.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #if LITE_BUILD_WITH_MGE
  13. #include "function_base.h"
  14. #include "network_impl.h"
  15. #include "network_impl_base.h"
  16. #include "tensor_impl.h"
  17. namespace lite {
  18. #define THROW_FUNC_ERROR(func_name) \
  19. auto msg_info = func_name + " is not aviliable in Dft backend."; \
  20. LITE_THROW(msg_info.c_str())
  21. // the functions used for dft's tensor.cpp are as followed:
  22. template <>
  23. inline std::shared_ptr<Tensor::TensorImplBase> call_func<
  24. TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(std::string func_name) {
  25. if (func_name == "create_tensor") {
  26. return std::make_shared<TensorImplDft>();
  27. }
  28. THROW_FUNC_ERROR(func_name);
  29. }
  30. template <>
  31. inline std::shared_ptr<Tensor::TensorImplBase> call_func<
  32. TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
  33. std::string func_name, LiteDeviceType device_type, bool is_pinned_host) {
  34. if (func_name == "create_tensor") {
  35. return std::make_shared<TensorImplDft>(device_type, is_pinned_host);
  36. }
  37. THROW_FUNC_ERROR(func_name);
  38. }
  39. template <>
  40. inline std::shared_ptr<Tensor::TensorImplBase> call_func<
  41. TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
  42. std::string func_name, int device_id, LiteDeviceType device_type,
  43. const Layout layout, bool is_pinned_host) {
  44. if (func_name == "create_tensor") {
  45. return std::make_shared<TensorImplDft>(
  46. device_id, device_type, layout, is_pinned_host);
  47. }
  48. THROW_FUNC_ERROR(func_name);
  49. }
  50. template <>
  51. inline std::shared_ptr<Tensor::TensorImplBase> call_func<
  52. TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
  53. std::string func_name, LiteDeviceType device_type, const Layout layout,
  54. bool is_pinned_host) {
  55. if (func_name == "create_tensor") {
  56. return std::make_shared<TensorImplDft>(device_type, layout, is_pinned_host);
  57. }
  58. THROW_FUNC_ERROR(func_name);
  59. }
  60. template <>
  61. inline std::shared_ptr<Tensor::TensorImplBase> call_func<
  62. TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
  63. std::string func_name, int device_id, int stream_id, LiteDeviceType device_type,
  64. bool is_pinned_host) {
  65. if (func_name == "create_tensor") {
  66. return std::make_shared<TensorImplDft>(
  67. device_id, stream_id, device_type, is_pinned_host);
  68. }
  69. THROW_FUNC_ERROR(func_name);
  70. }
  71. // the functions used for dft's network.cpp are as followed:
  72. template <>
  73. inline std::unique_ptr<Network::NetworkImplBase> call_func<
  74. NetworkImplDft, std::unique_ptr<Network::NetworkImplBase>>(
  75. std::string func_name) {
  76. if (func_name == "create_network") {
  77. return std::make_unique<NetworkImplDft>();
  78. }
  79. THROW_FUNC_ERROR(func_name);
  80. }
  81. template <>
  82. inline Network::NetworkImplBase* try_call_func<
  83. NetworkImplDft, Network::NetworkImplBase*>(std::string func_name) {
  84. if (func_name == "parse_model") {
  85. return new NetworkImplDft();
  86. }
  87. THROW_FUNC_ERROR(func_name);
  88. }
  89. #define CALL_FUNC(func_name, ...) \
  90. network_impl->cast_final_safe<NetworkImplDft>().func_name(__VA_ARGS__)
  91. template <>
  92. inline void call_func<NetworkImplDft, void>(
  93. std::string func_name, Network::NetworkImplBase* network_impl, size_t num) {
  94. if (func_name == "set_cpu_threads_number") {
  95. CALL_FUNC(set_cpu_threads_number, num);
  96. } else if (func_name == "set_network_algo_workspace_limit") {
  97. CALL_FUNC(set_network_algo_workspace_limit, num);
  98. } else {
  99. THROW_FUNC_ERROR(func_name);
  100. }
  101. }
  102. template <>
  103. inline void call_func<NetworkImplDft, void>(
  104. std::string func_name, Network::NetworkImplBase* network_impl) {
  105. if (func_name == "use_tensorrt") {
  106. CALL_FUNC(use_tensorrt);
  107. } else if (func_name == "set_cpu_inplace_mode") {
  108. CALL_FUNC(set_cpu_inplace_mode);
  109. } else if (func_name == "enable_global_layout_transform") {
  110. CALL_FUNC(enable_global_layout_transform);
  111. } else {
  112. THROW_FUNC_ERROR(func_name);
  113. }
  114. }
  115. template <>
  116. inline size_t call_func<NetworkImplDft, size_t>(
  117. std::string func_name, Network::NetworkImplBase* network_impl) {
  118. if (func_name == "get_cpu_threads_number") {
  119. return CALL_FUNC(get_cpu_threads_number);
  120. }
  121. THROW_FUNC_ERROR(func_name);
  122. }
  123. template <>
  124. inline bool call_func<NetworkImplDft, bool>(
  125. std::string func_name, Network::NetworkImplBase* network_impl) {
  126. if (func_name == "is_cpu_inplace_mode") {
  127. return CALL_FUNC(is_cpu_inplace_mode);
  128. }
  129. THROW_FUNC_ERROR(func_name);
  130. }
  131. template <>
  132. inline void call_func<NetworkImplDft, void>(
  133. std::string func_name, Network::NetworkImplBase* network_impl,
  134. ThreadAffinityCallback thread_affinity_callback) {
  135. if (func_name == "set_runtime_thread_affinity") {
  136. return CALL_FUNC(
  137. set_runtime_thread_affinity, std::move(thread_affinity_callback));
  138. }
  139. THROW_FUNC_ERROR(func_name);
  140. }
  141. template <>
  142. inline void call_func<NetworkImplDft, void>(
  143. std::string func_name, Network::NetworkImplBase* network_impl,
  144. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  145. bool binary_equal_between_batch) {
  146. if (func_name == "set_network_algo_policy") {
  147. return CALL_FUNC(
  148. set_network_algo_policy, strategy, shared_batch_size,
  149. binary_equal_between_batch);
  150. }
  151. THROW_FUNC_ERROR(func_name);
  152. }
  153. template <>
  154. inline void call_func<NetworkImplDft, void>(
  155. std::string func_name, Network::NetworkImplBase* network_impl,
  156. std::shared_ptr<Allocator> user_allocator) {
  157. if (func_name == "set_memory_allocator") {
  158. return CALL_FUNC(set_memory_allocator, user_allocator);
  159. }
  160. THROW_FUNC_ERROR(func_name);
  161. }
  162. template <>
  163. inline void call_func<NetworkImplDft, void>(
  164. std::string func_name, Network::NetworkImplBase* network_impl,
  165. std::string file_name) {
  166. if (func_name == "enable_io_txt_dump") {
  167. return CALL_FUNC(enable_io_txt_dump, file_name);
  168. } else if (func_name == "enable_io_bin_dump") {
  169. return CALL_FUNC(enable_io_bin_dump, file_name);
  170. } else if (func_name == "dump_layout_transform_model") {
  171. return CALL_FUNC(dump_layout_transform_model, file_name);
  172. }
  173. THROW_FUNC_ERROR(func_name);
  174. }
  175. template <>
  176. inline void call_func<NetworkImplDft, void>(
  177. std::string func_name, Network::NetworkImplBase* network_impl,
  178. Network::NetworkImplBase* src_network_impl) {
  179. if (func_name == "share_runtime_memory_with") {
  180. CALL_FUNC(share_runtime_memory_with, src_network_impl);
  181. } else if (func_name == "shared_weight_with") {
  182. CALL_FUNC(shared_weight_with, src_network_impl);
  183. } else {
  184. THROW_FUNC_ERROR(func_name);
  185. }
  186. }
  187. #undef THROW_FUNC_ERROR
  188. } // namespace lite
  189. #endif
  190. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}