|
@@ -419,7 +419,8 @@ _INST_RNG_MAKER(2) |
|
|
template <typename Op> |
|
|
template <typename Op> |
|
|
void exec( |
|
|
void exec( |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs, |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs, |
|
|
const SmallVector<TensorPtr>& outputs) { |
|
|
|
|
|
|
|
|
const SmallVector<TensorPtr>& outputs, |
|
|
|
|
|
const SmallVector<TensorPtr>& workspace) { |
|
|
auto&& rng = op.cast_final_safe<Op>(); |
|
|
auto&& rng = op.cast_final_safe<Op>(); |
|
|
|
|
|
|
|
|
auto dest = outputs[0]; |
|
|
auto dest = outputs[0]; |
|
@@ -450,56 +451,71 @@ void exec( |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <typename Op> |
|
|
template <typename Op> |
|
|
SmallVector<CompNode> infer_output_cns( |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> infer_output_attrs( |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs) { |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs) { |
|
|
CompNode cn; |
|
|
|
|
|
|
|
|
LogicalTensorDesc dest; |
|
|
auto&& rng = op.cast_final_safe<Op>(); |
|
|
auto&& rng = op.cast_final_safe<Op>(); |
|
|
auto handle = rng.handle; |
|
|
auto handle = rng.handle; |
|
|
if (handle) { |
|
|
if (handle) { |
|
|
cn = RNGDnnOpManager::get_comp_node(handle); |
|
|
|
|
|
|
|
|
dest.comp_node = RNGDnnOpManager::get_comp_node(handle); |
|
|
} else { |
|
|
} else { |
|
|
cn = inputs[0]->comp_node(); |
|
|
|
|
|
|
|
|
dest.comp_node = inputs[0]->comp_node(); |
|
|
} |
|
|
} |
|
|
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; |
|
|
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; |
|
|
if (!rng_with_shape) { |
|
|
if (!rng_with_shape) { |
|
|
for (int i = 0; i < inputs.size(); ++i) { |
|
|
for (int i = 0; i < inputs.size(); ++i) { |
|
|
mgb_assert( |
|
|
mgb_assert( |
|
|
inputs[i]->comp_node() == cn, |
|
|
|
|
|
|
|
|
inputs[i]->comp_node() == dest.comp_node, |
|
|
"%s expects the device of inputs[%d] to be same as the device of " |
|
|
"%s expects the device of inputs[%d] to be same as the device of " |
|
|
"handle; " |
|
|
"handle; " |
|
|
"got %s and %s actually", |
|
|
"got %s and %s actually", |
|
|
rng.dyn_typeinfo()->name, i, |
|
|
rng.dyn_typeinfo()->name, i, |
|
|
inputs[i]->comp_node().to_string().c_str(), cn.to_string().c_str()); |
|
|
|
|
|
|
|
|
inputs[i]->comp_node().to_string().c_str(), |
|
|
|
|
|
dest.comp_node.to_string().c_str()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return {cn}; |
|
|
|
|
|
|
|
|
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng); |
|
|
|
|
|
return {dest}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <> |
|
|
template <> |
|
|
SmallVector<CompNode> infer_output_cns<ShuffleRNG>( |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>( |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs) { |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs) { |
|
|
SmallVector<CompNode> cns(2); |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> dests(2); |
|
|
auto&& rng = op.cast_final_safe<ShuffleRNG>(); |
|
|
auto&& rng = op.cast_final_safe<ShuffleRNG>(); |
|
|
auto handle = rng.handle; |
|
|
auto handle = rng.handle; |
|
|
if (handle) { |
|
|
if (handle) { |
|
|
cns[0] = RNGDnnOpManager::get_comp_node(handle); |
|
|
|
|
|
cns[1] = RNGDnnOpManager::get_comp_node(handle); |
|
|
|
|
|
|
|
|
dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); |
|
|
|
|
|
dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); |
|
|
} else { |
|
|
} else { |
|
|
cns[0] = inputs[0]->comp_node(); |
|
|
|
|
|
cns[1] = inputs[0]->comp_node(); |
|
|
|
|
|
|
|
|
dests[0].comp_node = inputs[0]->comp_node(); |
|
|
|
|
|
dests[1].comp_node = inputs[0]->comp_node(); |
|
|
} |
|
|
} |
|
|
return cns; |
|
|
|
|
|
|
|
|
dests[0].layout = TensorLayout(inputs[0]->layout()); |
|
|
|
|
|
dests[0].layout.dtype = inputs[0]->layout().dtype; |
|
|
|
|
|
dests[1].layout = |
|
|
|
|
|
TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); |
|
|
|
|
|
return dests; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <> |
|
|
template <> |
|
|
SmallVector<CompNode> infer_output_cns<Dropout>( |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs) { |
|
|
const OpDef& op, const SmallVector<TensorPtr>& inputs) { |
|
|
SmallVector<CompNode> cns(2); |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> dests(2); |
|
|
auto&& cn = inputs[0]->comp_node(); |
|
|
auto&& cn = inputs[0]->comp_node(); |
|
|
|
|
|
|
|
|
cns[0] = cn; |
|
|
|
|
|
cns[1] = cn; |
|
|
|
|
|
return cns; |
|
|
|
|
|
|
|
|
dests[0].comp_node = cn; |
|
|
|
|
|
dests[0].layout = TensorLayout(inputs[0]->layout()); |
|
|
|
|
|
dests[0].layout.dtype = inputs[0]->layout().dtype; |
|
|
|
|
|
|
|
|
|
|
|
auto get_mask_size = [&]() -> size_t { |
|
|
|
|
|
auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); |
|
|
|
|
|
return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes( |
|
|
|
|
|
inputs[0]->layout()); |
|
|
|
|
|
}; |
|
|
|
|
|
dests[1].comp_node = cn; |
|
|
|
|
|
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); |
|
|
|
|
|
return dests; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <typename Op> |
|
|
template <typename Op> |
|
@@ -507,11 +523,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs, |
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs, |
|
|
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { |
|
|
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { |
|
|
SmallVector<TensorPtr> outputs; |
|
|
SmallVector<TensorPtr> outputs; |
|
|
SmallVector<CompNode> cns = infer_output_cns<Op>(def, inputs); |
|
|
|
|
|
for (size_t i = 0; i < cns.size(); i++) { |
|
|
|
|
|
outputs.push_back(Tensor::make(output_descs[i].layout, cns[i])); |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs); |
|
|
|
|
|
for (auto&& i : desc) { |
|
|
|
|
|
outputs.push_back(Tensor::make(i.layout, i.comp_node)); |
|
|
} |
|
|
} |
|
|
exec<Op>(def, inputs, outputs); |
|
|
|
|
|
|
|
|
exec<Op>(def, inputs, outputs, {}); |
|
|
return outputs; |
|
|
return outputs; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|