|
|
@@ -0,0 +1,304 @@ |
|
|
|
/** |
|
|
|
* \file imperative/src/impl/ops/custom_opdef.cpp |
|
|
|
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
|
* |
|
|
|
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
* software distributed under the License is distributed on an |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "megbrain/imperative/ops/custom_opdef.h" |
|
|
|
#include "megbrain/opr/custom_opnode.h" |
|
|
|
#include "megbrain/custom/data_adaptor.h" |
|
|
|
#include "../op_trait.h" |
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
namespace imperative { |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpDef); |
|
|
|
|
|
|
|
CustomOpDef::CustomOpDef(const std::shared_ptr<const custom::CustomOp> &op) |
|
|
|
: m_op(op), m_param(op->param_info()) {} |
|
|
|
|
|
|
|
CustomOpDef::CustomOpDef(const std::shared_ptr<const custom::CustomOp> &op, |
|
|
|
const custom::Param ¶m) |
|
|
|
: m_op(op), m_param(param) {} |
|
|
|
|
|
|
|
void CustomOpDef::param(const custom::Param &rhs) { |
|
|
|
m_param = rhs; |
|
|
|
} |
|
|
|
|
|
|
|
custom::Param &CustomOpDef::param(void) { |
|
|
|
return m_param; |
|
|
|
} |
|
|
|
|
|
|
|
custom::Param CustomOpDef::param(void) const { |
|
|
|
return m_param; |
|
|
|
} |
|
|
|
|
|
|
|
size_t CustomOpDef::input_num(void) const { |
|
|
|
return m_op->input_num(); |
|
|
|
} |
|
|
|
|
|
|
|
size_t CustomOpDef::output_num(void) const { |
|
|
|
return m_op->output_num(); |
|
|
|
} |
|
|
|
|
|
|
|
std::string CustomOpDef::name(void) const { |
|
|
|
return m_op->op_type(); |
|
|
|
} |
|
|
|
|
|
|
|
custom::RunTimeId CustomOpDef::runtime_id(void) const { |
|
|
|
return m_op->runtime_id(); |
|
|
|
} |
|
|
|
|
|
|
|
const std::shared_ptr<const custom::CustomOp> &CustomOpDef::impl(void) const { |
|
|
|
return m_op; |
|
|
|
} |
|
|
|
|
|
|
|
void CustomOpDef::compute(const SmallVector<DeviceTensorND> &inputs, |
|
|
|
SmallVector<DeviceTensorND> *outputs) const { |
|
|
|
std::vector<custom::Tensor> custom_inputs = |
|
|
|
custom::to_custom<DeviceTensorND, custom::Tensor>(inputs); |
|
|
|
std::vector<custom::Tensor> custom_outputs = |
|
|
|
custom::to_custom<DeviceTensorND, custom::Tensor>(*outputs); |
|
|
|
m_op->compute(custom_inputs, this->m_param, custom_outputs); |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( |
|
|
|
const SmallVector<TensorPtr> &inputs) const { |
|
|
|
SmallVector<LogicalTensorDesc> input_descs(inputs.size()); |
|
|
|
for (int i=0; i<inputs.size(); i++) { |
|
|
|
input_descs[i].comp_node = inputs[i]->comp_node(); |
|
|
|
input_descs[i].layout = inputs[i]->layout(); |
|
|
|
} |
|
|
|
return std::move(this->infer_output_attrs(input_descs)); |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( |
|
|
|
const SmallVector<LogicalTensorDesc> &inputs) const { |
|
|
|
SmallVector<CompNode> i_devices(inputs.size()); |
|
|
|
SmallVector<TensorShape> i_shapes(inputs.size()); |
|
|
|
SmallVector<megdnn::DType> i_dtypes(inputs.size()); |
|
|
|
SmallVector<TensorFormat> i_formats(inputs.size()); |
|
|
|
|
|
|
|
for (int i=0; i<inputs.size(); i++) { |
|
|
|
i_devices[i] = inputs[i].comp_node; |
|
|
|
i_shapes[i] = inputs[i].layout; // TensorLayout is derived from TensorShape |
|
|
|
i_dtypes[i] = inputs[i].layout.dtype; |
|
|
|
i_formats[i] = inputs[i].layout.format; |
|
|
|
} |
|
|
|
|
|
|
|
bool success = true; |
|
|
|
for (auto i_shape: i_shapes) { |
|
|
|
if (i_shape.ndim == 0) { |
|
|
|
success = false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<CompNode> o_devices; |
|
|
|
SmallVector<megdnn::DType> o_dtypes; |
|
|
|
SmallVector<TensorFormat> o_formats; |
|
|
|
SmallVector<TensorShape> o_shapes; |
|
|
|
|
|
|
|
o_devices = custom::to_builtin<CompNode, custom::Device>( |
|
|
|
m_op->infer_output_device( |
|
|
|
custom::to_custom<CompNode, custom::Device>(i_devices), this->m_param |
|
|
|
) |
|
|
|
); |
|
|
|
o_dtypes = custom::to_builtin<megdnn::DType, custom::DType>( |
|
|
|
m_op->infer_output_dtype( |
|
|
|
custom::to_custom<megdnn::DType, custom::DType>(i_dtypes), this->m_param |
|
|
|
) |
|
|
|
); |
|
|
|
o_formats = custom::to_builtin<TensorFormat, custom::Format>( |
|
|
|
m_op->infer_output_format( |
|
|
|
custom::to_custom<TensorFormat, custom::Format>(i_formats), this->m_param |
|
|
|
) |
|
|
|
); |
|
|
|
|
|
|
|
if (success) { |
|
|
|
o_shapes = custom::to_builtin<TensorShape, custom::Shape>( |
|
|
|
m_op->infer_output_shape( |
|
|
|
custom::to_custom<TensorShape, custom::Shape>(i_shapes), this->m_param |
|
|
|
) |
|
|
|
); |
|
|
|
} |
|
|
|
else { |
|
|
|
o_shapes = SmallVector<TensorShape>(this->output_num()); |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> outputs(this->output_num()); |
|
|
|
for (int i=0; i<this->output_num(); i++) { |
|
|
|
outputs[i].comp_node = std::move(o_devices[i]); |
|
|
|
outputs[i].layout = std::move( |
|
|
|
TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) |
|
|
|
); |
|
|
|
} |
|
|
|
return std::tuple<SmallVector<LogicalTensorDesc>, bool>(outputs, success); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
CustomOpDefFactory *CustomOpDefFactory::inst(void) { |
|
|
|
static CustomOpDefFactory factory; |
|
|
|
return &factory; |
|
|
|
} |
|
|
|
|
|
|
|
bool CustomOpDefFactory::is_custom_op(const OpDef &op) { |
|
|
|
return op.dyn_typeinfo() == CustomOpDef::typeinfo(); |
|
|
|
} |
|
|
|
|
|
|
|
CustomOpDefFactory::CustomOpDefFactory() { |
|
|
|
ops = custom::CustomOpManager::inst(); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::string> CustomOpDefFactory::op_list(void) const { |
|
|
|
return ops->op_name_list(); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const std::string &op_type) const { |
|
|
|
auto op = ops->find(op_type); |
|
|
|
return std::make_shared<CustomOpDef>(op); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const custom::RunTimeId &op_id) const { |
|
|
|
auto op = ops->find(op_id); |
|
|
|
return std::make_shared<CustomOpDef>(op); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const std::string &op_type, const custom::Param ¶m) const { |
|
|
|
auto op = ops->find(op_type); |
|
|
|
return std::make_shared<CustomOpDef>(op, param); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const custom::RunTimeId &op_id, const custom::Param ¶m) const { |
|
|
|
auto op = ops->find(op_id); |
|
|
|
return std::make_shared<CustomOpDef>(op, param); |
|
|
|
} |
|
|
|
|
|
|
|
namespace custom_opdef { // avoid name conflict |
|
|
|
|
|
|
|
void apply_on_device_tensornd(const OpDef& def, |
|
|
|
const SmallVector<DeviceTensorND>& inputs, |
|
|
|
SmallVector<DeviceTensorND>* outputs) { |
|
|
|
for (auto &&output: (*outputs)) { |
|
|
|
auto cn = output.comp_node(); |
|
|
|
cn.activate(); |
|
|
|
} |
|
|
|
CompNode::sync_all(); |
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
op.compute(inputs, outputs); |
|
|
|
|
|
|
|
// for (auto &&output: (*outputs)) { |
|
|
|
// auto cn = output.comp_node(); |
|
|
|
// cn.sync(); // cannot sync ?????????? |
|
|
|
// } |
|
|
|
CompNode::sync_all(); |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, const SmallVector<TensorPtr> &inputs) { |
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
auto [output_descs, success] = op.infer_output_attrs(inputs); |
|
|
|
mgb_assert(success == true, "infer output attributes fall\n"); |
|
|
|
SmallVector<TensorPtr> outputs(output_descs.size()); |
|
|
|
|
|
|
|
for (size_t i=0; i<outputs.size(); ++i) { |
|
|
|
auto& output = outputs[i]; |
|
|
|
auto& output_desc = output_descs[i]; |
|
|
|
output = Tensor::make(output_desc.layout, output_desc.comp_node); |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); |
|
|
|
SmallVector<DeviceTensorND> oup_tensornds(outputs.size()); |
|
|
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) |
|
|
|
inp_tensornds[i] = inputs[i]->dev_tensor(); |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) |
|
|
|
oup_tensornds[i] = outputs[i]->dev_tensor(); |
|
|
|
|
|
|
|
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) { |
|
|
|
SymbolVarArray input_syms; |
|
|
|
for (auto &input_var: inputs) |
|
|
|
input_syms.emplace_back(input_var); |
|
|
|
|
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
OperatorNodeConfig config; |
|
|
|
SymbolVarArray output_syms = opr::CustomOpNode::make( |
|
|
|
op.impl(), input_syms, op.param(), config |
|
|
|
); |
|
|
|
|
|
|
|
VarNodeArray outputs; |
|
|
|
for (auto &output_sym: output_syms) |
|
|
|
outputs.push_back(output_sym.node()); |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
return op.infer_output_attrs(inputs); |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<TensorPtr>& inputs_tensors, |
|
|
|
const SmallVector<MemoryDesc>& inputs_mems) { |
|
|
|
return {{}, {}}; |
|
|
|
} |
|
|
|
|
|
|
|
size_t hash(const OpDef& def) { |
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
const custom::Param ¶m = op.param(); |
|
|
|
size_t val = mgb::hash(op.runtime_id()); |
|
|
|
std::string hash_str = ""; |
|
|
|
for (auto &&val: param.raw()) { |
|
|
|
hash_str += val.first; |
|
|
|
hash_str += val.second.str(); |
|
|
|
} |
|
|
|
|
|
|
|
val = mgb::hash_pair_combine(val, mgb::hash(hash_str)); |
|
|
|
return val; |
|
|
|
} |
|
|
|
|
|
|
|
bool is_same_st(const OpDef& lhs, const OpDef& rhs) { |
|
|
|
auto &&a = static_cast<const CustomOpDef&>(lhs), |
|
|
|
&&b = static_cast<const CustomOpDef&>(rhs); |
|
|
|
return a.param() == b.param() && a.runtime_id() == b.runtime_id(); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { |
|
|
|
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); |
|
|
|
// can be implement with param schema |
|
|
|
// auto&& custom_opdef = def.cast_final_safe<CustomOpDef>(); |
|
|
|
std::vector<std::pair<const char*, std::string>> props_; |
|
|
|
return props_; |
|
|
|
} |
|
|
|
|
|
|
|
std::string make_name(const OpDef& def) { |
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
return op.name(); |
|
|
|
} |
|
|
|
|
|
|
|
} // custom_opdef |
|
|
|
|
|
|
|
OP_TRAIT_REG(CustomOpDef, CustomOpDef) |
|
|
|
.apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor) |
|
|
|
.apply_on_var_node(imperative::custom_opdef::apply_on_var_node) |
|
|
|
.apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd) |
|
|
|
.infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible) |
|
|
|
.infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc) |
|
|
|
.hash(imperative::custom_opdef::hash) |
|
|
|
.is_same_st(imperative::custom_opdef::is_same_st) |
|
|
|
.props(imperative::custom_opdef::props) |
|
|
|
.make_name(imperative::custom_opdef::make_name) |
|
|
|
.fallback(); |
|
|
|
|
|
|
|
} // imperative |
|
|
|
} // mgb |