GitOrigin-RevId: d14a69424d
tags/v1.9.0
@@ -41,7 +41,6 @@ from ..distributed import WORLD, is_distributed | |||||
from ..jit import exclude_from_trace | from ..jit import exclude_from_trace | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_func | from ..utils.deprecation import deprecated_func | ||||
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||||
from .debug_param import get_execution_strategy | from .debug_param import get_execution_strategy | ||||
from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
from .elemwise import _elwise, exp, log, log1p, maximum, minimum | from .elemwise import _elwise, exp, log, log1p, maximum, minimum | ||||
@@ -94,14 +93,15 @@ __all__ = [ | |||||
def expand_hw(x): | def expand_hw(x): | ||||
# NOTE: >1d array is accepted, as long as 1 <= size <= 2 | |||||
try: | |||||
x = int(x) | |||||
return [x, x] | |||||
except (TypeError, ValueError): | |||||
pass | |||||
h, w = x | |||||
return int(h), int(w) | |||||
if isinstance(x, Sequence): | |||||
return int(x[0]), int(x[1]) | |||||
return int(x), int(x) | |||||
def expand_dhw(x): | |||||
if isinstance(x, Sequence): | |||||
return int(x[0]), int(x[1]), int(x[2]) | |||||
return int(x), int(x), int(x) | |||||
def linear( | def linear( | ||||
@@ -177,11 +177,8 @@ def conv1d( | |||||
if weight.dtype != dtype: | if weight.dtype != dtype: | ||||
weight = weight.astype(dtype) | weight = weight.astype(dtype) | ||||
inp = expand_dims(inp, 3) | |||||
weight = expand_dims(weight, 3) | |||||
if bias is not None: | if bias is not None: | ||||
assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | ||||
bias = expand_dims(bias, 3) | |||||
stride_h = stride | stride_h = stride | ||||
pad_h = padding | pad_h = padding | ||||
@@ -206,7 +203,6 @@ def conv1d( | |||||
(output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
output = squeeze(output, 3) | |||||
return output | return output | ||||
@@ -314,9 +310,9 @@ def conv3d( | |||||
D, H, W = 0, 1, 2 | D, H, W = 0, 1, 2 | ||||
pad = _triple(padding) | |||||
stride = _triple_nonzero(stride) | |||||
dilate = _triple_nonzero(dilation) | |||||
pad = expand_dhw(padding) | |||||
stride = expand_dhw(stride) | |||||
dilate = expand_dhw(dilation) | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
op = builtin.Convolution3D( | op = builtin.Convolution3D( | ||||
@@ -572,9 +568,9 @@ def conv_transpose3d( | |||||
output tensor. | output tensor. | ||||
""" | """ | ||||
D, H, W = 0, 1, 2 | D, H, W = 0, 1, 2 | ||||
pad = _triple(padding) | |||||
stride = _triple_nonzero(stride) | |||||
dilate = _triple_nonzero(dilation) | |||||
pad = expand_dhw(padding) | |||||
stride = expand_dhw(stride) | |||||
dilate = expand_dhw(dilation) | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
op = builtin.Convolution3DBackwardData( | op = builtin.Convolution3DBackwardData( | ||||
@@ -618,9 +614,9 @@ def max_pool2d( | |||||
""" | """ | ||||
if stride is None: | if stride is None: | ||||
stride = kernel_size | stride = kernel_size | ||||
window_h, window_w = _pair_nonzero(kernel_size) | |||||
stride_h, stride_w = _pair_nonzero(stride) | |||||
padding_h, padding_w = _pair(padding) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | ||||
op = builtin.Pooling( | op = builtin.Pooling( | ||||
@@ -662,9 +658,9 @@ def avg_pool2d( | |||||
""" | """ | ||||
if stride is None: | if stride is None: | ||||
stride = kernel_size | stride = kernel_size | ||||
window_h, window_w = _pair_nonzero(kernel_size) | |||||
stride_h, stride_w = _pair_nonzero(stride) | |||||
padding_h, padding_w = _pair(padding) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | ||||
op = builtin.Pooling( | op = builtin.Pooling( | ||||
@@ -1779,10 +1775,10 @@ def sliding_window( | |||||
stride: stride of the window. Default: 1 | stride: stride of the window. Default: 1 | ||||
dilation: dilation of the window. Default: 1 | dilation: dilation of the window. Default: 1 | ||||
""" | """ | ||||
padding_h, padding_w = _pair(padding) | |||||
stride_h, stride_w = _pair_nonzero(stride) | |||||
dilation_h, dilation_w = _pair_nonzero(dilation) | |||||
window_h, window_w = _pair_nonzero(kernel_size) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
dilation_h, dilation_w = expand_hw(dilation) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
op = builtin.Images2Neibs( | op = builtin.Images2Neibs( | ||||
pad_h=padding_h, | pad_h=padding_h, | ||||
@@ -1818,11 +1814,11 @@ def sliding_window_transpose( | |||||
stride: stride of the window. Default: 1 | stride: stride of the window. Default: 1 | ||||
dilation: dilation of the window. Default: 1 | dilation: dilation of the window. Default: 1 | ||||
""" | """ | ||||
output_h, output_w = _pair_nonzero(output_size) | |||||
padding_h, padding_w = _pair(padding) | |||||
stride_h, stride_w = _pair_nonzero(stride) | |||||
dilation_h, dilation_w = _pair_nonzero(dilation) | |||||
window_h, window_w = _pair_nonzero(kernel_size) | |||||
output_h, output_w = expand_hw(output_size) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
dilation_h, dilation_w = expand_hw(dilation) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
expected_h = ( | expected_h = ( | ||||
output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 | output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 | ||||
@@ -80,19 +80,6 @@ class _BatchNorm(Module): | |||||
self.track_running_stats == False | self.track_running_stats == False | ||||
), "track_running_stats can not be initilized to False and changed to True later" | ), "track_running_stats can not be initilized to False and changed to True later" | ||||
inp_shape = inp.shape | |||||
_ndims = len(inp_shape) | |||||
if _ndims != 4: | |||||
origin_shape = inp_shape | |||||
if _ndims == 2: | |||||
n, c = inp_shape[0], inp_shape[1] | |||||
new_shape = (n, c, 1, 1) | |||||
elif _ndims == 3: | |||||
n, c, h = inp_shape[0], inp_shape[1], inp_shape[2] | |||||
new_shape = (n, c, h, 1) | |||||
inp = inp.reshape(new_shape) | |||||
_weight = self.weight | _weight = self.weight | ||||
_bias = self.bias | _bias = self.bias | ||||
@@ -130,9 +117,6 @@ class _BatchNorm(Module): | |||||
param_dim=self.param_dim, | param_dim=self.param_dim, | ||||
) | ) | ||||
if _ndims != 4: | |||||
output = output.reshape(origin_shape) | |||||
return output | return output | ||||
def _module_info_string(self) -> str: | def _module_info_string(self) -> str: | ||||
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
#include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
#include "megbrain/imperative/transformations/dim_expansion.h" | |||||
#include "megbrain/imperative/transformations/dtype_promote.h" | #include "megbrain/imperative/transformations/dtype_promote.h" | ||||
#include "megbrain/imperative/transformations/eval.h" | #include "megbrain/imperative/transformations/eval.h" | ||||
#include "megbrain/imperative/transformations/lazy.h" | #include "megbrain/imperative/transformations/lazy.h" | ||||
@@ -61,11 +62,13 @@ struct SymbolVarContext { | |||||
std::shared_ptr<SymbolTransformation> symbol_tsf; | std::shared_ptr<SymbolTransformation> symbol_tsf; | ||||
std::shared_ptr<ScalarTransformation> scalar_tsf; | std::shared_ptr<ScalarTransformation> scalar_tsf; | ||||
std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | ||||
std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf; | |||||
SymbolVarContext(cg::ComputingGraph* graph) { | SymbolVarContext(cg::ComputingGraph* graph) { | ||||
symbol_tsf = std::make_shared<SymbolTransformation>(graph); | symbol_tsf = std::make_shared<SymbolTransformation>(graph); | ||||
scalar_tsf = std::make_shared<ScalarTransformation>(); | scalar_tsf = std::make_shared<ScalarTransformation>(); | ||||
dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | ||||
dim_expansion_tsf = std::make_shared<DimExpansionTransformation>(); | |||||
Transformation::swap_context(context); | Transformation::swap_context(context); | ||||
} | } | ||||
@@ -73,6 +76,7 @@ struct SymbolVarContext { | |||||
symbol_tsf->register_at(Transformation::top()); | symbol_tsf->register_at(Transformation::top()); | ||||
scalar_tsf->register_at(Transformation::top()); | scalar_tsf->register_at(Transformation::top()); | ||||
dtype_promote_tsf->register_at(Transformation::top()); | dtype_promote_tsf->register_at(Transformation::top()); | ||||
dim_expansion_tsf->register_at(Transformation::top()); | |||||
} | } | ||||
ValueRef symvar2val(py::handle py_symbol_var) { | ValueRef symvar2val(py::handle py_symbol_var) { | ||||
@@ -452,6 +456,8 @@ void init_tensor(py::module m) { | |||||
std::make_shared<ScalarTransformation>()); | std::make_shared<ScalarTransformation>()); | ||||
transformations.register_at<Segment::DTypePromote>( | transformations.register_at<Segment::DTypePromote>( | ||||
std::make_shared<DTypePromoteTransformation>()); | std::make_shared<DTypePromoteTransformation>()); | ||||
transformations.register_at<Segment::DimExpansion>( | |||||
std::make_shared<DimExpansionTransformation>()); | |||||
static py::exception<interpreter::AsyncError> py_async_error( | static py::exception<interpreter::AsyncError> py_async_error( | ||||
m, "AsyncError", PyExc_RuntimeError); | m, "AsyncError", PyExc_RuntimeError); | ||||
@@ -26,13 +26,14 @@ struct TransformationManager { | |||||
enum Segment { | enum Segment { | ||||
ModuleTrace, | ModuleTrace, | ||||
DTypePromote, | DTypePromote, | ||||
DimExpansion, | |||||
Grad, | Grad, | ||||
Scalar, | Scalar, | ||||
Trace, | Trace, | ||||
Eval, | Eval, | ||||
}; | }; | ||||
std::array<std::vector<std::shared_ptr<Transformation>>, 6> segments; | |||||
std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | |||||
template <Segment segment> | template <Segment segment> | ||||
void register_at(std::shared_ptr<Transformation> transformation) { | void register_at(std::shared_ptr<Transformation> transformation) { | ||||
@@ -91,7 +91,7 @@ class ResNet(M.Module): | |||||
def run_dtr_resnet1202(): | def run_dtr_resnet1202(): | ||||
batch_size = 8 | |||||
batch_size = 7 | |||||
resnet1202 = ResNet(BasicBlock, [200, 200, 200]) | resnet1202 = ResNet(BasicBlock, [200, 200, 200]) | ||||
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) | opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) | ||||
gm = GradManager().attach(resnet1202.parameters()) | gm = GradManager().attach(resnet1202.parameters()) | ||||
@@ -0,0 +1,95 @@ | |||||
#include "megbrain/imperative/transformations/dim_expansion.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
namespace mgb::imperative { | |||||
namespace { | |||||
using DimExpansionRule = std::function<ValueRefList(const OpDef&, Span<ValueRef>)>; | |||||
static std::unordered_map<Typeinfo*, DimExpansionRule> dim_expansion_rules; | |||||
template <typename T> | |||||
void register_dim_expansion_rules(const DimExpansionRule& rule) { | |||||
dim_expansion_rules[T::typeinfo()] = [rule](const OpDef& def, | |||||
Span<ValueRef> inputs) { | |||||
return rule(def.cast_final_safe<T>(), inputs); | |||||
}; | |||||
} | |||||
ValueRefList conv1d_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
bool need_expand = inputs.at(0).shape()->ndim == 3; | |||||
if (!need_expand) | |||||
return imperative::apply(op, inputs); | |||||
ValueRefList converted(inputs.size()); | |||||
std::vector<int32_t> axis = {(int32_t)3}; | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
converted[i] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[i])[0]; | |||||
} | |||||
auto outputs = imperative::apply(op, converted); | |||||
outputs[0] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[0])[0]; | |||||
return outputs; | |||||
} | |||||
ValueRefList bn1d_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
size_t ndim = inputs.at(0).shape()->ndim; | |||||
bool need_expand = (ndim == 2 || ndim == 3); | |||||
if (!need_expand) | |||||
return imperative::apply(op, inputs); | |||||
ValueRefList converted(inputs.size()); | |||||
std::vector<int32_t> axis = {(int32_t)3}; | |||||
if (ndim == 2) { | |||||
axis.insert(axis.begin(), (int32_t)2); | |||||
} | |||||
converted[0] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[0])[0]; | |||||
for (size_t i = 1; i < inputs.size(); ++i) { | |||||
converted[i] = inputs[i]; | |||||
} | |||||
std::reverse(std::begin(axis), std::end(axis)); | |||||
auto outputs = imperative::apply(op, converted); | |||||
size_t idx = outputs.size() - 1; | |||||
outputs[idx] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[idx])[0]; | |||||
return outputs; | |||||
} | |||||
struct DimExpansionRuleRegistry { | |||||
DimExpansionRuleRegistry() { | |||||
register_dim_expansion_rules<Convolution>(conv1d_rule); | |||||
register_dim_expansion_rules<BatchNorm>(bn1d_rule); | |||||
} | |||||
} register_helper; | |||||
} // namespace | |||||
ValueRefList DimExpansionTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
if (auto apply_op = op.as<ApplyOp>()) { | |||||
auto iter = dim_expansion_rules.find(apply_op->op().dyn_typeinfo()); | |||||
if (iter != dim_expansion_rules.end()) { | |||||
return iter->second(apply_op->op(), inputs); | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
} | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
ValueRef DimExpansionTransformation::unwrap(ValueRef value) { | |||||
return value; | |||||
} | |||||
std::string DimExpansionTransformation::name() const { | |||||
return "DimExpansionTransformation"; | |||||
} | |||||
void DimExpansionTransformation::on_register() { | |||||
// printf("DimExpansionTransformation has been registered\n"); | |||||
} | |||||
void DimExpansionTransformation::on_unregister() noexcept { | |||||
// printf("DimExpansionTransformation has been unregistered\n"); | |||||
} | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,19 @@ | |||||
#pragma once | |||||
#include "megbrain/imperative/dispatch.h" | |||||
#include "megbrain/imperative/value.h" | |||||
namespace mgb::imperative { | |||||
class DimExpansionTransformation final : public Transformation { | |||||
private: | |||||
public: | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | |||||
ValueRef unwrap(ValueRef value) override; | |||||
std::string name() const override; | |||||
void on_register() override; | |||||
void on_unregister() noexcept override; | |||||
}; | |||||
} // namespace mgb::imperative |