GitOrigin-RevId: e295a1fa55
release-1.6
@@ -7,13 +7,20 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op | |||||
from .._imperative_rt.ops._custom import ( | |||||
_get_custom_op_list, | |||||
_install, | |||||
_make_custom_op, | |||||
_uninstall, | |||||
) | |||||
__all__ = ["load"] | __all__ = ["load"] | ||||
def _gen_custom_op_maker(custom_op_name): | def _gen_custom_op_maker(custom_op_name): | ||||
def op_maker(**kwargs): | def op_maker(**kwargs): | ||||
return _make_custom_op(custom_op_name, kwargs) | return _make_custom_op(custom_op_name, kwargs) | ||||
return op_maker | return op_maker | ||||
@@ -95,6 +95,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs | |||||
for (auto i_shape: i_shapes) { | for (auto i_shape: i_shapes) { | ||||
if (i_shape.ndim == 0) { | if (i_shape.ndim == 0) { | ||||
success = false; | success = false; | ||||
break; | |||||
} | } | ||||
} | } | ||||
@@ -187,14 +188,11 @@ void apply_on_device_tensornd(const OpDef& def, | |||||
auto cn = output.comp_node(); | auto cn = output.comp_node(); | ||||
cn.activate(); | cn.activate(); | ||||
} | } | ||||
// [TODO] sync should be modified | |||||
CompNode::sync_all(); | CompNode::sync_all(); | ||||
auto&& op = static_cast<const CustomOpDef&>(def); | auto&& op = static_cast<const CustomOpDef&>(def); | ||||
op.compute(inputs, outputs); | op.compute(inputs, outputs); | ||||
// for (auto &&output: (*outputs)) { | |||||
// auto cn = output.comp_node(); | |||||
// cn.sync(); // cannot sync ?????????? | |||||
// } | |||||
CompNode::sync_all(); | CompNode::sync_all(); | ||||
} | } | ||||
@@ -224,19 +222,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
} | } | ||||
VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) { | VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) { | ||||
SymbolVarArray input_syms; | |||||
for (auto &input_var: inputs) | |||||
input_syms.emplace_back(input_var); | |||||
auto&& op = static_cast<const CustomOpDef&>(def); | auto&& op = static_cast<const CustomOpDef&>(def); | ||||
OperatorNodeConfig config; | OperatorNodeConfig config; | ||||
SymbolVarArray output_syms = opr::CustomOpNode::make( | |||||
op.impl(), input_syms, op.param(), config | |||||
VarNodeArray outputs = opr::CustomOpNode::make( | |||||
op.impl(), inputs, op.param(), config | |||||
); | ); | ||||
VarNodeArray outputs; | |||||
for (auto &output_sym: output_syms) | |||||
outputs.push_back(output_sym.node()); | |||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -273,6 +263,7 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) { | |||||
return a.param() == b.param() && a.runtime_id() == b.runtime_id(); | return a.param() == b.param() && a.runtime_id() == b.runtime_id(); | ||||
} | } | ||||
// [TODO] to be implemented | |||||
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | ||||
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); | mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); | ||||
// can be implement with param schema | // can be implement with param schema | ||||
@@ -140,7 +140,8 @@ void CustomOpNode::do_execute(ExecEnv &env) { | |||||
std::vector<custom::Tensor> custom_inputs = custom::to_custom<DeviceTensorND, custom::Tensor>(inputs); | std::vector<custom::Tensor> custom_inputs = custom::to_custom<DeviceTensorND, custom::Tensor>(inputs); | ||||
std::vector<custom::Tensor> custom_outputs = custom::to_custom<DeviceTensorND, custom::Tensor>(outputs); | std::vector<custom::Tensor> custom_outputs = custom::to_custom<DeviceTensorND, custom::Tensor>(outputs); | ||||
m_op->compute(custom_inputs, m_param, custom_outputs); | m_op->compute(custom_inputs, m_param, custom_outputs); | ||||
CompNode::sync_all(); // whether reasonable | |||||
// [TODO] sync should be modified | |||||
CompNode::sync_all(); | |||||
this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>( | this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>( | ||||
this, m_comp_node | this, m_comp_node | ||||
@@ -157,7 +158,8 @@ void CustomOpNode::init_output_static_infer_desc() { | |||||
auto &&mgr = owner_graph()->static_infer_manager(); | auto &&mgr = owner_graph()->static_infer_manager(); | ||||
DepVal dep; | DepVal dep; | ||||
if (true) { // need design a function to allow user to decide it | |||||
// [TODO] need design a interface to allow user to decide it | |||||
if (true) { | |||||
for (auto input_var: input()) | for (auto input_var: input()) | ||||
dep.push_back({input_var, DepType::SHAPE}); | dep.push_back({input_var, DepType::SHAPE}); | ||||
} | } | ||||