GitOrigin-RevId: e82e5de480
release-1.6
@@ -7,24 +7,19 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ..._imperative_rt.ops import _custom | |||||
from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op | |||||
__all__ = [] | |||||
__all__ = ["load"] | |||||
for k, v in _custom.__dict__.items(): | |||||
globals()[k] = v | |||||
__all__.append(k) | |||||
def gen_custom_op_maker(custom_op_name): | |||||
def _gen_custom_op_maker(custom_op_name): | |||||
def op_maker(**kwargs): | def op_maker(**kwargs): | ||||
return make_custom_op(custom_op_name, kwargs) | |||||
return _make_custom_op(custom_op_name, kwargs) | |||||
return op_maker | return op_maker | ||||
def load(lib_path): | def load(lib_path): | ||||
op_in_this_lib = install(lib_path[0:-3], lib_path) | |||||
op_in_this_lib = _install(lib_path[0:-3], lib_path) | |||||
for op in op_in_this_lib: | for op in op_in_this_lib: | ||||
op_maker = gen_custom_op_maker(op) | |||||
op_maker = _gen_custom_op_maker(op) | |||||
globals()[op] = op_maker | globals()[op] = op_maker | ||||
__all__.append(op) | __all__.append(op) |
@@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) { | |||||
for (const auto &op: ops_in_lib) { | for (const auto &op: ops_in_lib) { | ||||
ret.append(op); | ret.append(op); | ||||
} | } | ||||
return std::move(ret); | |||||
return ret; | |||||
} | } | ||||
bool uninstall_custom(const std::string &name) { | bool uninstall_custom(const std::string &name) { | ||||
@@ -701,12 +701,12 @@ py::list get_custom_op_list(void) { | |||||
} | } | ||||
void init_custom(pybind11::module m) { | void init_custom(pybind11::module m) { | ||||
m.def("install", &install_custom); | |||||
m.def("uninstall", &uninstall_custom); | |||||
m.def("get_custom_op_list", &get_custom_op_list); | |||||
m.def("_install", &install_custom); | |||||
m.def("_uninstall", &uninstall_custom); | |||||
m.def("_get_custom_op_list", &get_custom_op_list); | |||||
static PyMethodDef method_def = { | static PyMethodDef method_def = { | ||||
"make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | |||||
"_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | |||||
}; | }; | ||||
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | ||||
pybind11::setattr(m, method_def.ml_name, func); | pybind11::setattr(m, method_def.ml_name, func); | ||||
@@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) { | |||||
return op.name(); | return op.name(); | ||||
} | } | ||||
} // custom_opdef | |||||
OP_TRAIT_REG(CustomOpDef, CustomOpDef) | OP_TRAIT_REG(CustomOpDef, CustomOpDef) | ||||
.apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor) | |||||
.apply_on_var_node(imperative::custom_opdef::apply_on_var_node) | |||||
.apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd) | |||||
.infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible) | |||||
.infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc) | |||||
.hash(imperative::custom_opdef::hash) | |||||
.is_same_st(imperative::custom_opdef::is_same_st) | |||||
.props(imperative::custom_opdef::props) | |||||
.make_name(imperative::custom_opdef::make_name) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.apply_on_device_tensornd(apply_on_device_tensornd) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.hash(hash) | |||||
.is_same_st(is_same_st) | |||||
.props(props) | |||||
.make_name(make_name) | |||||
.fallback(); | .fallback(); | ||||
} // custom_opdef | |||||
} // imperative | } // imperative | ||||
} // mgb | } // mgb |
@@ -60,18 +60,5 @@ public: | |||||
std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const; | std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const; | ||||
}; | }; | ||||
namespace custom_opdef { // avoid name conflict | |||||
void apply_on_device_tensornd(const OpDef&, const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*); | |||||
SmallVector<TensorPtr> apply_on_physical_tensor(const OpDef&, const SmallVector<TensorPtr>&); | |||||
VarNodeArray apply_on_var_node(const OpDef&, const cg::VarNodeArray&); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef&, const SmallVector<LogicalTensorDesc>&); | |||||
size_t hash(const OpDef&); | |||||
bool is_same_st(const OpDef&, const OpDef&); | |||||
std::vector<std::pair<const char*, std::string>> props(const OpDef&); | |||||
std::string make_name(const OpDef&); | |||||
} // custom_opdef | |||||
} // imperative | } // imperative | ||||
} // mgb | } // mgb |
@@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() { | |||||
} | } | ||||
cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { | cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { | ||||
// auto ret = &const_cast<OperatorNodeBase::NodeProp&>(node_prop()); | |||||
// for (auto &&inp_var: input()) | |||||
// ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE); | |||||
// ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE); | |||||
// return ret; | |||||
return OperatorNodeBase::do_make_node_prop(); | return OperatorNodeBase::do_make_node_prop(); | ||||
} | } | ||||