GitOrigin-RevId: d7877f2e32
release-1.6
@@ -0,0 +1,30 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..._imperative_rt.ops import _custom | |||||
__all__ = [] | |||||
for k, v in _custom.__dict__.items(): | |||||
globals()[k] = v | |||||
__all__.append(k) | |||||
def gen_custom_op_maker(custom_op_name): | |||||
def op_maker(**kwargs): | |||||
return make_custom_op(custom_op_name, kwargs) | |||||
return op_maker | |||||
def load(lib_path): | |||||
op_in_this_lib = install(lib_path[0:-3], lib_path) | |||||
for op in op_in_this_lib: | |||||
op_maker = gen_custom_op_maker(op) | |||||
globals()[op] = op_maker | |||||
__all__.append(op) |
@@ -13,6 +13,7 @@ from collections import OrderedDict | |||||
import numpy as np | import numpy as np | ||||
import megengine as mge | import megengine as mge | ||||
from megengine.core.ops import custom | |||||
from megengine.core.tensor import megbrain_graph as G | from megengine.core.tensor import megbrain_graph as G | ||||
from megengine.device import get_device_count, set_default_device | from megengine.device import get_device_count, set_default_device | ||||
from megengine.functional.debug_param import set_execution_strategy | from megengine.functional.debug_param import set_execution_strategy | ||||
@@ -397,6 +398,10 @@ def main(): | |||||
type=str, | type=str, | ||||
help="Record the static graph's static memory info.", | help="Record the static graph's static memory info.", | ||||
) | ) | ||||
parser.add_argument( | |||||
"--custom-op-lib", type=str, help="path of the custom op", | |||||
) | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.verbose: | if args.verbose: | ||||
@@ -409,6 +414,8 @@ def main(): | |||||
if args.dump_cpp_model: | if args.dump_cpp_model: | ||||
args.embed_input = True | args.embed_input = True | ||||
if args.custom_op_lib is not None: | |||||
custom.load(args.custom_op_lib) | |||||
logger.info("loading model ...") | logger.info("loading model ...") | ||||
ret = G.load_graph(args.net) | ret = G.load_graph(args.net) | ||||
@@ -607,4 +607,107 @@ void init_ops(py::module m) { | |||||
.def("compile", [](PySubgraphBuilder& self, int gopt_level){ | .def("compile", [](PySubgraphBuilder& self, int gopt_level){ | ||||
return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level); | return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level); | ||||
}); | }); | ||||
auto custom = submodule(m, "_custom"); | |||||
init_custom(custom); | |||||
} | |||||
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \ | |||||
case mgb::custom::ParamDynType::dyn_type: { \ | |||||
param_val = py::handle(kv.second).cast<static_type>(); \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \ | |||||
case mgb::custom::ParamDynType::dyn_type: { \ | |||||
auto pyvals = py::handle(kv.second).cast<py::list>(); \ | |||||
static_type vals; \ | |||||
using basic_type = \ | |||||
mgb::custom::get_vector_template_arg_type<static_type>::type; \ | |||||
for (auto &pyval: pyvals) { \ | |||||
vals.push_back(py::handle(pyval).cast<basic_type>()); \ | |||||
} \ | |||||
param_val = vals; \ | |||||
break; \ | |||||
} | |||||
PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyObject *kwnames) { | |||||
auto op_name = py::handle(args[0]).cast<std::string>(); | |||||
auto kwargs = py::handle(args[1]).cast<py::dict>(); | |||||
std::shared_ptr<OpDef> opdef = CustomOpDefFactory::inst()->create_opdef(op_name); | |||||
auto &custom_opdef = static_cast<mgb::imperative::CustomOpDef&>(*opdef); | |||||
auto ¶m = custom_opdef.param(); | |||||
for (auto &&kv: kwargs) { | |||||
std::string param_name = py::handle(kv.first).cast<std::string>(); | |||||
std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name; | |||||
if (!param.exist(param_name)) { | |||||
mgb_log_warn( | |||||
"op %s have no param named %s, ignore this param parsed from python", | |||||
op_name.c_str(), param_name.c_str() | |||||
); | |||||
continue; | |||||
} | |||||
auto& param_val = param[param_name]; | |||||
switch (param_val.type()) { | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST) | |||||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST) | |||||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) | |||||
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) | |||||
CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) | |||||
default: { | |||||
mgb_assert( | |||||
false, "param dtype of %s:%s is invalid", | |||||
op_name.c_str(), param_name.c_str() | |||||
); | |||||
} | |||||
} | |||||
} | |||||
PyTypeObject* pytype; | |||||
pytype = &PyOpType(OpDef); | |||||
PyObject* obj = pytype->tp_alloc(pytype, 0); | |||||
reinterpret_cast<PyOp(OpDef)*>(obj)->op = opdef; | |||||
return obj; | |||||
} | |||||
#undef CUSTOM_CASE_TO_PARSE_LIST | |||||
#undef CUSTOM_CASE_TO_PARSE_NON_LIST | |||||
py::list install_custom(const std::string &name, const std::string &path) { | |||||
py::list ret; | |||||
const auto &ops_in_lib = mgb::custom::LibManager::inst()->install(name, path); | |||||
for (const auto &op: ops_in_lib) { | |||||
ret.append(op); | |||||
} | |||||
return std::move(ret); | |||||
} | |||||
bool uninstall_custom(const std::string &name) { | |||||
return mgb::custom::LibManager::inst()->uninstall(name); | |||||
} | |||||
py::list get_custom_op_list(void) { | |||||
std::vector<std::string> all_ops = CustomOpDefFactory::inst()->op_list(); | |||||
py::list ret; | |||||
for (auto &op: all_ops) { | |||||
ret.append(op); | |||||
} | |||||
return std::move(ret); | |||||
} | |||||
void init_custom(pybind11::module m) { | |||||
m.def("install", &install_custom); | |||||
m.def("uninstall", &uninstall_custom); | |||||
m.def("get_custom_op_list", &get_custom_op_list); | |||||
static PyMethodDef method_def = { | |||||
"make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | |||||
}; | |||||
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | |||||
pybind11::setattr(m, method_def.ml_name, func); | |||||
} | } |
@@ -16,6 +16,7 @@ | |||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megbrain/opr/param_defs.h" | #include "megbrain/opr/param_defs.h" | ||||
#include "megbrain/imperative/ops/custom_opdef.h" | |||||
namespace PYBIND11_NAMESPACE { | namespace PYBIND11_NAMESPACE { | ||||
namespace detail { | namespace detail { | ||||
@@ -35,3 +36,4 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(ENUM_CASTER_DEF) | |||||
} // PYBIND11_NAMESPACE | } // PYBIND11_NAMESPACE | ||||
void init_ops(pybind11::module m); | void init_ops(pybind11::module m); | ||||
void init_custom(pybind11::module m); |
@@ -0,0 +1,304 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/custom_opdef.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/custom_opdef.h" | |||||
#include "megbrain/opr/custom_opnode.h" | |||||
#include "megbrain/custom/data_adaptor.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpDef); | |||||
CustomOpDef::CustomOpDef(const std::shared_ptr<const custom::CustomOp> &op) | |||||
: m_op(op), m_param(op->param_info()) {} | |||||
CustomOpDef::CustomOpDef(const std::shared_ptr<const custom::CustomOp> &op, | |||||
const custom::Param ¶m) | |||||
: m_op(op), m_param(param) {} | |||||
void CustomOpDef::param(const custom::Param &rhs) { | |||||
m_param = rhs; | |||||
} | |||||
custom::Param &CustomOpDef::param(void) { | |||||
return m_param; | |||||
} | |||||
custom::Param CustomOpDef::param(void) const { | |||||
return m_param; | |||||
} | |||||
size_t CustomOpDef::input_num(void) const { | |||||
return m_op->input_num(); | |||||
} | |||||
size_t CustomOpDef::output_num(void) const { | |||||
return m_op->output_num(); | |||||
} | |||||
std::string CustomOpDef::name(void) const { | |||||
return m_op->op_type(); | |||||
} | |||||
custom::RunTimeId CustomOpDef::runtime_id(void) const { | |||||
return m_op->runtime_id(); | |||||
} | |||||
const std::shared_ptr<const custom::CustomOp> &CustomOpDef::impl(void) const { | |||||
return m_op; | |||||
} | |||||
void CustomOpDef::compute(const SmallVector<DeviceTensorND> &inputs, | |||||
SmallVector<DeviceTensorND> *outputs) const { | |||||
std::vector<custom::Tensor> custom_inputs = | |||||
custom::to_custom<DeviceTensorND, custom::Tensor>(inputs); | |||||
std::vector<custom::Tensor> custom_outputs = | |||||
custom::to_custom<DeviceTensorND, custom::Tensor>(*outputs); | |||||
m_op->compute(custom_inputs, this->m_param, custom_outputs); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( | |||||
const SmallVector<TensorPtr> &inputs) const { | |||||
SmallVector<LogicalTensorDesc> input_descs(inputs.size()); | |||||
for (int i=0; i<inputs.size(); i++) { | |||||
input_descs[i].comp_node = inputs[i]->comp_node(); | |||||
input_descs[i].layout = inputs[i]->layout(); | |||||
} | |||||
return std::move(this->infer_output_attrs(input_descs)); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( | |||||
const SmallVector<LogicalTensorDesc> &inputs) const { | |||||
SmallVector<CompNode> i_devices(inputs.size()); | |||||
SmallVector<TensorShape> i_shapes(inputs.size()); | |||||
SmallVector<megdnn::DType> i_dtypes(inputs.size()); | |||||
SmallVector<TensorFormat> i_formats(inputs.size()); | |||||
for (int i=0; i<inputs.size(); i++) { | |||||
i_devices[i] = inputs[i].comp_node; | |||||
i_shapes[i] = inputs[i].layout; // TensorLayout is derived from TensorShape | |||||
i_dtypes[i] = inputs[i].layout.dtype; | |||||
i_formats[i] = inputs[i].layout.format; | |||||
} | |||||
bool success = true; | |||||
for (auto i_shape: i_shapes) { | |||||
if (i_shape.ndim == 0) { | |||||
success = false; | |||||
} | |||||
} | |||||
SmallVector<CompNode> o_devices; | |||||
SmallVector<megdnn::DType> o_dtypes; | |||||
SmallVector<TensorFormat> o_formats; | |||||
SmallVector<TensorShape> o_shapes; | |||||
o_devices = custom::to_builtin<CompNode, custom::Device>( | |||||
m_op->infer_output_device( | |||||
custom::to_custom<CompNode, custom::Device>(i_devices), this->m_param | |||||
) | |||||
); | |||||
o_dtypes = custom::to_builtin<megdnn::DType, custom::DType>( | |||||
m_op->infer_output_dtype( | |||||
custom::to_custom<megdnn::DType, custom::DType>(i_dtypes), this->m_param | |||||
) | |||||
); | |||||
o_formats = custom::to_builtin<TensorFormat, custom::Format>( | |||||
m_op->infer_output_format( | |||||
custom::to_custom<TensorFormat, custom::Format>(i_formats), this->m_param | |||||
) | |||||
); | |||||
if (success) { | |||||
o_shapes = custom::to_builtin<TensorShape, custom::Shape>( | |||||
m_op->infer_output_shape( | |||||
custom::to_custom<TensorShape, custom::Shape>(i_shapes), this->m_param | |||||
) | |||||
); | |||||
} | |||||
else { | |||||
o_shapes = SmallVector<TensorShape>(this->output_num()); | |||||
} | |||||
SmallVector<LogicalTensorDesc> outputs(this->output_num()); | |||||
for (int i=0; i<this->output_num(); i++) { | |||||
outputs[i].comp_node = std::move(o_devices[i]); | |||||
outputs[i].layout = std::move( | |||||
TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) | |||||
); | |||||
} | |||||
return std::tuple<SmallVector<LogicalTensorDesc>, bool>(outputs, success); | |||||
} | |||||
CustomOpDefFactory *CustomOpDefFactory::inst(void) { | |||||
static CustomOpDefFactory factory; | |||||
return &factory; | |||||
} | |||||
bool CustomOpDefFactory::is_custom_op(const OpDef &op) { | |||||
return op.dyn_typeinfo() == CustomOpDef::typeinfo(); | |||||
} | |||||
CustomOpDefFactory::CustomOpDefFactory() { | |||||
ops = custom::CustomOpManager::inst(); | |||||
} | |||||
std::vector<std::string> CustomOpDefFactory::op_list(void) const { | |||||
return ops->op_name_list(); | |||||
} | |||||
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const std::string &op_type) const { | |||||
auto op = ops->find(op_type); | |||||
return std::make_shared<CustomOpDef>(op); | |||||
} | |||||
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const custom::RunTimeId &op_id) const { | |||||
auto op = ops->find(op_id); | |||||
return std::make_shared<CustomOpDef>(op); | |||||
} | |||||
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const std::string &op_type, const custom::Param ¶m) const { | |||||
auto op = ops->find(op_type); | |||||
return std::make_shared<CustomOpDef>(op, param); | |||||
} | |||||
std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(const custom::RunTimeId &op_id, const custom::Param ¶m) const { | |||||
auto op = ops->find(op_id); | |||||
return std::make_shared<CustomOpDef>(op, param); | |||||
} | |||||
namespace custom_opdef { // avoid name conflict | |||||
void apply_on_device_tensornd(const OpDef& def, | |||||
const SmallVector<DeviceTensorND>& inputs, | |||||
SmallVector<DeviceTensorND>* outputs) { | |||||
for (auto &&output: (*outputs)) { | |||||
auto cn = output.comp_node(); | |||||
cn.activate(); | |||||
} | |||||
CompNode::sync_all(); | |||||
auto&& op = static_cast<const CustomOpDef&>(def); | |||||
op.compute(inputs, outputs); | |||||
// for (auto &&output: (*outputs)) { | |||||
// auto cn = output.comp_node(); | |||||
// cn.sync(); // cannot sync ?????????? | |||||
// } | |||||
CompNode::sync_all(); | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr> &inputs) { | |||||
auto&& op = static_cast<const CustomOpDef&>(def); | |||||
auto [output_descs, success] = op.infer_output_attrs(inputs); | |||||
mgb_assert(success == true, "infer output attributes fall\n"); | |||||
SmallVector<TensorPtr> outputs(output_descs.size()); | |||||
for (size_t i=0; i<outputs.size(); ++i) { | |||||
auto& output = outputs[i]; | |||||
auto& output_desc = output_descs[i]; | |||||
output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||||
} | |||||
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||||
SmallVector<DeviceTensorND> oup_tensornds(outputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) | |||||
inp_tensornds[i] = inputs[i]->dev_tensor(); | |||||
for (size_t i = 0; i < outputs.size(); ++i) | |||||
oup_tensornds[i] = outputs[i]->dev_tensor(); | |||||
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | |||||
return outputs; | |||||
} | |||||
VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) { | |||||
SymbolVarArray input_syms; | |||||
for (auto &input_var: inputs) | |||||
input_syms.emplace_back(input_var); | |||||
auto&& op = static_cast<const CustomOpDef&>(def); | |||||
OperatorNodeConfig config; | |||||
SymbolVarArray output_syms = opr::CustomOpNode::make( | |||||
op.impl(), input_syms, op.param(), config | |||||
); | |||||
VarNodeArray outputs; | |||||
for (auto &output_sym: output_syms) | |||||
outputs.push_back(output_sym.node()); | |||||
return outputs; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto&& op = static_cast<const CustomOpDef&>(def); | |||||
return op.infer_output_attrs(inputs); | |||||
} | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return {{}, {}}; | |||||
} | |||||
size_t hash(const OpDef& def) { | |||||
auto&& op = static_cast<const CustomOpDef&>(def); | |||||
const custom::Param ¶m = op.param(); | |||||
size_t val = mgb::hash(op.runtime_id()); | |||||
std::string hash_str = ""; | |||||
for (auto &&val: param.raw()) { | |||||
hash_str += val.first; | |||||
hash_str += val.second.str(); | |||||
} | |||||
val = mgb::hash_pair_combine(val, mgb::hash(hash_str)); | |||||
return val; | |||||
} | |||||
bool is_same_st(const OpDef& lhs, const OpDef& rhs) { | |||||
auto &&a = static_cast<const CustomOpDef&>(lhs), | |||||
&&b = static_cast<const CustomOpDef&>(rhs); | |||||
return a.param() == b.param() && a.runtime_id() == b.runtime_id(); | |||||
} | |||||
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||||
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); | |||||
// can be implement with param schema | |||||
// auto&& custom_opdef = def.cast_final_safe<CustomOpDef>(); | |||||
std::vector<std::pair<const char*, std::string>> props_; | |||||
return props_; | |||||
} | |||||
std::string make_name(const OpDef& def) { | |||||
auto&& op = static_cast<const CustomOpDef&>(def); | |||||
return op.name(); | |||||
} | |||||
} // custom_opdef | |||||
OP_TRAIT_REG(CustomOpDef, CustomOpDef) | |||||
.apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor) | |||||
.apply_on_var_node(imperative::custom_opdef::apply_on_var_node) | |||||
.apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd) | |||||
.infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible) | |||||
.infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc) | |||||
.hash(imperative::custom_opdef::hash) | |||||
.is_same_st(imperative::custom_opdef::is_same_st) | |||||
.props(imperative::custom_opdef::props) | |||||
.make_name(imperative::custom_opdef::make_name) | |||||
.fallback(); | |||||
} // imperative | |||||
} // mgb |
@@ -0,0 +1,77 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/ops/custom_opdef.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/custom/custom.h" | |||||
#include "megbrain/custom/manager.h" | |||||
#include "megbrain/imperative/op_def.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
class CustomOpDef: public OpDefImplBase<CustomOpDef> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
const std::shared_ptr<const custom::CustomOp> m_op; | |||||
custom::Param m_param; | |||||
public: | |||||
CustomOpDef(const std::shared_ptr<const custom::CustomOp> &op); | |||||
CustomOpDef(const std::shared_ptr<const custom::CustomOp> &op, | |||||
const custom::Param&); | |||||
void param(const custom::Param&); | |||||
custom::Param ¶m(void); | |||||
custom::Param param(void) const; | |||||
size_t input_num(void) const; | |||||
size_t output_num(void) const; | |||||
std::string name(void) const; | |||||
custom::RunTimeId runtime_id(void) const; | |||||
const std::shared_ptr<const custom::CustomOp> &impl(void) const; | |||||
void compute(const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*) const; | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs( | |||||
const SmallVector<TensorPtr> &inputs) const; | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs( | |||||
const SmallVector<LogicalTensorDesc>&) const; | |||||
}; | |||||
class CustomOpDefFactory { | |||||
custom::CustomOpManager *ops; | |||||
CustomOpDefFactory(); | |||||
public: | |||||
PREVENT_COPY_AND_ASSIGN(CustomOpDefFactory); | |||||
static CustomOpDefFactory *inst(void); | |||||
static bool is_custom_op(const OpDef &op); | |||||
std::vector<std::string> op_list(void) const; | |||||
std::shared_ptr<OpDef> create_opdef(const std::string&) const; | |||||
std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&) const; | |||||
std::shared_ptr<OpDef> create_opdef(const std::string&, const custom::Param&) const; | |||||
std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const; | |||||
}; | |||||
namespace custom_opdef { // avoid name conflict | |||||
void apply_on_device_tensornd(const OpDef&, const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*); | |||||
SmallVector<TensorPtr> apply_on_physical_tensor(const OpDef&, const SmallVector<TensorPtr>&); | |||||
VarNodeArray apply_on_var_node(const OpDef&, const cg::VarNodeArray&); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef&, const SmallVector<LogicalTensorDesc>&); | |||||
size_t hash(const OpDef&); | |||||
bool is_same_st(const OpDef&, const OpDef&); | |||||
std::vector<std::pair<const char*, std::string>> props(const OpDef&); | |||||
std::string make_name(const OpDef&); | |||||
} // custom_opdef | |||||
} // imperative | |||||
} // mgb |