Browse Source

perf(mge): override grad of matmul

GitOrigin-RevId: d9d97e70fe
release-1.10
Megvii Engine Team 3 years ago
parent
commit
4aa79c453b
4 changed files with 229 additions and 89 deletions
  1. +21
    -11
      imperative/python/megengine/core/tensor/array_method.py
  2. +133
    -0
      imperative/python/src/grad_override.cpp
  3. +42
    -0
      imperative/python/test/unit/core/test_autodiff.py
  4. +33
    -78
      imperative/src/impl/ops/matmul.cpp

+ 21
- 11
imperative/python/megengine/core/tensor/array_method.py View File

@@ -13,6 +13,7 @@ from .._imperative_rt.core2 import (
astype_cpp,
batched_matmul_cpp,
broadcast_cpp,
expand_dims_cpp,
getitem_cpp,
matmul_cpp,
reshape_cpp,
@@ -62,7 +63,6 @@ def _matmul(
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)

if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result
@@ -72,34 +72,44 @@ def _matmul(
# 2x2
# nx1(transpose_a=False), n>=3
# nx2(transpose_a=False), n>=3
return matmul_cpp(
inp1,
inp2,
dim1,
dim2,
ret = matmul_cpp(
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0),
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1),
max(dim1, 2),
max(dim2, 2),
transpose_a,
transpose_b,
compute_mode,
_config._benchmark_kernel,
_config._deterministic_kernel,
)
if dim1 == 1:
ret = squeeze_cpp(ret, -2)
elif dim2 == 1:
ret = squeeze_cpp(ret, -1)
return ret
else: # dispath to BatchedMatrixMul
# nx1(transpose_a=True), n>=3
# nx2(transpose_a=True), n>=3
# nxm,n>=3,m>=3
# 1xm,m>=3
# 2xm,m>=3
return batched_matmul_cpp(
inp1,
inp2,
dim1,
dim2,
ret = batched_matmul_cpp(
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0),
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1),
max(dim1, 2),
max(dim2, 2),
transpose_a,
transpose_b,
compute_mode,
_config._benchmark_kernel,
_config._deterministic_kernel,
)
if dim1 == 1:
ret = squeeze_cpp(ret, -2)
elif dim2 == 1:
ret = squeeze_cpp(ret, -1)
return ret


def _unary_elwise(mode):


+ 133
- 0
imperative/python/src/grad_override.cpp View File

@@ -87,6 +87,136 @@ ValueRef make_empty_tensor(
return res;
}

std::optional<ValueRefList> matrix_mul_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& matmul = op.cast_final_safe<MatrixMul>();
size_t dimA = matmul.dimA;
size_t dimB = matmul.dimB;
auto&& param = matmul.param();
auto&& policy = matmul.policy();
mgb_assert(inputs.size() == 2);
std::array<ValueRef, 2> inps, input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (inputs_require_grad[i ^ 1]) {
inps[i] = inputs[i];
input_shapes[i] = get_shape(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes),
param, policy, dimA, dimB](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(2);
if (!grad) {
return ret;
}
size_t dimG = std::max(dimA, dimB);
if (inps_[1]) {
if (param.transposeA) {
// A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul
auto&& grad_op = MatrixMul::make(
param.transposeB, true, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimB, dimG);
ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0];
} else {
// A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul
auto&& grad_op = MatrixMul::make(
false, !param.transposeB, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimG, dimB);
ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0];
}
}
if (inps_[0]) {
if (param.transposeB) {
// A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul
// (specialized)
auto&& grad_op = MatrixMul::make(
true, param.transposeA, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimG, dimA);
ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0];
} else {
// A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul
// (specialized)
auto&& grad_op = MatrixMul::make(
!param.transposeA, false, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimA, dimG);
ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0];
}
}
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}

