You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_manip.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #include "megbrain/opr/tensor_manip.h"
  2. #include "megbrain/imperative/ops/autogen.h"
  3. #include "megbrain/imperative/ops/opr_attr.h"
  4. #include "../dnn_op_helper.h"
  5. #include "../op_trait.h"
  6. namespace mgb::imperative {
  7. namespace get_var_shape {
  8. cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  9. auto&& op_def = def.cast_final_safe<GetVarShape>();
  10. OperatorNodeConfig config{op_def.make_name()};
  11. return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
  12. }
  13. DispatchMode decide_dispatch_mode(
  14. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  15. bool host_computable = true;
  16. for (auto&& inp : inputs) {
  17. // FIXME(czh): remove value check after proxy graph's
  18. // apply_on_device_tensornd is supported and output Tensor
  19. // is made before add_task.
  20. // then if layout is valid, ptr->layout must be ready
  21. if (inp.value.empty() || inp.value.layout().ndim == 0) {
  22. host_computable = false;
  23. break;
  24. }
  25. }
  26. return host_computable ? DEFAULT_CPU : KERNEL;
  27. }
  28. void apply_on_device_tensornd(
  29. const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
  30. SmallVector<DeviceTensorND>* outputs) {
  31. auto&& op_def = def.cast_final_safe<GetVarShape>();
  32. TensorShape shp;
  33. if (inputs.size() == 1) {
  34. shp = inputs[0].layout();
  35. } else {
  36. TensorShapeArray src(inputs.size());
  37. for (size_t i = 0; i < inputs.size(); ++i) {
  38. src[i] = inputs[i].layout();
  39. }
  40. megdnn::Elemwise::deduce_shape(src, shp);
  41. }
  42. mgb_assert(shp.ndim != 0, "input shape invalid");
  43. mgb_assert(
  44. (*outputs)[0].comp_node() == CompNode::default_cpu(),
  45. "GetVarShape's apply_on_device_tensornd should receive default_cpu "
  46. "outputs.");
  47. HostTensorND hv;
  48. if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
  49. hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
  50. auto* ptr = hv.ptr<dt_int32>();
  51. for (size_t i = 0; i < shp.ndim; ++i) {
  52. ptr[i] = shp.shape[i];
  53. }
  54. } else {
  55. int32_t axis = op_def.axis;
  56. if (axis < 0) {
  57. axis += shp.ndim;
  58. }
  59. mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
  60. hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
  61. auto* ptr = hv.ptr<dt_int32>();
  62. ptr[0] = shp.shape[axis];
  63. }
  64. (*outputs)[0] = DeviceTensorND::make_proxy(hv);
  65. }
  66. HostTensorND get_var_shape_host_tensor(
  67. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  68. SmallVector<DeviceTensorND> input_tensornds;
  69. for (auto&& inp : inputs) {
  70. input_tensornds.push_back(inp->dev_tensor(false));
  71. }
  72. SmallVector<DeviceTensorND> output_tensornds = {
  73. {CompNode::default_cpu(), dtype::Int32()}};
  74. apply_on_device_tensornd(def, input_tensornds, &output_tensornds);
  75. // restore to input comp_node
  76. return HostTensorND::make_proxy(output_tensornds[0])
  77. .proxy_to_comp_node(inputs[0]->comp_node());
  78. }
  79. SmallVector<TensorPtr> apply_on_physical_tensor(
  80. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  81. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  82. return {Tensor::make(get_var_shape_host_tensor(def, inputs))};
  83. }
  84. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  85. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  86. auto&& op_def = def.cast_final_safe<GetVarShape>();
  87. auto&& desc = inputs[0];
  88. TensorShape shp;
  89. if (inputs.size() == 1) {
  90. shp = desc.layout;
  91. } else {
  92. TensorShapeArray src(inputs.size());
  93. for (size_t i = 0; i < inputs.size(); ++i) {
  94. src[i] = inputs[i].layout;
  95. if (!src[i].ndim) {
  96. return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
  97. }
  98. }
  99. megdnn::Elemwise::deduce_shape(src, shp);
  100. }
  101. if (!shp.ndim) {
  102. return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
  103. }
  104. DeviceTensorND value;
  105. if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
  106. value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
  107. auto* ptr = value.ptr<dt_int32>();
  108. for (size_t i = 0; i < shp.ndim; ++i) {
  109. ptr[i] = shp[i];
  110. }
  111. } else {
  112. int32_t axis = op_def.axis;
  113. if (axis < 0) {
  114. axis += shp.ndim;
  115. }
  116. mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
  117. value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
  118. auto* ptr = value.ptr<dt_int32>();
  119. ptr[0] = shp[axis];
  120. }
  121. return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
  122. }
  123. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  124. auto* node = &node_->cast_final_safe<opr::GetVarShape>();
  125. return GetVarShape::make(node->param());
  126. }
  127. OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
  128. .make_from_op_node(make_from_op_node)
  129. .decide_dispatch_mode(decide_dispatch_mode)
  130. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  131. .apply_on_var_node(apply_on_var_node)
  132. .apply_on_device_tensornd(apply_on_device_tensornd)
  133. .apply_on_physical_tensor(apply_on_physical_tensor)
  134. .fallback();
  135. } // namespace get_var_shape
  136. namespace param_pack {
  137. TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
  138. TensorShapeArray ret;
  139. for (auto&& i : shapes) {
  140. SmallVector<size_t> shape(i.begin(), i.end());
  141. TensorShape shp(shape);
  142. ret.push_back(shp);
  143. }
  144. return ret;
  145. }
  146. cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
  147. const OpDef& def, const VarNodeArray& inputs) {
  148. auto&& param = def.cast_final_safe<ParamPackSplit>();
  149. auto&& graph = inputs[0]->owner_graph();
  150. auto&& shapes = get_shapes(param.shapes);
  151. OperatorNodeConfig config(param.make_name());
  152. cg::OperatorNodeBase* opr =
  153. graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
  154. inputs[0], param.offsets, shapes, config));
  155. return opr;
  156. }
  157. SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
  158. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  159. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  160. auto&& param = def.cast_final_safe<ParamPackSplit>();
  161. mgb_assert(
  162. inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
  163. auto&& inp = inputs[0];
  164. auto&& shp = inp->layout();
  165. mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
  166. mgb_assert(param.shapes.size() * 2 == param.offsets.size());
  167. SmallVector<TensorPtr> ret;
  168. auto&& shapes = get_shapes(param.shapes);
  169. size_t dtype_size = inputs[0]->layout().dtype.size();
  170. for (size_t i = 0; i < shapes.size(); ++i) {
  171. // memory forward
  172. ret.push_back(inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
  173. }
  174. return ret;
  175. }
  176. OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
  177. .apply_on_var_node(param_pack_split_apply_on_var_node)
  178. .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
  179. .fallback();
  180. cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
  181. const OpDef& def, const VarNodeArray& inputs) {
  182. auto&& param = def.cast_final_safe<ParamPackConcat>();
  183. auto&& graph = inputs[0]->owner_graph();
  184. VarNodeArray inps(inputs.begin(), inputs.end() - 1);
  185. OperatorNodeConfig config{param.make_name()};
  186. cg::OperatorNodeBase* opr =
  187. graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
  188. inps, inputs.back(), param.offsets, config));
  189. return opr;
  190. }
  191. SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
  192. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  193. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  194. def.cast_final_safe<ParamPackConcat>();
  195. mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
  196. auto comp_node = inputs.front()->comp_node();
  197. auto dtype = inputs.front()->dtype();
  198. size_t nr_inputs = inputs.size() - 1;
  199. size_t nr_elems = 0;
  200. for (size_t i = 0; i < nr_inputs; ++i) {
  201. auto& input = inputs[i];
  202. mgb_assert(
  203. comp_node == input->comp_node(),
  204. "inputs for param_pack_concat must in same comp_node");
  205. mgb_assert(
  206. dtype == input->dtype(),
  207. "inputs for param_pack_concat must have same dtype");
  208. nr_elems += input->layout().total_nr_elems();
  209. }
  210. auto dest_layout = TensorLayout({nr_elems}, dtype);
  211. auto output = Tensor::make(dest_layout, comp_node);
  212. auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
  213. size_t srcs_size = sizeof(void*) * nr_inputs;
  214. void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
  215. std::shared_ptr<dt_byte> srcs_ptr = {
  216. (dt_byte*)srcs_raw_ptr,
  217. [comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }};
  218. TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
  219. size_t ws_size;
  220. {
  221. TensorShapeArray src_shapes;
  222. for (size_t i = 0; i < nr_inputs; ++i) {
  223. src_shapes.push_back(inputs[i]->shape());
  224. }
  225. ws_size = caller.op->get_workspace_in_bytes(
  226. src_shapes, inputs.back()->shape(), TensorShape{});
  227. }
  228. for (size_t i = 0; i < nr_inputs; ++i) {
  229. srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr();
  230. }
  231. HostTensorStorage srcs_storage;
  232. srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
  233. caller.op->exec(
  234. {srcs_raw_ptr, srcs_layout}, inputs.back()->dnn_tensor(),
  235. output->dnn_tensor(), caller.create_workspace(ws_size));
  236. async_release(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
  237. return {output};
  238. }
  239. OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
  240. .apply_on_var_node(param_pack_concat_apply_on_var_node)
  241. .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
  242. .fallback();
  243. } // namespace param_pack
  244. namespace split {
  245. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  246. using Options = opr::Split::Options;
  247. auto* node = &node_->cast_final_safe<opr::Split>();
  248. auto&& opt = node->options();
  249. int axis = opt.axis;
  250. mgb_assert(
  251. opt.method == Options::Method::SPECIFY,
  252. "only Split with SPECIFY output shapes is supported");
  253. mgb_assert(opt.partition.size() == opt.nr_part);
  254. return Split::make(axis, 0);
  255. }
  256. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  257. using Options = opr::Split::Options;
  258. auto&& sp = static_cast<const Split&>(def);
  259. OperatorNodeConfig config{sp.make_name()};
  260. opr::Split::Options opt;
  261. if (sp.nsections) {
  262. opt = Options::make_average(sp.axis, sp.nsections);
  263. opt.method = Options::Method::CALL_BACK;
  264. } else {
  265. opt.axis = sp.axis;
  266. opt.method = Options::Method::SPECIFY;
  267. mgb_assert(inputs.size() > 1);
  268. opt.nr_part = inputs.size() - 1;
  269. opt.partition.resize(opt.nr_part);
  270. for (size_t i = 1; i < inputs.size(); ++i)
  271. opt.partition[i - 1] = inputs[i];
  272. }
  273. return opr::Split::make(inputs[0], opt, config);
  274. }
  275. OP_TRAIT_REG(Split, Split, opr::Split)
  276. .make_from_op_node(make_from_op_node)
  277. .apply_on_var_node(apply_on_var_node)
  278. .fallback();
  279. } // namespace split
  280. } // namespace mgb::imperative