Browse Source

fix(imperative): fix error message when applying custom function with non-tensor arguments

GitOrigin-RevId: 387d6fda4a
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
18274e023c
3 changed files with 12 additions and 0 deletions
  1. +5
    -0
      imperative/python/megengine/core/autodiff/grad.py
  2. +6
    -0
      imperative/src/impl/op_trait.cpp
  3. +1
    -0
      imperative/src/impl/ops/utility.cpp

+ 5
- 0
imperative/python/megengine/core/autodiff/grad.py View File

@@ -123,6 +123,11 @@ class Function(ops.PyOpBase):


This method should return a tuple of Tensor or a single Tensor representing the output This method should return a tuple of Tensor or a single Tensor representing the output
of the function. of the function.

.. note::

positional arguments should all be Tensor

""" """
raise NotImplementedError raise NotImplementedError




+ 6
- 0
imperative/src/impl/op_trait.cpp View File

@@ -98,6 +98,12 @@ OpTraitRegistry& OpTraitRegistry::fallback() {
if (!trait->decide_dispatch_mode) { if (!trait->decide_dispatch_mode) {
trait->decide_dispatch_mode = fallback_decide_dispatch_mode; trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
} }
if (!trait->make_name) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
trait->make_name = make_name;
}
return *this; return *this;
} }




+ 1
- 0
imperative/src/impl/ops/utility.cpp View File

@@ -18,6 +18,7 @@
namespace mgb::imperative { namespace mgb::imperative {


MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback();


namespace { namespace fastpathcopy { namespace { namespace fastpathcopy {
auto apply_on_var_node( auto apply_on_var_node(


Loading…
Cancel
Save