std::optional<ValueRefList> batched_matrix_mul_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& bmm = op.cast_final_safe<BatchedMatrixMul>();
size_t dimA = bmm.dimA;
size_t dimB = bmm.dimB;
auto&& param = bmm.param();
auto&& policy = bmm.policy();
mgb_assert(inputs.size() == 2);
std::array<ValueRef, 2> inps, input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (inputs_require_grad[i ^ 1]) {
inps[i] = inputs[i];
input_shapes[i] = get_shape(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes),
param, policy, dimA, dimB](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(2);
if (!grad) {
return ret;
}
size_t dimG = std::max(dimA, dimB);
if (inps_[1]) {
if (param.transposeA) {
auto&& grad_op = BatchedMatrixMul::make(
param.transposeB, true, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimB, dimG);
ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0];
} else {
auto&& grad_op = BatchedMatrixMul::make(
false, !param.transposeB, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimG, dimB);
ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0];
}
if (dimG != dimA) {
ret[0] = reduce_to(ret[0], input_shapes_[0]);
}
}
if (inps_[0]) {
if (param.transposeB) {
auto&& grad_op = BatchedMatrixMul::make(
true, param.transposeA, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimG, dimA);
ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0];
} else {
auto&& grad_op = BatchedMatrixMul::make(
!param.transposeA, false, param.compute_mode, param.format,
policy.strategy, policy.workspace_limit, dimA, dimG);
ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0];
}
if (dimG != dimB) {
ret[1] = reduce_to(ret[1], input_shapes_[1]);
}
}
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}

std::optional<ValueRefList> elemwise_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
@@ -395,6 +525,9 @@ struct Init {
FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
CustomBackward::register_grad_rule(
PixelShuffle::typeinfo(), pixelShuffle_grad_rule);
CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule);
CustomBackward::register_grad_rule(
BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule);
}
} _;



+ 42
- 0
imperative/python/test/unit/core/test_autodiff.py View File

@@ -511,3 +511,45 @@ def test_pixel_shuffle():
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(2 * x.numpy(), x.grad.numpy())


def test_matmul():
def test_one(xdim, ydim, transposeA, transposeB):
xshape = (1, 4) if xdim == 1 else (2,) * (xdim - 2) + (3, 4)
yshape = (4, 1) if ydim == 1 else (2,) * (ydim - 2) + (4, 5)
x = np.random.rand(*xshape).astype("float32")
y = np.random.rand(*yshape).astype("float32")
gshape = (x @ y).shape
g = np.random.rand(*gshape).astype("float32")
dx = g @ np.swapaxes(y, -1, -2)
dy = np.swapaxes(x, -1, -2) @ g
while dx.shape != x.shape:
dx = dx.sum(0)
while dy.shape != y.shape:
dy = dy.sum(0)
if transposeA:
x = np.swapaxes(x, -1, -2)
dx = np.swapaxes(dx, -1, -2)
if transposeB:
y = np.swapaxes(y, -1, -2)
dy = np.swapaxes(dy, -1, -2)
x = mge.Tensor(x.squeeze())
y = mge.Tensor(y.squeeze())
g = mge.Tensor(g.squeeze())
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
grad.wrt(y, callback=save_to(y))
z = F.matmul(x, y, transpose_a=transposeA, transpose_b=transposeB)
grad(z, g)
np.testing.assert_almost_equal(dx.squeeze(), x.grad.numpy(), decimal=5)
np.testing.assert_almost_equal(dy.squeeze(), y.grad.numpy(), decimal=5)

for xdim in [1, 2, 3, 4]:
for ydim in [1, 2, 3, 4]:
for transposeA in [False, True]:
if xdim == 1 and transposeA == True:
continue
for transposeB in [False, True]:
if ydim == 1 and transposeB == True:
continue
test_one(xdim, ydim, transposeA, transposeB)

+ 33
- 78
imperative/src/impl/ops/matmul.cpp View File

@@ -31,18 +31,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
DTypeScalar vi{-1};
auto graph = inputs[0]->owner_graph();

bool remove_row = false, remove_col = false;
if (dim1 == 1) {
dim1 = 2;
remove_row = true;
inp1 = inp1.add_axis(0);
}
if (dim2 == 1) {
dim2 = 2;
remove_col = true;
inp2 = inp2.add_axis(1);
}

SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail;
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
@@ -91,17 +79,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
result = result.reshape(tshp);
}

auto maxdim = dim1 > dim2 ? dim1 : dim2;
if (remove_row) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 2));
result = opr::AxisAddRemove::make(result, remove_param);
}
if (remove_col) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 1));
result = opr::AxisAddRemove::make(result, remove_param);
}
return result;
}

@@ -113,6 +90,19 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
size_t dim1 = layout1.ndim, dim2 = layout2.ndim;

