@@ -13,6 +13,7 @@ from .._imperative_rt.core2 import ( | |||
astype_cpp, | |||
batched_matmul_cpp, | |||
broadcast_cpp, | |||
expand_dims_cpp, | |||
getitem_cpp, | |||
matmul_cpp, | |||
reshape_cpp, | |||
@@ -62,7 +63,6 @@ def _matmul( | |||
assert dim1 > 0 and dim2 > 0 | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
(result,) = apply(builtin.Dot(), inp1, inp2) | |||
return result | |||
@@ -72,34 +72,44 @@ def _matmul( | |||
# 2x2 | |||
# nx1(transpose_a=False), n>=3 | |||
# nx2(transpose_a=False), n>=3 | |||
return matmul_cpp( | |||
inp1, | |||
inp2, | |||
dim1, | |||
dim2, | |||
ret = matmul_cpp( | |||
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||
max(dim1, 2), | |||
max(dim2, 2), | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
_config._benchmark_kernel, | |||
_config._deterministic_kernel, | |||
) | |||
if dim1 == 1: | |||
ret = squeeze_cpp(ret, -2) | |||
elif dim2 == 1: | |||
ret = squeeze_cpp(ret, -1) | |||
return ret | |||
else: # dispath to BatchedMatrixMul | |||
# nx1(transpose_a=True), n>=3 | |||
# nx2(transpose_a=True), n>=3 | |||
# nxm,n>=3,m>=3 | |||
# 1xm,m>=3 | |||
# 2xm,m>=3 | |||
return batched_matmul_cpp( | |||
inp1, | |||
inp2, | |||
dim1, | |||
dim2, | |||
ret = batched_matmul_cpp( | |||
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||
max(dim1, 2), | |||
max(dim2, 2), | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
_config._benchmark_kernel, | |||
_config._deterministic_kernel, | |||
) | |||
if dim1 == 1: | |||
ret = squeeze_cpp(ret, -2) | |||
elif dim2 == 1: | |||
ret = squeeze_cpp(ret, -1) | |||
return ret | |||
def _unary_elwise(mode): | |||
@@ -87,6 +87,136 @@ ValueRef make_empty_tensor( | |||
return res; | |||
} | |||
std::optional<ValueRefList> matrix_mul_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& matmul = op.cast_final_safe<MatrixMul>(); | |||
size_t dimA = matmul.dimA; | |||
size_t dimB = matmul.dimB; | |||
auto&& param = matmul.param(); | |||
auto&& policy = matmul.policy(); | |||
mgb_assert(inputs.size() == 2); | |||
std::array<ValueRef, 2> inps, input_shapes; | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (inputs_require_grad[i ^ 1]) { | |||
inps[i] = inputs[i]; | |||
input_shapes[i] = get_shape(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||
param, policy, dimA, dimB](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
SmallVector<ValueRef> ret(2); | |||
if (!grad) { | |||
return ret; | |||
} | |||
size_t dimG = std::max(dimA, dimB); | |||
if (inps_[1]) { | |||
if (param.transposeA) { | |||
// A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul | |||
auto&& grad_op = MatrixMul::make( | |||
param.transposeB, true, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimB, dimG); | |||
ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||
} else { | |||
// A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul | |||
auto&& grad_op = MatrixMul::make( | |||
false, !param.transposeB, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimG, dimB); | |||
ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||
} | |||
} | |||
if (inps_[0]) { | |||
if (param.transposeB) { | |||
// A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul | |||
// (specialized) | |||
auto&& grad_op = MatrixMul::make( | |||
true, param.transposeA, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimG, dimA); | |||
ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||
} else { | |||
// A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul | |||
// (specialized) | |||
auto&& grad_op = MatrixMul::make( | |||
!param.transposeA, false, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimA, dimG); | |||
ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||
} | |||
} | |||
return ret; | |||
}); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<ValueRefList> batched_matrix_mul_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& bmm = op.cast_final_safe<BatchedMatrixMul>(); | |||
size_t dimA = bmm.dimA; | |||
size_t dimB = bmm.dimB; | |||
auto&& param = bmm.param(); | |||
auto&& policy = bmm.policy(); | |||
mgb_assert(inputs.size() == 2); | |||
std::array<ValueRef, 2> inps, input_shapes; | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (inputs_require_grad[i ^ 1]) { | |||
inps[i] = inputs[i]; | |||
input_shapes[i] = get_shape(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||
param, policy, dimA, dimB](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
SmallVector<ValueRef> ret(2); | |||
if (!grad) { | |||
return ret; | |||
} | |||
size_t dimG = std::max(dimA, dimB); | |||
if (inps_[1]) { | |||
if (param.transposeA) { | |||
auto&& grad_op = BatchedMatrixMul::make( | |||
param.transposeB, true, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimB, dimG); | |||
ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||
} else { | |||
auto&& grad_op = BatchedMatrixMul::make( | |||
false, !param.transposeB, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimG, dimB); | |||
ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||
} | |||
if (dimG != dimA) { | |||
ret[0] = reduce_to(ret[0], input_shapes_[0]); | |||
} | |||
} | |||
if (inps_[0]) { | |||
if (param.transposeB) { | |||
auto&& grad_op = BatchedMatrixMul::make( | |||
true, param.transposeA, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimG, dimA); | |||
ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||
} else { | |||
auto&& grad_op = BatchedMatrixMul::make( | |||
!param.transposeA, false, param.compute_mode, param.format, | |||
policy.strategy, policy.workspace_limit, dimA, dimG); | |||
ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||
} | |||
if (dimG != dimB) { | |||
ret[1] = reduce_to(ret[1], input_shapes_[1]); | |||
} | |||
} | |||
return ret; | |||
}); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<ValueRefList> elemwise_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
@@ -395,6 +525,9 @@ struct Init { | |||
FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | |||
CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule); | |||
} | |||
} _; | |||
@@ -511,3 +511,45 @@ def test_pixel_shuffle(): | |||
y = f(x) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) | |||
def test_matmul(): | |||
def test_one(xdim, ydim, transposeA, transposeB): | |||
xshape = (1, 4) if xdim == 1 else (2,) * (xdim - 2) + (3, 4) | |||
yshape = (4, 1) if ydim == 1 else (2,) * (ydim - 2) + (4, 5) | |||
x = np.random.rand(*xshape).astype("float32") | |||
y = np.random.rand(*yshape).astype("float32") | |||
gshape = (x @ y).shape | |||
g = np.random.rand(*gshape).astype("float32") | |||
dx = g @ np.swapaxes(y, -1, -2) | |||
dy = np.swapaxes(x, -1, -2) @ g | |||
while dx.shape != x.shape: | |||
dx = dx.sum(0) | |||
while dy.shape != y.shape: | |||
dy = dy.sum(0) | |||
if transposeA: | |||
x = np.swapaxes(x, -1, -2) | |||
dx = np.swapaxes(dx, -1, -2) | |||
if transposeB: | |||
y = np.swapaxes(y, -1, -2) | |||
dy = np.swapaxes(dy, -1, -2) | |||
x = mge.Tensor(x.squeeze()) | |||
y = mge.Tensor(y.squeeze()) | |||
g = mge.Tensor(g.squeeze()) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
grad.wrt(y, callback=save_to(y)) | |||
z = F.matmul(x, y, transpose_a=transposeA, transpose_b=transposeB) | |||
grad(z, g) | |||
np.testing.assert_almost_equal(dx.squeeze(), x.grad.numpy(), decimal=5) | |||
np.testing.assert_almost_equal(dy.squeeze(), y.grad.numpy(), decimal=5) | |||
for xdim in [1, 2, 3, 4]: | |||
for ydim in [1, 2, 3, 4]: | |||
for transposeA in [False, True]: | |||
if xdim == 1 and transposeA == True: | |||
continue | |||
for transposeB in [False, True]: | |||
if ydim == 1 and transposeB == True: | |||
continue | |||
test_one(xdim, ydim, transposeA, transposeB) |
@@ -31,18 +31,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
DTypeScalar vi{-1}; | |||
auto graph = inputs[0]->owner_graph(); | |||
bool remove_row = false, remove_col = false; | |||
if (dim1 == 1) { | |||
dim1 = 2; | |||
remove_row = true; | |||
inp1 = inp1.add_axis(0); | |||
} | |||
if (dim2 == 1) { | |||
dim2 = 2; | |||
remove_col = true; | |||
inp2 = inp2.add_axis(1); | |||
} | |||
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
if (dim1 > 2) { | |||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
@@ -91,17 +79,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
result = result.reshape(tshp); | |||
} | |||
auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||
if (remove_row) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
if (remove_col) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
return result; | |||
} | |||
@@ -113,6 +90,19 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
DType dst_dtype; | |||
if (dim1 == dim2 && dim2 >= 3) { // only happens in backward | |||
for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||
layout1[0] *= layout1[i]; | |||
layout2[0] *= layout2[i]; | |||
} | |||
layout1[1] = layout1[layout1.ndim - 1]; | |||
layout1.ndim = 2; | |||
layout1.init_contiguous_stride(); | |||
layout2[1] = layout2[layout2.ndim - 1]; | |||
layout2.ndim = 2; | |||
layout2.init_contiguous_stride(); | |||
dim1 = dim2 = 2; | |||
} | |||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||
dnn_opr.op->param() = matmul.param(); | |||
@@ -156,6 +146,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||
dnn_opr.op->param() = matmul.param(); | |||
if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||
for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||
layout1[0] *= layout1[i]; | |||
layout2[0] *= layout2[i]; | |||
} | |||
layout1[1] = layout1[layout1.ndim - 1]; | |||
layout1.ndim = 2; | |||
layout1.init_contiguous_stride(); | |||
layout2[1] = layout2[layout2.ndim - 1]; | |||
layout2.ndim = 2; | |||
layout2.init_contiguous_stride(); | |||
} | |||
DType dst_dtype; | |||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
@@ -191,11 +194,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
} | |||
TensorLayout layout_a = layout1, layout_b = layout2; | |||
if (dim1 == 1) { | |||
layout_a.add_axis_cont_inplace(0); | |||
inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||
inp_tensornds[0].layout = layout_a; | |||
} else if (dim1 > 2) { | |||
if (dim1 > 2) { | |||
size_t batch = std::accumulate( | |||
layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | |||
std::multiplies<size_t>()); | |||
@@ -216,13 +215,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||
} | |||
if (dim2 == 1) { | |||
layout_b.add_axis_inplace(1, 1, 1); | |||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
inp_tensornds[1].layout = layout_b; | |||
} else { | |||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
} | |||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | |||
dst_layout.init_contiguous_stride(); | |||
@@ -232,6 +225,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
if (matmul.transposeB) | |||
std::swap(layout_b.shape[0], layout_b.shape[1]); | |||
if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||
inp_tensornds[0].layout = layout_a; | |||
inp_tensornds[1].layout = layout_b; | |||
} | |||
DeviceTensorND out = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||
size_t sz = setup_algo<megdnn::MatrixMul>( | |||
@@ -279,18 +277,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto graph = inputs[0]->owner_graph(); | |||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
bool remove_row = false, remove_col = false; | |||
if (dim1 == 1) { | |||
dim1 = 2; | |||
remove_row = true; | |||
inp1 = inp1.add_axis(0); | |||
} | |||
if (dim2 == 1) { | |||
dim2 = 2; | |||
remove_col = true; | |||
inp2 = inp2.add_axis(1); | |||
} | |||
auto shp1 = inp1.symshape(); | |||
auto shp2 = inp2.symshape(); | |||
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
@@ -349,16 +335,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | |||
result = result.reshape(result_shp); | |||
} | |||
if (remove_row) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
if (remove_col) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
return result; | |||
} | |||
@@ -418,21 +394,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
DType dst_dtype; | |||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
bool remove_row = false, remove_col = false; | |||
if (dim1 == 1) { | |||
dim1 = 2; | |||
remove_row = true; | |||
} | |||
if (dim2 == 1) { | |||
dim2 = 2; | |||
remove_col = true; | |||
} | |||
if (remove_row) | |||
layout1.add_axis_cont_inplace(0); | |||
if (remove_col) | |||
layout2.add_axis_inplace(1, 1, 1); | |||
TensorShape tshp, batch_shp; | |||
size_t j = 0; | |||
auto inp1 = inputs[0], inp2 = inputs[1]; | |||
@@ -530,12 +491,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
if (maxdim > 3) { | |||
dst_layout = dst_layout.reshape(shp1); | |||
} | |||
if (remove_row) { | |||
dst_layout = dst_layout.remove_axis(maxdim - 2); | |||
} | |||
if (remove_col) { | |||
dst_layout = dst_layout.remove_axis(maxdim - 1); | |||
} | |||
return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | |||
} | |||