GitOrigin-RevId: 96cc237c67
release-1.10
@@ -260,7 +260,6 @@ class GradManager: | |||||
push_scope("backward") | push_scope("backward") | ||||
set_option("record_computing_path", 0) | set_option("record_computing_path", 0) | ||||
_origin_auto_format = get_auto_format_convert() | _origin_auto_format = get_auto_format_convert() | ||||
set_auto_format_convert(False) | |||||
from ..functional import ones_like | from ..functional import ones_like | ||||
global backwarding_grad_manager | global backwarding_grad_manager | ||||
@@ -304,7 +303,6 @@ class GradManager: | |||||
self.release() | self.release() | ||||
backwarding_grad_manager = cache | backwarding_grad_manager = cache | ||||
set_option("record_computing_path", 1) | set_option("record_computing_path", 1) | ||||
set_auto_format_convert(_origin_auto_format) | |||||
pop_scope("backward") | pop_scope("backward") | ||||
def record(self): | def record(self): | ||||
@@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
return x | return x | ||||
# set x's format to use FormatTransformation rule for Broadcast. | # set x's format to use FormatTransformation rule for Broadcast. | ||||
return broadcast_to(x, inp.shape) | |||||
rst = broadcast_to(x, inp.shape) | |||||
rst.format = inp.format | |||||
return rst | |||||
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | ||||
@@ -26,7 +26,7 @@ public: | |||||
Eval, | Eval, | ||||
}; | }; | ||||
std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments; | |||||
std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments; | |||||
private: | private: | ||||
template <Segment segment> | template <Segment segment> | ||||
@@ -12,6 +12,7 @@ import megengine.functional as F | |||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import Parameter, Tensor, amp | from megengine import Parameter, Tensor, amp | ||||
from megengine.core._config import set_auto_format_convert | from megengine.core._config import set_auto_format_convert | ||||
from megengine.core._trace_option import use_symbolic_shape | |||||
class MyModule(M.Module): | class MyModule(M.Module): | ||||
@@ -41,22 +42,25 @@ class MyModule(M.Module): | |||||
def test_convert_module(is_inplace): | def test_convert_module(is_inplace): | ||||
m = MyModule() | m = MyModule() | ||||
expected_shape = { | expected_shape = { | ||||
"i.bn.weight": (1, 1, 1, 4), | |||||
"i.bn.bias": (1, 1, 1, 4), | |||||
"i.bn.running_mean": (1, 1, 1, 4), | |||||
"i.bn.running_var": (1, 1, 1, 4), | |||||
"conv.weight": (2, 2, 4, 4, 2), | |||||
"conv.bias": (1, 1, 1, 4), | |||||
"bn.weight": (1, 1, 1, 4), | |||||
"bn.bias": (1, 1, 1, 4), | |||||
"bn.running_mean": (1, 1, 1, 4), | |||||
"bn.running_var": (1, 1, 1, 4), | |||||
"param": (1, 1, 1, 3), | |||||
"buff": (1, 1, 1, 3), | |||||
"i.bn.weight": (1, 4, 1, 1), | |||||
"i.bn.bias": (1, 4, 1, 1), | |||||
"i.bn.running_mean": (1, 4, 1, 1), | |||||
"i.bn.running_var": (1, 4, 1, 1), | |||||
"conv.weight": (2, 2, 2, 4, 4), | |||||
"conv.bias": (1, 4, 1, 1), | |||||
"bn.weight": (1, 4, 1, 1), | |||||
"bn.bias": (1, 4, 1, 1), | |||||
"bn.running_mean": (1, 4, 1, 1), | |||||
"bn.running_var": (1, 4, 1, 1), | |||||
"param": (1, 3, 1, 1), | |||||
"buff": (1, 3, 1, 1), | |||||
} | } | ||||
m = amp.convert_module_format(m, is_inplace) | m = amp.convert_module_format(m, is_inplace) | ||||
for name, param in m.named_tensors(): | for name, param in m.named_tensors(): | ||||
assert param.format == "nhwc" | assert param.format == "nhwc" | ||||
set_auto_format_convert(False) | |||||
assert param.shape == expected_shape[name], name | |||||
set_auto_format_convert(True) | |||||
if use_symbolic_shape(): | |||||
np.testing.assert_array_equal( | |||||
param.shape.numpy(), expected_shape[name], name | |||||
) | |||||
else: | |||||
assert param.shape == expected_shape[name], name |
@@ -6,6 +6,7 @@ import megengine.functional as F | |||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.jit import trace | from megengine.jit import trace | ||||
@@ -121,7 +122,10 @@ def test_repeat(is_symbolic): | |||||
@pytest.mark.parametrize("is_symbolic", [None]) | @pytest.mark.parametrize("is_symbolic", [None]) | ||||
def test_getshape(is_symbolic): | def test_getshape(is_symbolic): | ||||
def func(x): | def func(x): | ||||
return x.shape | |||||
if use_symbolic_shape(): | |||||
return x.shape.numpy() | |||||
else: | |||||
return x.shape | |||||
data = np.arange(0, 24).reshape((1, 2, 3, 4)) | data = np.arange(0, 24).reshape((1, 2, 3, 4)) | ||||
_compare_nchw_nhwc(data, func, is_symbolic) | _compare_nchw_nhwc(data, func, is_symbolic) | ||||
@@ -1,5 +1,6 @@ | |||||
#include "megbrain/imperative/transformations/format.h" | #include "megbrain/imperative/transformations/format.h" | ||||
#include "megbrain/imperative/transformations/grad.h" | #include "megbrain/imperative/transformations/grad.h" | ||||
#include "megbrain/imperative/transformations/symbol.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
@@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs( | |||||
} | } | ||||
return wrapped_outputs; | return wrapped_outputs; | ||||
} | } | ||||
inline bool FormatTransformation::check_all_format_value( | |||||
const Span<ValueRef>& inputs) const { | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (!inputs[i].as_ref(m_value_type)) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
namespace { | namespace { | ||||
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | ||||
@@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format( | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
auto&& inp = inputs[i].cast(t.value_type()); | auto&& inp = inputs[i].cast(t.value_type()); | ||||
if (inp.format() != dst_fmt && | if (inp.format() != dst_fmt && | ||||
inp.value().shape().cast<ShapeValue>().ndim == 4) { | |||||
(inp.value().shape().cast<ShapeValue>().ndim == 4 || | |||||
inp.value().shape().cast<ShapeValue>().ndim == 5)) { | |||||
unified_inputs[i] = t.to(inp, dst_fmt, scope); | unified_inputs[i] = t.to(inp, dst_fmt, scope); | ||||
} else { | } else { | ||||
unified_inputs[i] = inputs[i]; | unified_inputs[i] = inputs[i]; | ||||
@@ -568,6 +581,10 @@ struct FormatRuleRegistry { | |||||
ValueRefList FormatTransformation::apply_transformation( | ValueRefList FormatTransformation::apply_transformation( | ||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
if (auto* apply_op = op.as<ApplyOp>()) { | if (auto* apply_op = op.as<ApplyOp>()) { | ||||
// bypass SymbolValue | |||||
if (!check_all_format_value(inputs)) { | |||||
return imperative::apply(op, unwrap_inputs(inputs)); | |||||
} | |||||
// all inputs should be FormattedTensorValue | // all inputs should be FormattedTensorValue | ||||
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); | auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); | ||||
if (iter != format_rules.end()) { | if (iter != format_rules.end()) { | ||||
@@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
auto&& format = inp_ref->format(); | auto&& format = inp_ref->format(); | ||||
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); | return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); | ||||
} else { | } else { | ||||
mgb_log_warn( | |||||
"Not FormattedTensorValue input for IdentityLike op: %s, %s", | |||||
op.to_string().c_str(), inputs[0].to_string().c_str()); | |||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
} else if (op.is<AttachGrad>()) { | } else if (op.is<AttachGrad>()) { | ||||
@@ -70,6 +70,7 @@ public: | |||||
const ValueRef& output, Format format = Format::Type::DEFAULT) const; | const ValueRef& output, Format format = Format::Type::DEFAULT) const; | ||||
inline ValueRefList wrap_outputs( | inline ValueRefList wrap_outputs( | ||||
const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const; | const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const; | ||||
inline bool check_all_format_value(const Span<ValueRef>& inputs) const; | |||||
TypedValueRef<FormattedTensorValue> as( | TypedValueRef<FormattedTensorValue> as( | ||||
const FormattedTensorValue&, const Format::Type& target) const; | const FormattedTensorValue&, const Format::Type& target) const; | ||||