Browse Source

refactor(imperative): add TODO tag for some functions

GitOrigin-RevId: e295a1fa55
release-1.6
Megvii Engine Team 3 years ago
parent
commit
cdb692d2fa
3 changed files with 18 additions and 18 deletions
  1. +8
    -1
      imperative/python/megengine/core/ops/custom.py
  2. +6
    -15
      imperative/src/impl/ops/custom_opdef.cpp
  3. +4
    -2
      src/opr/impl/custom_opnode.cpp

+ 8
- 1
imperative/python/megengine/core/ops/custom.py View File

@@ -7,13 +7,20 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op
from .._imperative_rt.ops._custom import (
_get_custom_op_list,
_install,
_make_custom_op,
_uninstall,
)

__all__ = ["load"]


def _gen_custom_op_maker(custom_op_name):
def op_maker(**kwargs):
return _make_custom_op(custom_op_name, kwargs)

return op_maker




+ 6
- 15
imperative/src/impl/ops/custom_opdef.cpp View File

@@ -95,6 +95,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs
for (auto i_shape: i_shapes) {
if (i_shape.ndim == 0) {
success = false;
break;
}
}

@@ -187,14 +188,11 @@ void apply_on_device_tensornd(const OpDef& def,
auto cn = output.comp_node();
cn.activate();
}

// [TODO] sync should be modified
CompNode::sync_all();
auto&& op = static_cast<const CustomOpDef&>(def);
op.compute(inputs, outputs);

// for (auto &&output: (*outputs)) {
// auto cn = output.comp_node();
// cn.sync(); // cannot sync ??????????
// }
CompNode::sync_all();
}

@@ -224,19 +222,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}

VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) {
SymbolVarArray input_syms;
for (auto &input_var: inputs)
input_syms.emplace_back(input_var);
auto&& op = static_cast<const CustomOpDef&>(def);
OperatorNodeConfig config;
SymbolVarArray output_syms = opr::CustomOpNode::make(
op.impl(), input_syms, op.param(), config
VarNodeArray outputs = opr::CustomOpNode::make(
op.impl(), inputs, op.param(), config
);

VarNodeArray outputs;
for (auto &output_sym: output_syms)
outputs.push_back(output_sym.node());
return outputs;
}

@@ -273,6 +263,7 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) {
return a.param() == b.param() && a.runtime_id() == b.runtime_id();
}

// [TODO] to be implemented
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now");
// can be implement with param schema


+ 4
- 2
src/opr/impl/custom_opnode.cpp View File

@@ -140,7 +140,8 @@ void CustomOpNode::do_execute(ExecEnv &env) {
std::vector<custom::Tensor> custom_inputs = custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
std::vector<custom::Tensor> custom_outputs = custom::to_custom<DeviceTensorND, custom::Tensor>(outputs);
m_op->compute(custom_inputs, m_param, custom_outputs);
CompNode::sync_all(); // whether reasonable
// [TODO] sync should be modified
CompNode::sync_all();

this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
this, m_comp_node
@@ -157,7 +158,8 @@ void CustomOpNode::init_output_static_infer_desc() {
auto &&mgr = owner_graph()->static_infer_manager();

DepVal dep;
if (true) { // need design a function to allow user to decide it
// [TODO] need design a interface to allow user to decide it
if (true) {
for (auto input_var: input())
dep.push_back({input_var, DepType::SHAPE});
}


Loading…
Cancel
Save