From 20719502978dc8b15028ce8ba9c5b65d9e02afb9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 2 Nov 2022 13:59:25 +0800 Subject: [PATCH] fix(imperative): fix matmul deduce dtype GitOrigin-RevId: 24f4e1f9fc1fb58b3443d04c13e95e7493d119d1 --- imperative/python/test/unit/quantization/test_op.py | 16 ++++++++++++++++ imperative/src/impl/ops/matmul.cpp | 8 ++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py index 31103716..51a6a06b 100644 --- a/imperative/python/test/unit/quantization/test_op.py +++ b/imperative/python/test/unit/quantization/test_op.py @@ -340,3 +340,19 @@ def test_conv_transpose2d(): test_func(2, 4, 3, 1, 8, 1, 1, 1, 1, 0, 0, 1, 1, 1, False) test_func(4, 4, 16, 16, 8, 3, 3, 1, 1, 1, 1, 1, 1, 1, False) test_func(32, 64, 36, 28, 16, 3, 2, 1, 3, 1, 0, 1, 1, 1, False) + + +def test_matmul(): + inp_scale = np.float32(np.random.rand()) + weight_scale = np.float32(np.random.rand()) + inp_dtype = dtype.qint8(inp_scale) + weight_dtype = dtype.qint8(weight_scale) + + inp_data = np.random.random((3, 12)) + weight_data = np.random.random((5, 12)) + inp_int8 = mge.tensor(dtype.convert_to_qint8(inp_data, inp_dtype)) + weight_int8 = mge.tensor(dtype.convert_to_qint8(weight_data, weight_dtype)) + + res = F.matmul(inp_int8, weight_int8, transpose_b=True) + res_scale = dtype.get_scale(res.dtype) + np.testing.assert_allclose(inp_scale * weight_scale, res_scale) diff --git a/imperative/src/impl/ops/matmul.cpp b/imperative/src/impl/ops/matmul.cpp index 5fb4d199..a5b28b31 100644 --- a/imperative/src/impl/ops/matmul.cpp +++ b/imperative/src/impl/ops/matmul.cpp @@ -104,7 +104,7 @@ std::tuple, bool> infer_output_attrs_fallible( } DnnOprHelper dnn_opr(matmul.param()); - dnn_opr.opr().deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + dnn_opr.opr().deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype); if (dim1 == 0 || dim2 == 0) { return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; @@ -157,7 +157,7 @@ SmallVector apply_on_physical_tensor( } DType dst_dtype; - dnn_opr.op()->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + dnn_opr.op()->deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype); // only matters when layout1 has dim 2 if (matmul.transposeA) @@ -335,7 +335,7 @@ std::tuple, bool> infer_output_attrs_fallible( DType dst_dtype; DnnOprHelper dnn_opr(matmul.param()); - dnn_opr.opr().deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + dnn_opr.opr().deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype); if (dim1 == 0 || dim2 == 0) { return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; @@ -378,7 +378,7 @@ SmallVector apply_on_physical_tensor( DnnOprCaller dnn_opr(cn, matmul.param(), matmul.policy()); DType dst_dtype; - dnn_opr.op()->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); + dnn_opr.op()->deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype); TensorShape tshp, batch_shp; size_t j = 0;