From 84d1a440f0be69c174e1f3bd720a4b85da95cd8b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 3 Mar 2022 15:02:56 +0800 Subject: [PATCH] fix(imperative): do not use output_desc in rng ops GitOrigin-RevId: e6a399be171ea93d8b1a79842a6c066b06e3843d --- imperative/src/impl/ops/rng.cpp | 64 +++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index d6c99308..6a8e67b2 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -419,7 +419,8 @@ _INST_RNG_MAKER(2) template void exec( const OpDef& op, const SmallVector& inputs, - const SmallVector& outputs) { + const SmallVector& outputs, + const SmallVector& workspace) { auto&& rng = op.cast_final_safe(); auto dest = outputs[0]; @@ -450,56 +451,71 @@ void exec( } template -SmallVector infer_output_cns( +SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { - CompNode cn; + LogicalTensorDesc dest; auto&& rng = op.cast_final_safe(); auto handle = rng.handle; if (handle) { - cn = RNGDnnOpManager::get_comp_node(handle); + dest.comp_node = RNGDnnOpManager::get_comp_node(handle); } else { - cn = inputs[0]->comp_node(); + dest.comp_node = inputs[0]->comp_node(); } constexpr bool rng_with_shape = OpMeth::DnnOp::NR_INPUTS == 0; if (!rng_with_shape) { for (int i = 0; i < inputs.size(); ++i) { mgb_assert( - inputs[i]->comp_node() == cn, + inputs[i]->comp_node() == dest.comp_node, "%s expects the device of inputs[%d] to be same as the device of " "handle; " "got %s and %s actually", rng.dyn_typeinfo()->name, i, - inputs[i]->comp_node().to_string().c_str(), cn.to_string().c_str()); + inputs[i]->comp_node().to_string().c_str(), + dest.comp_node.to_string().c_str()); } } - return {cn}; + dest.layout = _InferLayout::do_infer(inputs[0], rng); + return {dest}; } template <> -SmallVector infer_output_cns( +SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { - SmallVector cns(2); + SmallVector dests(2); auto&& rng = op.cast_final_safe(); auto handle = rng.handle; if (handle) { - cns[0] = RNGDnnOpManager::get_comp_node(handle); - cns[1] = RNGDnnOpManager::get_comp_node(handle); + dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); + dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); } else { - cns[0] = inputs[0]->comp_node(); - cns[1] = inputs[0]->comp_node(); + dests[0].comp_node = inputs[0]->comp_node(); + dests[1].comp_node = inputs[0]->comp_node(); } - return cns; + dests[0].layout = TensorLayout(inputs[0]->layout()); + dests[0].layout.dtype = inputs[0]->layout().dtype; + dests[1].layout = + TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); + return dests; } template <> -SmallVector infer_output_cns( +SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { - SmallVector cns(2); + SmallVector dests(2); auto&& cn = inputs[0]->comp_node(); - cns[0] = cn; - cns[1] = cn; - return cns; + dests[0].comp_node = cn; + dests[0].layout = TensorLayout(inputs[0]->layout()); + dests[0].layout.dtype = inputs[0]->layout().dtype; + + auto get_mask_size = [&]() -> size_t { + auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); + return dnn_handle->create_operator()->get_mask_size_in_bytes( + inputs[0]->layout()); + }; + dests[1].comp_node = cn; + dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); + return dests; } template @@ -507,11 +523,11 @@ SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { SmallVector outputs; - SmallVector cns = infer_output_cns(def, inputs); - for (size_t i = 0; i < cns.size(); i++) { - outputs.push_back(Tensor::make(output_descs[i].layout, cns[i])); + SmallVector desc = infer_output_attrs(def, inputs); + for (auto&& i : desc) { + outputs.push_back(Tensor::make(i.layout, i.comp_node)); } - exec(def, inputs, outputs); + exec(def, inputs, outputs, {}); return outputs; }