Browse Source

fix(imperative): fix matmul deduce dtype

GitOrigin-RevId: 24f4e1f9fc
master
Megvii Engine Team 2 years ago
parent
commit
2071950297
2 changed files with 20 additions and 4 deletions
  1. +16
    -0
      imperative/python/test/unit/quantization/test_op.py
  2. +4
    -4
      imperative/src/impl/ops/matmul.cpp

+ 16
- 0
imperative/python/test/unit/quantization/test_op.py View File

@@ -340,3 +340,19 @@ def test_conv_transpose2d():
test_func(2, 4, 3, 1, 8, 1, 1, 1, 1, 0, 0, 1, 1, 1, False)
test_func(4, 4, 16, 16, 8, 3, 3, 1, 1, 1, 1, 1, 1, 1, False)
test_func(32, 64, 36, 28, 16, 3, 2, 1, 3, 1, 0, 1, 1, 1, False)


def test_matmul():
inp_scale = np.float32(np.random.rand())
weight_scale = np.float32(np.random.rand())
inp_dtype = dtype.qint8(inp_scale)
weight_dtype = dtype.qint8(weight_scale)

inp_data = np.random.random((3, 12))
weight_data = np.random.random((5, 12))
inp_int8 = mge.tensor(dtype.convert_to_qint8(inp_data, inp_dtype))
weight_int8 = mge.tensor(dtype.convert_to_qint8(weight_data, weight_dtype))

res = F.matmul(inp_int8, weight_int8, transpose_b=True)
res_scale = dtype.get_scale(res.dtype)
np.testing.assert_allclose(inp_scale * weight_scale, res_scale)

+ 4
- 4
imperative/src/impl/ops/matmul.cpp View File

@@ -104,7 +104,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}

DnnOprHelper<megdnn::MatrixMul> dnn_opr(matmul.param());
dnn_opr.opr().deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.opr().deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);

if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false};
@@ -157,7 +157,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}

DType dst_dtype;
dnn_opr.op()->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.op()->deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);

// only matters when layout1 has dim 2
if (matmul.transposeA)
@@ -335,7 +335,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
DType dst_dtype;

DnnOprHelper<megdnn::MatrixMul> dnn_opr(matmul.param());
dnn_opr.opr().deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.opr().deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);

if (dim1 == 0 || dim2 == 0) {
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false};
@@ -378,7 +378,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(

DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn, matmul.param(), matmul.policy());
DType dst_dtype;
dnn_opr.op()->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype);
dnn_opr.op()->deduce_dtype(layout1.dtype, layout2.dtype, dst_dtype);

TensorShape tshp, batch_shp;
size_t j = 0;


Loading…
Cancel
Save