DType dst_dtype;
if (dim1 == dim2 && dim2 >= 3) { // only happens in backward
for (size_t i = 1; i + 1 < layout1.ndim; ++i) {
layout1[0] *= layout1[i];
layout2[0] *= layout2[i];
}
layout1[1] = layout1[layout1.ndim - 1];
layout1.ndim = 2;
layout1.init_contiguous_stride();
layout2[1] = layout2[layout2.ndim - 1];
layout2.ndim = 2;
layout2.init_contiguous_stride();
dim1 = dim2 = 2;
}

DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node);
dnn_opr.op->param() = matmul.param();
@@ -156,6 +146,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn);
dnn_opr.op->param() = matmul.param();

if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward
for (size_t i = 1; i + 1 < layout1.ndim; ++i) {
layout1[0] *= layout1[i];
layout2[0] *= layout2[i];
}
layout1[1] = layout1[layout1.ndim - 1];
layout1.ndim = 2;
layout1.init_contiguous_stride();
layout2[1] = layout2[layout2.ndim - 1];
layout2.ndim = 2;
layout2.init_contiguous_stride();
}

DType dst_dtype;
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);

@@ -191,11 +194,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}

TensorLayout layout_a = layout1, layout_b = layout2;
if (dim1 == 1) {
layout_a.add_axis_cont_inplace(0);
inp_tensornds[0] = inputs[0]->dnn_tensor();
inp_tensornds[0].layout = layout_a;
} else if (dim1 > 2) {
if (dim1 > 2) {
size_t batch = std::accumulate(
layout1.shape, layout1.shape + dim1 - 1, (size_t)1,
std::multiplies<size_t>());
@@ -216,13 +215,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
inp_tensornds[0] = inputs[0]->dnn_tensor();
}

if (dim2 == 1) {
layout_b.add_axis_inplace(1, 1, 1);
inp_tensornds[1] = inputs[1]->dnn_tensor();
inp_tensornds[1].layout = layout_b;
} else {
inp_tensornds[1] = inputs[1]->dnn_tensor();
}
inp_tensornds[1] = inputs[1]->dnn_tensor();

TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype);
dst_layout.init_contiguous_stride();
@@ -232,6 +225,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
if (matmul.transposeB)
std::swap(layout_b.shape[0], layout_b.shape[1]);

if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward
inp_tensornds[0].layout = layout_a;
inp_tensornds[1].layout = layout_b;
}

DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout);
size_t sz = setup_algo<megdnn::MatrixMul>(
@@ -279,18 +277,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto graph = inputs[0]->owner_graph();
auto idx = opr::ImmutableTensor::make(*graph, vi, config);

bool remove_row = false, remove_col = false;
if (dim1 == 1) {
dim1 = 2;
remove_row = true;
inp1 = inp1.add_axis(0);
}
if (dim2 == 1) {
dim2 = 2;
remove_col = true;
inp2 = inp2.add_axis(1);
}

auto shp1 = inp1.symshape();
auto shp2 = inp2.symshape();
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail;
@@ -349,16 +335,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn);
result = result.reshape(result_shp);
}
if (remove_row) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 2));
result = opr::AxisAddRemove::make(result, remove_param);
}
if (remove_col) {
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(maxdim - 1));
result = opr::AxisAddRemove::make(result, remove_param);
}
return result;
}

@@ -418,21 +394,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
DType dst_dtype;
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);

bool remove_row = false, remove_col = false;
if (dim1 == 1) {
dim1 = 2;
remove_row = true;
}
if (dim2 == 1) {
dim2 = 2;
remove_col = true;
}

if (remove_row)
layout1.add_axis_cont_inplace(0);
if (remove_col)
layout2.add_axis_inplace(1, 1, 1);

TensorShape tshp, batch_shp;
size_t j = 0;
auto inp1 = inputs[0], inp2 = inputs[1];
@@ -530,12 +491,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
if (maxdim > 3) {
dst_layout = dst_layout.reshape(shp1);
}
if (remove_row) {
dst_layout = dst_layout.remove_axis(maxdim - 2);
}
if (remove_col) {
dst_layout = dst_layout.remove_axis(maxdim - 1);
}
return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))};
}



Loading…
Cancel
Save