Browse Source

refactor(imperative): modify the python interface of custom op

GitOrigin-RevId: e82e5de480
release-1.6
Megvii Engine Team 3 years ago
parent
commit
90dd07161c
5 changed files with 22 additions and 45 deletions
  1. +6
    -11
      imperative/python/megengine/core/ops/custom.py
  2. +5
    -5
      imperative/python/src/ops.cpp
  3. +11
    -11
      imperative/src/impl/ops/custom_opdef.cpp
  4. +0
    -13
      imperative/src/include/megbrain/imperative/ops/custom_opdef.h
  5. +0
    -5
      src/opr/impl/custom_opnode.cpp

imperative/python/megengine/core/ops/custom/__init__.py → imperative/python/megengine/core/ops/custom.py View File

@@ -7,24 +7,19 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..._imperative_rt.ops import _custom
from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op

__all__ = []
__all__ = ["load"]

for k, v in _custom.__dict__.items():
globals()[k] = v
__all__.append(k)


def gen_custom_op_maker(custom_op_name):
def _gen_custom_op_maker(custom_op_name):
def op_maker(**kwargs):
return make_custom_op(custom_op_name, kwargs)
return _make_custom_op(custom_op_name, kwargs)
return op_maker


def load(lib_path):
op_in_this_lib = install(lib_path[0:-3], lib_path)
op_in_this_lib = _install(lib_path[0:-3], lib_path)
for op in op_in_this_lib:
op_maker = gen_custom_op_maker(op)
op_maker = _gen_custom_op_maker(op)
globals()[op] = op_maker
__all__.append(op)

+ 5
- 5
imperative/python/src/ops.cpp View File

@@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) {
for (const auto &op: ops_in_lib) {
ret.append(op);
}
return std::move(ret);
return ret;
}

bool uninstall_custom(const std::string &name) {
@@ -701,12 +701,12 @@ py::list get_custom_op_list(void) {
}

void init_custom(pybind11::module m) {
m.def("install", &install_custom);
m.def("uninstall", &uninstall_custom);
m.def("get_custom_op_list", &get_custom_op_list);
m.def("_install", &install_custom);
m.def("_uninstall", &uninstall_custom);
m.def("_get_custom_op_list", &get_custom_op_list);

static PyMethodDef method_def = {
"make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
"_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
};
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
pybind11::setattr(m, method_def.ml_name, func);


+ 11
- 11
imperative/src/impl/ops/custom_opdef.cpp View File

@@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) {
return op.name();
}

} // custom_opdef

OP_TRAIT_REG(CustomOpDef, CustomOpDef)
.apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor)
.apply_on_var_node(imperative::custom_opdef::apply_on_var_node)
.apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd)
.infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible)
.infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc)
.hash(imperative::custom_opdef::hash)
.is_same_st(imperative::custom_opdef::is_same_st)
.props(imperative::custom_opdef::props)
.make_name(imperative::custom_opdef::make_name)
.apply_on_physical_tensor(apply_on_physical_tensor)
.apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.hash(hash)
.is_same_st(is_same_st)
.props(props)
.make_name(make_name)
.fallback();

} // custom_opdef

} // imperative
} // mgb

+ 0
- 13
imperative/src/include/megbrain/imperative/ops/custom_opdef.h View File

@@ -60,18 +60,5 @@ public:
std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const;
};

namespace custom_opdef { // avoid name conflict

void apply_on_device_tensornd(const OpDef&, const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*);
SmallVector<TensorPtr> apply_on_physical_tensor(const OpDef&, const SmallVector<TensorPtr>&);
VarNodeArray apply_on_var_node(const OpDef&, const cg::VarNodeArray&);
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef&, const SmallVector<LogicalTensorDesc>&);
size_t hash(const OpDef&);
bool is_same_st(const OpDef&, const OpDef&);
std::vector<std::pair<const char*, std::string>> props(const OpDef&);
std::string make_name(const OpDef&);

} // custom_opdef

} // imperative
} // mgb

+ 0
- 5
src/opr/impl/custom_opnode.cpp View File

@@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() {
}

cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const {
// auto ret = &const_cast<OperatorNodeBase::NodeProp&>(node_prop());
// for (auto &&inp_var: input())
// ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE);
// ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE);
// return ret;
return OperatorNodeBase::do_make_node_prop();
}



Loading…
Cancel
Save