@@ -29,10 +29,13 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||||
# TODO: use initialization from tensor after fixing format setting | # TODO: use initialization from tensor after fixing format setting | ||||
if x.format != "nhwc": | if x.format != "nhwc": | ||||
if inplace: | if inplace: | ||||
# reset will destroy backward grad | |||||
data = x.numpy().transpose(*pattern) | data = x.numpy().transpose(*pattern) | ||||
x[...] = Tensor(data, format="nhwc") | x[...] = Tensor(data, format="nhwc") | ||||
else: | else: | ||||
x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||||
# use mge interface to maintain grad | |||||
x = F.transpose(x, pattern) | |||||
x.format="nhwc" | |||||
return x | return x | ||||
@@ -245,6 +245,8 @@ def conv2d( | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
with _config._override(auto_format_convert=False): | |||||
print(compute_mode, inp.shape, inp.format, weight.shape, weight.format) | |||||
op = builtin.Convolution( | op = builtin.Convolution( | ||||
stride_h=stride_h, | stride_h=stride_h, | ||||
stride_w=stride_w, | stride_w=stride_w, | ||||
@@ -1,5 +1,6 @@ | |||||
import numpy as np | import numpy as np | ||||
import megengine as mge | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import Parameter | from megengine import Parameter | ||||
@@ -34,6 +35,7 @@ class GroupNorm(Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
N, C, H, W = x.shape | N, C, H, W = x.shape | ||||
format = x.format | |||||
assert C == self.num_channels | assert C == self.num_channels | ||||
x = x.reshape(N, self.num_groups, -1) | x = x.reshape(N, self.num_groups, -1) | ||||
@@ -44,7 +46,9 @@ class GroupNorm(Module): | |||||
x = x.reshape(N, C, H, W) | x = x.reshape(N, C, H, W) | ||||
if self.affine: | if self.affine: | ||||
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | ||||
# FIXME(czh): remove this after making it a builtin op. | |||||
if format == "nhwc": | |||||
x = mge.amp.convert_tensor_format(x, inplace=False) | |||||
return x | return x | ||||
def _module_info_string(self) -> str: | def _module_info_string(self) -> str: | ||||
@@ -81,6 +85,7 @@ class InstanceNorm(Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
N, C, H, W = x.shape | N, C, H, W = x.shape | ||||
format = x.format | |||||
assert C == self.num_channels | assert C == self.num_channels | ||||
x = x.reshape(N, C, -1) | x = x.reshape(N, C, -1) | ||||
mean = x.mean(axis=2, keepdims=True) | mean = x.mean(axis=2, keepdims=True) | ||||
@@ -90,7 +95,9 @@ class InstanceNorm(Module): | |||||
x = x.reshape(N, C, H, W) | x = x.reshape(N, C, H, W) | ||||
if self.affine: | if self.affine: | ||||
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | ||||
# FIXME(czh): remove this after making it a builtin op. | |||||
if format == "nhwc": | |||||
x = mge.amp.convert_tensor_format(x, inplace=False) | |||||
return x | return x | ||||
def _module_info_string(self) -> str: | def _module_info_string(self) -> str: | ||||
@@ -122,7 +122,11 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
@property | @property | ||||
def format(self) -> str: | def format(self) -> str: | ||||
return super().format | |||||
return super().format() | |||||
@format.setter | |||||
def format(self, format): | |||||
super()._set_format(format) | |||||
@property | @property | ||||
def qparams(self): | def qparams(self): | ||||
@@ -584,6 +584,12 @@ void TensorWrapper::set_module_trace_info(PyObject* obj) { | |||||
module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | ||||
} | } | ||||
void TensorWrapper::_set_format(PyObject* dest) { | |||||
auto py_dest = py::reinterpret_borrow<py::object>(dest); | |||||
auto format = py_dest.cast<std::string>(); | |||||
m_tensor->set_format(format); | |||||
} | |||||
void TensorWrapper::_set_name(PyObject* dest) { | void TensorWrapper::_set_name(PyObject* dest) { | ||||
auto py_dest = py::reinterpret_borrow<py::object>(dest); | auto py_dest = py::reinterpret_borrow<py::object>(dest); | ||||
auto name = py_dest.cast<std::string>(); | auto name = py_dest.cast<std::string>(); | ||||
@@ -812,7 +818,7 @@ void init_tensor(py::module m) { | |||||
.def_getset<&TensorWrapper::shape>("shape") | .def_getset<&TensorWrapper::shape>("shape") | ||||
.def_getset<&TensorWrapper::dtype>("dtype") | .def_getset<&TensorWrapper::dtype>("dtype") | ||||
.def_getset<&TensorWrapper::device>("device") | .def_getset<&TensorWrapper::device>("device") | ||||
.def_getset<&TensorWrapper::format>("format") | |||||
.def<&TensorWrapper::format>("format") | |||||
.def<&TensorWrapper::reset>("_reset") | .def<&TensorWrapper::reset>("_reset") | ||||
.def<&TensorWrapper::isscalar>("_isscalar") | .def<&TensorWrapper::isscalar>("_isscalar") | ||||
.def<&TensorWrapper::detach>("detach") | .def<&TensorWrapper::detach>("detach") | ||||
@@ -820,6 +826,7 @@ void init_tensor(py::module m) { | |||||
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") | .def<&TensorWrapper::_dev_tensor>("_dev_tensor") | ||||
.def<&TensorWrapper::_drop>("_drop") | .def<&TensorWrapper::_drop>("_drop") | ||||
.def<&TensorWrapper::_detail>("_detail") | .def<&TensorWrapper::_detail>("_detail") | ||||
.def<&TensorWrapper::_set_format>("_set_format") | |||||
.def<&TensorWrapper::_set_name>("_set_name") | .def<&TensorWrapper::_set_name>("_set_name") | ||||
.def<&TensorWrapper::_watch>("_watch") | .def<&TensorWrapper::_watch>("_watch") | ||||
.def<&TensorWrapper::_var>("var") | .def<&TensorWrapper::_var>("var") | ||||
@@ -59,6 +59,11 @@ public: | |||||
return *shape; | return *shape; | ||||
} | } | ||||
inline Format format() { return *data().format(); } | inline Format format() { return *data().format(); } | ||||
inline void set_format(std::string format) { | |||||
if (!format.empty()) { | |||||
m_data = imperative::apply(SetFormat(format), m_data)[0]; | |||||
} | |||||
} | |||||
inline HostValue::ref_t numpy() { return data().numpy(); } | inline HostValue::ref_t numpy() { return data().numpy(); } | ||||
inline void reset(ValueRef value) { | inline void reset(ValueRef value) { | ||||
m_data = value; | m_data = value; | ||||
@@ -130,6 +135,7 @@ public: | |||||
PyObject* copied(); | PyObject* copied(); | ||||
PyObject* module_trace_info(); | PyObject* module_trace_info(); | ||||
void set_module_trace_info(PyObject*); | void set_module_trace_info(PyObject*); | ||||
void _set_format(PyObject*); | |||||
void _set_name(PyObject*); | void _set_name(PyObject*); | ||||
PyObject* _detail(); | PyObject* _detail(); | ||||
PyObject* _var(); | PyObject* _var(); | ||||
@@ -31,6 +31,9 @@ def test_basic(): | |||||
b[...] = tensor(data, format="nchw") | b[...] = tensor(data, format="nchw") | ||||
assert b.format == "nchw" | assert b.format == "nchw" | ||||
# set tensor's format | |||||
b.format = "nhwc" | |||||
assert b.format == "nhwc" | |||||
def _compare_nchw_nhwc(data, func, is_symbolic=None): | def _compare_nchw_nhwc(data, func, is_symbolic=None): | ||||
x1 = tensor(data) | x1 = tensor(data) | ||||
@@ -105,9 +105,16 @@ std::string IsScalar::to_string() const { | |||||
return "IsScalar"; | return "IsScalar"; | ||||
} | } | ||||
std::string GetFormat::to_string() const { | |||||
return "GetFormat{}"; | |||||
} | |||||
std::string SetFormat::to_string() const { | |||||
return ssprintf("SetFormat{format=%s}", m_format.to_string().c_str()); | |||||
} | |||||
std::string GetVarVal::to_string() const { | std::string GetVarVal::to_string() const { | ||||
return "GetVarVal"; | return "GetVarVal"; | ||||
} | } | ||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |
@@ -57,15 +57,15 @@ inline ValueRefList FormatTransformation::unwrap_inputs( | |||||
} | } | ||||
inline ValueRef FormatTransformation::wrap_output( | inline ValueRef FormatTransformation::wrap_output( | ||||
const ValueRef& output, FT type) const { | |||||
return m_value_type.make(output, type); | |||||
const ValueRef& output, Format format) const { | |||||
return m_value_type.make(output, format); | |||||
} | } | ||||
inline ValueRefList FormatTransformation::wrap_outputs( | inline ValueRefList FormatTransformation::wrap_outputs( | ||||
const ValueRefList& outputs, FT type) const { | |||||
const ValueRefList& outputs, Format format) const { | |||||
ValueRefList wrapped_outputs(outputs.size()); | ValueRefList wrapped_outputs(outputs.size()); | ||||
for (size_t i = 0; i < outputs.size(); ++i) { | for (size_t i = 0; i < outputs.size(); ++i) { | ||||
wrapped_outputs[i] = wrap_output(outputs[i], type); | |||||
wrapped_outputs[i] = wrap_output(outputs[i], format); | |||||
} | } | ||||
return wrapped_outputs; | return wrapped_outputs; | ||||
} | } | ||||
@@ -241,7 +241,7 @@ ValueRefList subtensor_rule( | |||||
if (!(auto_convert && src.format() == FT::NHWC)) { | if (!(auto_convert && src.format() == FT::NHWC)) { | ||||
return {t.wrap_output( | return {t.wrap_output( | ||||
imperative::apply(op, t.unwrap_inputs(inputs))[0], | imperative::apply(op, t.unwrap_inputs(inputs))[0], | ||||
src.format().type())}; | |||||
src.format())}; | |||||
} | } | ||||
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | ||||
auto outputs = imperative::apply( | auto outputs = imperative::apply( | ||||
@@ -264,7 +264,7 @@ ValueRefList setsubtensor_rule( | |||||
if (!(auto_convert && src.format() == FT::NHWC)) { | if (!(auto_convert && src.format() == FT::NHWC)) { | ||||
return {t.wrap_output( | return {t.wrap_output( | ||||
imperative::apply(op, t.unwrap_inputs(inputs))[0], | imperative::apply(op, t.unwrap_inputs(inputs))[0], | ||||
src.format().type())}; | |||||
src.format())}; | |||||
} | } | ||||
// value has been broadcasted to src's fake NCHW shape. | // value has been broadcasted to src's fake NCHW shape. | ||||
auto& value = inputs[1].cast(t.value_type()); | auto& value = inputs[1].cast(t.value_type()); | ||||
@@ -330,7 +330,7 @@ ValueRefList identity_rule_helper( | |||||
// mgb_assert(inputs.size() == 1); | // mgb_assert(inputs.size() == 1); | ||||
auto& src = inputs[0].cast(t.value_type()); | auto& src = inputs[0].cast(t.value_type()); | ||||
return t.wrap_outputs( | return t.wrap_outputs( | ||||
imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type()); | |||||
imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||||
} | } | ||||
ValueRefList batchnorm_rule( | ValueRefList batchnorm_rule( | ||||
@@ -467,7 +467,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
} | } | ||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | } else if (auto* create_tensor = op.as<CreateTensor>()) { | ||||
auto format = create_tensor->format(); | auto format = create_tensor->format(); | ||||
return {wrap_output(imperative::apply(op, inputs)[0], format.type())}; | |||||
return {wrap_output(imperative::apply(op, inputs)[0], format)}; | |||||
} else if (auto* get_attr = op.as<GetAttr>()) { | } else if (auto* get_attr = op.as<GetAttr>()) { | ||||
auto&& input = inputs.item(); | auto&& input = inputs.item(); | ||||
if (!input.is(m_value_type)) { | if (!input.is(m_value_type)) { | ||||
@@ -500,12 +500,16 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
op.to_string().c_str(), inputs[0].to_string().c_str()); | op.to_string().c_str(), inputs[0].to_string().c_str()); | ||||
return {FormatValue::make(FT::DEFAULT)}; | return {FormatValue::make(FT::DEFAULT)}; | ||||
} | } | ||||
} else if (auto* _op = op.as<SetFormat>()) { | |||||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||||
mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); | |||||
return {m_value_type.make(inp_ref->value(), _op->format())}; | |||||
} else if (op.is<Operator::IdentityLike>()) { | } else if (op.is<Operator::IdentityLike>()) { | ||||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | auto&& inp_ref = inputs[0].as_ref(m_value_type); | ||||
if (inp_ref) { | if (inp_ref) { | ||||
auto&& format = inp_ref->format(); | auto&& format = inp_ref->format(); | ||||
return wrap_outputs( | return wrap_outputs( | ||||
imperative::apply(op, unwrap_inputs(inputs)), format.type()); | |||||
imperative::apply(op, unwrap_inputs(inputs)), format); | |||||
} else { | } else { | ||||
mgb_log_warn( | mgb_log_warn( | ||||
"Not FormattedTensorValue input for IdentityLike op: %s, %s", | "Not FormattedTensorValue input for IdentityLike op: %s, %s", | ||||
@@ -521,13 +525,13 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
GenericFunction new_callback = | GenericFunction new_callback = | ||||
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | ||||
auto wrapped_inputs = SmallVector<ValueRef>{ | auto wrapped_inputs = SmallVector<ValueRef>{ | ||||
this->value_type().make(inputs_.item(), format.type())}; | |||||
this->value_type().make(inputs_.item(), format)}; | |||||
auto ret = callback(wrapped_inputs); | auto ret = callback(wrapped_inputs); | ||||
return ret; | return ret; | ||||
}; | }; | ||||
auto&& outputs = imperative::apply( | auto&& outputs = imperative::apply( | ||||
op, inp_ref->value(), FunctionValue::make(new_callback)); | op, inp_ref->value(), FunctionValue::make(new_callback)); | ||||
return wrap_outputs(outputs, format.type()); | |||||
return wrap_outputs(outputs, format); | |||||
} else { | } else { | ||||
mgb_log_warn( | mgb_log_warn( | ||||
"Not FormattedTensorValue input for AttachGrad op: %s, %s", | "Not FormattedTensorValue input for AttachGrad op: %s, %s", | ||||
@@ -549,7 +553,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | for (size_t i = 0; i < nr_outputs; ++i) { | ||||
if (auto output_ref = outputs_[i].as_ref(m_value_type)) { | if (auto output_ref = outputs_[i].as_ref(m_value_type)) { | ||||
wrapped_outputs[i] = | wrapped_outputs[i] = | ||||
m_value_type.make(outputs[i], output_ref->format().type()); | |||||
m_value_type.make(outputs[i], output_ref->format()); | |||||
} else { | } else { | ||||
mgb_log_warn( | mgb_log_warn( | ||||
"Not FormattedTensorValue outputs for SetGrad op: %s, %s", | "Not FormattedTensorValue outputs for SetGrad op: %s, %s", | ||||
@@ -164,7 +164,19 @@ public: | |||||
class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> { | class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> { | ||||
public: | public: | ||||
std::string to_string() const override { return "GetFormat{}"; } | |||||
std::string to_string() const override; | |||||
}; | |||||
class SetFormat final : public OperatorImpl<SetFormat, Operator::IdentityLike> { | |||||
private: | |||||
Format m_format; | |||||
public: | |||||
SetFormat(std::string format) : m_format(format) {} | |||||
Format format() const { return m_format; } | |||||
std::string to_string() const override; | |||||
}; | }; | ||||
class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { | class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { | ||||
@@ -26,6 +26,8 @@ public: | |||||
const Format& format() const { return m_format; } | const Format& format() const { return m_format; } | ||||
void set_format(Format format) { m_format = format; } | |||||
void clear() override { | void clear() override { | ||||
m_value = {}; | m_value = {}; | ||||
m_format = {}; | m_format = {}; | ||||
@@ -65,10 +67,10 @@ public: | |||||
inline ValueRef unwrap_input(const ValueRef& input) const; | inline ValueRef unwrap_input(const ValueRef& input) const; | ||||
inline ValueRefList unwrap_inputs(const Span<ValueRef>& inputs) const; | inline ValueRefList unwrap_inputs(const Span<ValueRef>& inputs) const; | ||||
inline ValueRef wrap_output( | inline ValueRef wrap_output( | ||||
const ValueRef& output, Format::Type type = Format::Type::DEFAULT) const; | |||||
const ValueRef& output, Format format = Format::Type::DEFAULT) const; | |||||
inline ValueRefList wrap_outputs( | inline ValueRefList wrap_outputs( | ||||
const ValueRefList& outputs, | const ValueRefList& outputs, | ||||
Format::Type type = Format::Type::DEFAULT) const; | |||||
Format format = Format::Type::DEFAULT) const; | |||||
TypedValueRef<FormattedTensorValue> as( | TypedValueRef<FormattedTensorValue> as( | ||||
const FormattedTensorValue&, const Format::Type& target) const; | const FormattedTensorValue&, const Format::Type& target) const; | ||||