Browse Source

fix(imperative/amp): fix format transformation for symbol trans

GitOrigin-RevId: 96cc237c67
release-1.10
Megvii Engine Team 3 years ago
parent
commit
d313f92610
7 changed files with 47 additions and 24 deletions
  1. +0
    -2
      imperative/python/megengine/autodiff/grad_manager.py
  2. +3
    -1
      imperative/python/megengine/functional/tensor.py
  3. +1
    -1
      imperative/python/src/transformation.h
  4. +19
    -15
      imperative/python/test/unit/amp/test_convert_format.py
  5. +5
    -1
      imperative/python/test/unit/core/test_formatted_tensor.py
  6. +18
    -4
      imperative/src/impl/transformations/format.cpp
  7. +1
    -0
      imperative/src/include/megbrain/imperative/transformations/format.h

+ 0
- 2
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -260,7 +260,6 @@ class GradManager:
push_scope("backward") push_scope("backward")
set_option("record_computing_path", 0) set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert() _origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
from ..functional import ones_like from ..functional import ones_like


global backwarding_grad_manager global backwarding_grad_manager
@@ -304,7 +303,6 @@ class GradManager:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
set_option("record_computing_path", 1) set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
pop_scope("backward") pop_scope("backward")


def record(self): def record(self):


+ 3
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return x return x


# set x's format to use FormatTransformation rule for Broadcast. # set x's format to use FormatTransformation rule for Broadcast.
return broadcast_to(x, inp.shape)
rst = broadcast_to(x, inp.shape)
rst.format = inp.format
return rst




def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:


+ 1
- 1
imperative/python/src/transformation.h View File

@@ -26,7 +26,7 @@ public:
Eval, Eval,
}; };


std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments;
std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments;


private: private:
template <Segment segment> template <Segment segment>


+ 19
- 15
imperative/python/test/unit/amp/test_convert_format.py View File

@@ -12,6 +12,7 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Parameter, Tensor, amp from megengine import Parameter, Tensor, amp
from megengine.core._config import set_auto_format_convert from megengine.core._config import set_auto_format_convert
from megengine.core._trace_option import use_symbolic_shape




class MyModule(M.Module): class MyModule(M.Module):
@@ -41,22 +42,25 @@ class MyModule(M.Module):
def test_convert_module(is_inplace): def test_convert_module(is_inplace):
m = MyModule() m = MyModule()
expected_shape = { expected_shape = {
"i.bn.weight": (1, 1, 1, 4),
"i.bn.bias": (1, 1, 1, 4),
"i.bn.running_mean": (1, 1, 1, 4),
"i.bn.running_var": (1, 1, 1, 4),
"conv.weight": (2, 2, 4, 4, 2),
"conv.bias": (1, 1, 1, 4),
"bn.weight": (1, 1, 1, 4),
"bn.bias": (1, 1, 1, 4),
"bn.running_mean": (1, 1, 1, 4),
"bn.running_var": (1, 1, 1, 4),
"param": (1, 1, 1, 3),
"buff": (1, 1, 1, 3),
"i.bn.weight": (1, 4, 1, 1),
"i.bn.bias": (1, 4, 1, 1),
"i.bn.running_mean": (1, 4, 1, 1),
"i.bn.running_var": (1, 4, 1, 1),
"conv.weight": (2, 2, 2, 4, 4),
"conv.bias": (1, 4, 1, 1),
"bn.weight": (1, 4, 1, 1),
"bn.bias": (1, 4, 1, 1),
"bn.running_mean": (1, 4, 1, 1),
"bn.running_var": (1, 4, 1, 1),
"param": (1, 3, 1, 1),
"buff": (1, 3, 1, 1),
} }
m = amp.convert_module_format(m, is_inplace) m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors(): for name, param in m.named_tensors():
assert param.format == "nhwc" assert param.format == "nhwc"
set_auto_format_convert(False)
assert param.shape == expected_shape[name], name
set_auto_format_convert(True)
if use_symbolic_shape():
np.testing.assert_array_equal(
param.shape.numpy(), expected_shape[name], name
)
else:
assert param.shape == expected_shape[name], name

+ 5
- 1
imperative/python/test/unit/core/test_formatted_tensor.py View File

@@ -6,6 +6,7 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import tensor from megengine import tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._trace_option import use_symbolic_shape
from megengine.jit import trace from megengine.jit import trace




@@ -121,7 +122,10 @@ def test_repeat(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
def test_getshape(is_symbolic): def test_getshape(is_symbolic):
def func(x): def func(x):
return x.shape
if use_symbolic_shape():
return x.shape.numpy()
else:
return x.shape


data = np.arange(0, 24).reshape((1, 2, 3, 4)) data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func, is_symbolic) _compare_nchw_nhwc(data, func, is_symbolic)


+ 18
- 4
imperative/src/impl/transformations/format.cpp View File

@@ -1,5 +1,6 @@
#include "megbrain/imperative/transformations/format.h" #include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/grad.h" #include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/transformations/symbol.h"


#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
@@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs(
} }
return wrapped_outputs; return wrapped_outputs;
} }

inline bool FormatTransformation::check_all_format_value(
const Span<ValueRef>& inputs) const {
for (size_t i = 0; i < inputs.size(); ++i) {
if (!inputs[i].as_ref(m_value_type)) {
return false;
}
}
return true;
}

namespace { namespace {


ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
@@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format(
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type()); auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != dst_fmt && if (inp.format() != dst_fmt &&
inp.value().shape().cast<ShapeValue>().ndim == 4) {
(inp.value().shape().cast<ShapeValue>().ndim == 4 ||
inp.value().shape().cast<ShapeValue>().ndim == 5)) {
unified_inputs[i] = t.to(inp, dst_fmt, scope); unified_inputs[i] = t.to(inp, dst_fmt, scope);
} else { } else {
unified_inputs[i] = inputs[i]; unified_inputs[i] = inputs[i];
@@ -568,6 +581,10 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation( ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) { const Operator& op, Span<ValueRef> inputs) {
if (auto* apply_op = op.as<ApplyOp>()) { if (auto* apply_op = op.as<ApplyOp>()) {
// bypass SymbolValue
if (!check_all_format_value(inputs)) {
return imperative::apply(op, unwrap_inputs(inputs));
}
// all inputs should be FormattedTensorValue // all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
if (iter != format_rules.end()) { if (iter != format_rules.end()) {
@@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation(
auto&& format = inp_ref->format(); auto&& format = inp_ref->format();
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
} else { } else {
mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s, %s",
op.to_string().c_str(), inputs[0].to_string().c_str());
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
} else if (op.is<AttachGrad>()) { } else if (op.is<AttachGrad>()) {


+ 1
- 0
imperative/src/include/megbrain/imperative/transformations/format.h View File

@@ -70,6 +70,7 @@ public:
const ValueRef& output, Format format = Format::Type::DEFAULT) const; const ValueRef& output, Format format = Format::Type::DEFAULT) const;
inline ValueRefList wrap_outputs( inline ValueRefList wrap_outputs(
const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const; const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const;
inline bool check_all_format_value(const Span<ValueRef>& inputs) const;


TypedValueRef<FormattedTensorValue> as( TypedValueRef<FormattedTensorValue> as(
const FormattedTensorValue&, const Format::Type& target) const; const FormattedTensorValue&, const Format::Type& target) const;


Loading…
Cancel
Save