|
|
@@ -95,6 +95,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs |
|
|
|
for (auto i_shape: i_shapes) { |
|
|
|
if (i_shape.ndim == 0) { |
|
|
|
success = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -187,14 +188,11 @@ void apply_on_device_tensornd(const OpDef& def, |
|
|
|
auto cn = output.comp_node(); |
|
|
|
cn.activate(); |
|
|
|
} |
|
|
|
|
|
|
|
// [TODO] sync should be modified |
|
|
|
CompNode::sync_all(); |
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
op.compute(inputs, outputs); |
|
|
|
|
|
|
|
// for (auto &&output: (*outputs)) { |
|
|
|
// auto cn = output.comp_node(); |
|
|
|
// cn.sync(); // cannot sync ?????????? |
|
|
|
// } |
|
|
|
CompNode::sync_all(); |
|
|
|
} |
|
|
|
|
|
|
@@ -224,19 +222,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
} |
|
|
|
|
|
|
|
VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) { |
|
|
|
SymbolVarArray input_syms; |
|
|
|
for (auto &input_var: inputs) |
|
|
|
input_syms.emplace_back(input_var); |
|
|
|
|
|
|
|
auto&& op = static_cast<const CustomOpDef&>(def); |
|
|
|
OperatorNodeConfig config; |
|
|
|
SymbolVarArray output_syms = opr::CustomOpNode::make( |
|
|
|
op.impl(), input_syms, op.param(), config |
|
|
|
VarNodeArray outputs = opr::CustomOpNode::make( |
|
|
|
op.impl(), inputs, op.param(), config |
|
|
|
); |
|
|
|
|
|
|
|
VarNodeArray outputs; |
|
|
|
for (auto &output_sym: output_syms) |
|
|
|
outputs.push_back(output_sym.node()); |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
@@ -273,6 +263,7 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) { |
|
|
|
return a.param() == b.param() && a.runtime_id() == b.runtime_id(); |
|
|
|
} |
|
|
|
|
|
|
|
// [TODO] to be implemented |
|
|
|
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { |
|
|
|
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); |
|
|
|
// can be implement with param schema |
|
|
|