|
|
@@ -21,11 +21,11 @@ using namespace cuda; |
|
|
|
|
|
|
|
namespace { |
|
|
|
std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( |
|
|
|
const ConvolutionBackwardDataImpl::CanonizedFilterMeta& fm, |
|
|
|
const ConvolutionBackwardFilterImpl::CanonizedFilterMeta& fm, |
|
|
|
const TensorLayout& src_layout, const TensorLayout& diff_layout, |
|
|
|
const TensorLayout& grad_layout, |
|
|
|
const ConvolutionBackwardFilterImpl* opr) { |
|
|
|
size_t N = grad_layout.shape[0], IC = fm.icpg, |
|
|
|
size_t N = src_layout.shape[0], IC = fm.icpg, |
|
|
|
OC = fm.ocpg, OH = diff_layout.shape[2], |
|
|
|
OW = diff_layout.shape[3], FH = fm.spatial[0], |
|
|
|
FW = fm.spatial[1]; |
|
|
|