You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.cpp 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #include "../dnn_op_helper.h"
  2. #include "../op_trait.h"
  3. #include "megbrain/imperative/ops/autogen.h"
  4. #include "megbrain/opr/misc.h"
  5. namespace mgb {
  6. namespace imperative {
  7. namespace check_non_finite {
  8. SymbolVarArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  9. auto&& op = def.cast_final_safe<CheckNonFinite>();
  10. OperatorNodeConfig config{op.make_name()};
  11. return opr::CheckNonFinite::make(inputs, op.param(), config);
  12. }
  13. SmallVector<TensorPtr> apply_on_physical_tensor(
  14. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  15. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  16. auto&& op = def.cast_final_safe<CheckNonFinite>();
  17. auto comp_node = inputs[0]->comp_node();
  18. auto dest = Tensor::make(TensorLayout({1}, dtype::Int32()), comp_node);
  19. SmallVector<TensorPtr> outputs;
  20. outputs.reserve(inputs.size() + 1);
  21. for (auto&& input : inputs) {
  22. outputs.push_back(Tensor::make(input->layout(), comp_node));
  23. outputs.back()->dev_tensor().copy_from_fixlayout(input->dev_tensor());
  24. }
  25. DnnOprCaller<megdnn::CheckNonFinite> dnn_opr(comp_node, {op.scale});
  26. dnn_opr.exec_with_ws(outputs, dest);
  27. outputs.push_back(dest);
  28. return outputs;
  29. }
  30. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  31. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  32. size_t size = inputs.size();
  33. SmallVector<LogicalTensorDesc> dests(size + 1);
  34. bool validated = true;
  35. for (size_t i = 0; i < size; ++i) {
  36. dests[i].comp_node = inputs[i].comp_node;
  37. dests[i].layout = inputs[i].layout;
  38. validated &= bool(dests[i].layout.ndim);
  39. }
  40. dests[size].comp_node = inputs[0].comp_node;
  41. dests[size].layout = TensorLayout({1}, dtype::Int32());
  42. return {dests, validated};
  43. }
  44. OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
  45. .apply_on_var_node(apply_on_var_node)
  46. .apply_on_physical_tensor(apply_on_physical_tensor)
  47. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  48. .fallback();
  49. } // namespace check_non_finite
  50. } // namespace imperative
  51. } // namespace mgb
  52. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}