|
@@ -586,7 +586,7 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::is_preferred( |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/* ===================== Matrix mul nchw44 algo ===================== */ |
|
|
/* ===================== Matrix mul nchw44 algo ===================== */ |
|
|
namespace{ |
|
|
|
|
|
|
|
|
namespace { |
|
|
void kern_matmul_nchw44(const NCBKernParam& param) { |
|
|
void kern_matmul_nchw44(const NCBKernParam& param) { |
|
|
bool is_xcorr = !param.filter_meta.should_flip; |
|
|
bool is_xcorr = !param.filter_meta.should_flip; |
|
|
UNPACK_CONV_F32_NCB_KERN_SIZES(param); |
|
|
UNPACK_CONV_F32_NCB_KERN_SIZES(param); |
|
@@ -628,7 +628,7 @@ void kern_matmul_nchw44(const NCBKernParam& param) { |
|
|
} |
|
|
} |
|
|
TensorND B_, C_; |
|
|
TensorND B_, C_; |
|
|
for (size_t n = 0; n < N; ++n) { |
|
|
for (size_t n = 0; n < N; ++n) { |
|
|
float*C_src, *C_dst; |
|
|
|
|
|
|
|
|
float *C_src, *C_dst; |
|
|
float* diff = const_cast<float*>(param.diff<float>() + n * param.inp_bs); |
|
|
float* diff = const_cast<float*>(param.diff<float>() + n * param.inp_bs); |
|
|
float* grad = param.grad<float>() + n * param.out_bs; |
|
|
float* grad = param.grad<float>() + n * param.out_bs; |
|
|
if (is1X1) { |
|
|
if (is1X1) { |
|
@@ -637,13 +637,13 @@ void kern_matmul_nchw44(const NCBKernParam& param) { |
|
|
C_src = static_cast<float*>(bundle.get(0)); |
|
|
C_src = static_cast<float*>(bundle.get(0)); |
|
|
} |
|
|
} |
|
|
{ |
|
|
{ |
|
|
B_.layout = TensorLayout({OC/4, IH * IW, 4}, param.diff_type); |
|
|
|
|
|
|
|
|
B_.layout = TensorLayout({OC / 4, IH * IW, 4}, param.diff_type); |
|
|
B_.reset_ptr(static_cast<void*>(diff)); |
|
|
B_.reset_ptr(static_cast<void*>(diff)); |
|
|
C_.layout = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type); |
|
|
C_.layout = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type); |
|
|
C_.reset_ptr(C_src); |
|
|
C_.reset_ptr(C_src); |
|
|
Workspace workspace( |
|
|
Workspace workspace( |
|
|
static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)); |
|
|
static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)); |
|
|
auto matmul_opr =get_matmul_opr(param); |
|
|
|
|
|
|
|
|
auto matmul_opr = get_matmul_opr(param); |
|
|
matmul_opr->exec(A_dst, B_, C_, workspace); |
|
|
matmul_opr->exec(A_dst, B_, C_, workspace); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|