GitOrigin-RevId: 05291baf98
tags/v1.3.0
@@ -102,24 +102,24 @@ class DeformableConvBackwardDataImpl::AlgoMatmul final : public AlgoBase { | |||||
private: | private: | ||||
static WorkspaceBundle get_bundle(const SizeArgs& args); | static WorkspaceBundle get_bundle(const SizeArgs& args); | ||||
static void get_matmul_layout(const SizeArgs& args, TensorLayout& al, | |||||
TensorLayout& bl, TensorLayout& cl); | |||||
public: | public: | ||||
AlgoMatmul() {} | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "AlgoMatmul"; } | |||||
std::vector<SearchItem> get_subopr_list( | |||||
const TensorLayoutArray& layouts, | |||||
const OperatorBase* opr) const override; | |||||
const char* name() const override { return "MATMUL"; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
}; | }; | ||||
class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj { | class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj { | ||||
AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
public: | public: | ||||
AlgoPack(); | AlgoPack(); | ||||
AlgoMatmul algo_matmul; | AlgoMatmul algo_matmul; | ||||
@@ -57,24 +57,47 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, | |||||
return p; | return p; | ||||
} | } | ||||
}; // anonymous namespace | |||||
bool Algo::is_available(const SizeArgs&) const { | |||||
return true; | |||||
std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config( | |||||
const DeformableConvForwardImpl::CanonizedFilterMeta& fm, | |||||
const TensorLayout& im, | |||||
const TensorLayout& out_grad) { | |||||
auto&& dt = im.dtype; | |||||
size_t batch_sz = im[0], OH = out_grad[2], | |||||
OW = out_grad[3], FH = fm.spatial[0], FW = fm.spatial[1]; | |||||
size_t M = fm.icpg * FH * FW, K = fm.ocpg, N = batch_sz * OH * OW, | |||||
batch = fm.group; | |||||
TensorLayout al = {{batch, K, M}, dt}; | |||||
TensorLayout bl = {{batch, K, N}, dt}; | |||||
TensorLayout cl = {{batch, M, N}, dt}; | |||||
BatchedMatrixMulForward::Param param; | |||||
param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
param.transposeA = true; | |||||
return {{al, bl, cl}, param}; | |||||
} | } | ||||
void Algo::get_matmul_layout(const SizeArgs& args, TensorLayout& al, | |||||
TensorLayout& bl, TensorLayout& cl) { | |||||
auto&& dt = args.im_layout.dtype; | |||||
auto&& fm = args.filter_meta; | |||||
size_t batch_sz = args.im_layout[0], OH = args.out_grad_layout[2], | |||||
OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; | |||||
}; // anonymous namespace | |||||
size_t M = fm.icpg * FH * FW, K = fm.ocpg, N = batch_sz * OH * OW, | |||||
batch = fm.group; | |||||
al = {{batch, K, M}, dt}; | |||||
bl = {{batch, K, N}, dt}; | |||||
cl = {{batch, M, N}, dt}; | |||||
std::vector<Algorithm::SearchItem> | |||||
Algo::get_subopr_list( | |||||
const TensorLayoutArray& layouts, const OperatorBase* opr) const { | |||||
const DeformableConvBackwardDataImpl* deformable_conv = | |||||
static_cast<const DeformableConvBackwardDataImpl*>(opr); | |||||
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta( | |||||
layouts[0].ndim, layouts[1], layouts[2]); | |||||
auto&& config = sub_opr_config(fm, layouts[0], layouts[4]); | |||||
std::string param_str; | |||||
Algorithm::serialize_write_pod(config.second, param_str); | |||||
return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, | |||||
config.first}}; | |||||
} | |||||
bool Algo::is_available(const SizeArgs&) const { | |||||
return true; | |||||
} | } | ||||
WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | ||||
@@ -83,14 +106,20 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | |||||
OC = args.out_grad_layout[1], OH = args.out_grad_layout[2], | OC = args.out_grad_layout[1], OH = args.out_grad_layout[2], | ||||
OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; | OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; | ||||
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
TensorLayout al, bl, cl; | |||||
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid() && | |||||
!args.opr->execution_policy().sub_policy.empty()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
bmatmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
get_matmul_layout(args, al, bl, cl); | |||||
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
bmm_opr->param().transposeA = true; | |||||
auto&& config = sub_opr_config(args.filter_meta, args.im_layout, | |||||
args.out_grad_layout); | |||||
bmatmul_opr->param() = config.second; | |||||
size_t bmm_ws = bmm_opr->get_workspace_in_bytes(al, bl, cl); | |||||
size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes( | |||||
config.first[0], config.first[1], config.first[2]); | |||||
size_t result_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | size_t result_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | ||||
size_t relayout_ws1 = batch_sz * OC * OH * OW * sizeof(float); | size_t relayout_ws1 = batch_sz * OC * OH * OW * sizeof(float); | ||||
size_t relayout_ws2 = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | size_t relayout_ws2 = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | ||||
@@ -154,21 +183,24 @@ void Algo::exec(const ExecArgs& args) const { | |||||
// matmul [g, icpg, FH, FW, ocpg] * [g, ocpg, N, OH, OW] => | // matmul [g, icpg, FH, FW, ocpg] * [g, ocpg, N, OH, OW] => | ||||
// => [g, icpg, FH, FW, N, OH, OW] | // => [g, icpg, FH, FW, N, OH, OW] | ||||
{ | { | ||||
TensorLayout al, bl, cl; | |||||
get_matmul_layout(args, al, bl, cl); | |||||
TensorND A(static_cast<void*>(dev_filter), al), | |||||
B(static_cast<void*>(relayout_ws1), bl), | |||||
C(static_cast<void*>(result_ws), cl); | |||||
size_t bmm_ws_size = bundle.get_size(0); | |||||
auto&& bmm_opr = | |||||
auto bmatmul_opr = | |||||
args.handle->create_operator<BatchedMatrixMulForward>(); | args.handle->create_operator<BatchedMatrixMulForward>(); | ||||
if (args.opr->execution_policy().algo.valid()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
bmatmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
bmm_opr->param().transposeA = true; | |||||
auto&& config = sub_opr_config(args.filter_meta, args.im_layout, | |||||
args.out_grad_layout); | |||||
bmatmul_opr->param() = config.second; | |||||
bmm_opr->exec( | |||||
TensorND A(static_cast<void*>(dev_filter), config.first[0]), | |||||
B(static_cast<void*>(relayout_ws1), config.first[1]), | |||||
C(static_cast<void*>(result_ws), config.first[2]); | |||||
size_t bmm_ws_size = bundle.get_size(0); | |||||
bmatmul_opr->exec( | |||||
A, B, C, | A, B, C, | ||||
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); | Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); | ||||
} | } | ||||
@@ -92,20 +92,20 @@ public: | |||||
class DeformableConvBackwardFilterImpl::AlgoMatmul final : public AlgoBase { | class DeformableConvBackwardFilterImpl::AlgoMatmul final : public AlgoBase { | ||||
private: | private: | ||||
static void get_matmul_layout(const SizeArgs& args, TensorLayout& al, | |||||
TensorLayout& bl, TensorLayout& cl); | |||||
static WorkspaceBundle get_bundle(const SizeArgs& args); | static WorkspaceBundle get_bundle(const SizeArgs& args); | ||||
public: | public: | ||||
AlgoMatmul() {} | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "AlgoMatmul"; } | |||||
std::vector<SearchItem> get_subopr_list( | |||||
const TensorLayoutArray& layouts, | |||||
const OperatorBase* opr) const override; | |||||
const char* name() const override { return "MATMUL"; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
}; | }; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -57,25 +58,46 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, | |||||
return p; | return p; | ||||
} | } | ||||
}; // anonymous namespace | |||||
bool Algo::is_available(const SizeArgs&) const { | |||||
return true; | |||||
} | |||||
void Algo::get_matmul_layout(const SizeArgs& args, TensorLayout& al, | |||||
TensorLayout& bl, TensorLayout& cl) { | |||||
auto&& dt = args.im_layout.dtype; | |||||
auto&& fm = args.filter_grad_meta; | |||||
size_t batch_sz = args.im_layout[0], OH = args.out_grad_layout[2], | |||||
OW = args.out_grad_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; | |||||
std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config( | |||||
const DeformableConvBackwardFilterImpl::CanonizedFilterMeta& fm, | |||||
const TensorLayout& im, const TensorLayout& out_grad) { | |||||
auto&& dt = im.dtype; | |||||
size_t batch_sz = im[0], OH = out_grad[2], OW = out_grad[3], | |||||
FH = fm.spatial[0], FW = fm.spatial[1]; | |||||
size_t M = fm.ocpg, K = OH * OW * batch_sz, N = fm.icpg * FH * FW, | size_t M = fm.ocpg, K = OH * OW * batch_sz, N = fm.icpg * FH * FW, | ||||
batch = fm.group; | batch = fm.group; | ||||
TensorLayout al = {{batch, M, K}, dt}; | |||||
TensorLayout bl = {{batch, N, K}, dt}; | |||||
TensorLayout cl = {{batch, M, N}, dt}; | |||||
BatchedMatrixMulForward::Param param; | |||||
param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
param.transposeB = true; | |||||
al = {{batch, M, K}, dt}; | |||||
bl = {{batch, N, K}, dt}; | |||||
cl = {{batch, M, N}, dt}; | |||||
return {{al, bl, cl}, param}; | |||||
} | |||||
}; // anonymous namespace | |||||
std::vector<Algorithm::SearchItem> | |||||
Algo::get_subopr_list( | |||||
const TensorLayoutArray& layouts, const OperatorBase* opr) const { | |||||
const DeformableConvBackwardFilterImpl* deformable_conv = | |||||
static_cast<const DeformableConvBackwardFilterImpl*>(opr); | |||||
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta( | |||||
layouts[0].ndim, layouts[4], layouts[1]); | |||||
auto&& config = sub_opr_config(fm, layouts[0], layouts[3]); | |||||
std::string param_str; | |||||
Algorithm::serialize_write_pod(config.second, param_str); | |||||
return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, | |||||
config.first}}; | |||||
} | |||||
bool Algo::is_available(const SizeArgs&) const { | |||||
return true; | |||||
} | } | ||||
WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | ||||
@@ -85,16 +107,22 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | |||||
size_t IC = fm.group * fm.icpg, OC = args.out_grad_layout[1]; | size_t IC = fm.group * fm.icpg, OC = args.out_grad_layout[1]; | ||||
auto batch_sz = args.im_layout[0]; | auto batch_sz = args.im_layout[0]; | ||||
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
TensorLayout al, bl, cl; | |||||
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid() && | |||||
!args.opr->execution_policy().sub_policy.empty()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
bmatmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
get_matmul_layout(args, al, bl, cl); | |||||
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
bmm_opr->param().transposeB = true; | |||||
auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout, | |||||
args.out_grad_layout); | |||||
bmatmul_opr->param() = config.second; | |||||
size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | ||||
size_t out_grad_ws = batch_sz * OC * OH * OW * sizeof(float); | size_t out_grad_ws = batch_sz * OC * OH * OW * sizeof(float); | ||||
size_t bmm_ws = bmm_opr->get_workspace_in_bytes(al, bl, cl); | |||||
size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes( | |||||
config.first[0], config.first[1], config.first[2]); | |||||
return {nullptr, {col_ws, out_grad_ws, bmm_ws}}; | return {nullptr, {col_ws, out_grad_ws, bmm_ws}}; | ||||
} | } | ||||
@@ -138,20 +166,23 @@ void Algo::exec(const ExecArgs& args) const { | |||||
args.handle->relayout_opr()->exec(C2, C3); | args.handle->relayout_opr()->exec(C2, C3); | ||||
// matmul | // matmul | ||||
TensorLayout al, bl, cl; | |||||
get_matmul_layout(args, al, bl, cl); | |||||
TensorND A(static_cast<void*>(out_grad_ws), al), | |||||
B(static_cast<void*>(col_ws), bl), | |||||
C(static_cast<void*>(dev_filter_grad), cl); | |||||
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
bmatmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
size_t bmm_ws_size = bundle.get_size(2); | |||||
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout, | |||||
args.out_grad_layout); | |||||
bmatmul_opr->param() = config.second; | |||||
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
bmm_opr->param().transposeB = true; | |||||
TensorND A(static_cast<void*>(out_grad_ws), config.first[0]), | |||||
B(static_cast<void*>(col_ws), config.first[1]), | |||||
C(static_cast<void*>(dev_filter_grad), config.first[2]); | |||||
bmm_opr->exec( | |||||
size_t bmm_ws_size = bundle.get_size(2); | |||||
bmatmul_opr->exec( | |||||
A, B, C, | A, B, C, | ||||
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); | Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); | ||||
} | } | ||||
@@ -87,20 +87,20 @@ public: | |||||
class DeformableConvForwardImpl::AlgoMatmul final : public AlgoBase { | class DeformableConvForwardImpl::AlgoMatmul final : public AlgoBase { | ||||
private: | private: | ||||
static void get_matmul_layout(const SizeArgs& args, TensorLayout& al, | |||||
TensorLayout& bl, TensorLayout& cl); | |||||
static WorkspaceBundle get_bundle(const SizeArgs& args); | static WorkspaceBundle get_bundle(const SizeArgs& args); | ||||
public: | public: | ||||
AlgoMatmul(){}; | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "AlgoMatmul"; } | |||||
std::vector<SearchItem> get_subopr_list( | |||||
const TensorLayoutArray& layouts, | |||||
const OperatorBase* opr) const override; | |||||
const char* name() const override { return "MATMUL"; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
}; | }; | ||||
@@ -57,24 +57,47 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, | |||||
return p; | return p; | ||||
} | } | ||||
std::pair<TensorLayoutArray, BatchedMatrixMulForward::Param> sub_opr_config( | |||||
const DeformableConvForwardImpl::CanonizedFilterMeta& fm, | |||||
const TensorLayout& im, | |||||
const TensorLayout& dst) { | |||||
auto&& dt = im.dtype; | |||||
size_t batch_sz = im[0], OH = dst[2], | |||||
OW = dst[3], FH = fm.spatial[0], FW = fm.spatial[1]; | |||||
size_t M = fm.ocpg, N = OH * OW * batch_sz, K = fm.icpg * FH * FW, | |||||
batch = fm.group; | |||||
TensorLayout al = {{batch, M, K}, dt}; | |||||
TensorLayout bl = {{batch, K, N}, dt}; | |||||
TensorLayout cl = {{batch, M, N}, dt}; | |||||
BatchedMatrixMulForward::Param param; | |||||
param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
return {{al, bl, cl}, param}; | |||||
} | |||||
}; // anonymous namespace | }; // anonymous namespace | ||||
bool Algo::is_available(const SizeArgs&) const { | |||||
return true; | |||||
std::vector<Algorithm::SearchItem> | |||||
Algo::get_subopr_list( | |||||
const TensorLayoutArray& layouts, const OperatorBase* opr) const { | |||||
const DeformableConvForwardImpl* deformable_conv = | |||||
static_cast<const DeformableConvForwardImpl*>(opr); | |||||
CanonizedFilterMeta fm = deformable_conv->make_canonized_filter_meta( | |||||
layouts[0].ndim, layouts[1], layouts[2]); | |||||
auto&& config = sub_opr_config(fm, layouts[0], layouts[4]); | |||||
std::string param_str; | |||||
Algorithm::serialize_write_pod(config.second, param_str); | |||||
return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, | |||||
config.first}}; | |||||
} | } | ||||
void Algo::get_matmul_layout(const SizeArgs& args, TensorLayout& al, | |||||
TensorLayout& bl, TensorLayout& cl) { | |||||
auto&& dt = args.im_layout.dtype; | |||||
auto&& fm = args.filter_meta; | |||||
size_t batch_sz = args.im_layout[0], OH = args.dst_layout[2], | |||||
OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; | |||||
size_t M = fm.ocpg, N = OH * OW * batch_sz, K = fm.icpg * FH * FW, | |||||
batch = fm.group; | |||||
al = {{batch, M, K}, dt}; | |||||
bl = {{batch, K, N}, dt}; | |||||
cl = {{batch, M, N}, dt}; | |||||
bool Algo::is_available(const SizeArgs&) const { | |||||
return true; | |||||
} | } | ||||
WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | ||||
@@ -83,17 +106,24 @@ WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { | |||||
OC = args.dst_layout[1], OH = args.dst_layout[2], | OC = args.dst_layout[1], OH = args.dst_layout[2], | ||||
OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; | OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; | ||||
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
TensorLayout al, bl, cl; | |||||
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid() && | |||||
!args.opr->execution_policy().sub_policy.empty()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
bmatmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
get_matmul_layout(args, al, bl, cl); | |||||
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
auto&& config = | |||||
sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout); | |||||
bmatmul_opr->param() = config.second; | |||||
size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | size_t col_ws = batch_sz * IC * FH * FW * OH * OW * sizeof(float); | ||||
size_t bmm_ws = bmm_opr->get_workspace_in_bytes(al, bl, cl); | |||||
size_t bmm_ws = bmatmul_opr->get_workspace_in_bytes( | |||||
config.first[0], config.first[1], config.first[2]); | |||||
size_t result_ws = batch_sz * OC * OH * OW * sizeof(float); | size_t result_ws = batch_sz * OC * OH * OW * sizeof(float); | ||||
return {nullptr, {col_ws, bmm_ws, result_ws}}; | |||||
return WorkspaceBundle{nullptr, {col_ws, bmm_ws, result_ws}}; | |||||
} | } | ||||
size_t Algo::get_workspace_in_bytes(const SizeArgs& args) const { | size_t Algo::get_workspace_in_bytes(const SizeArgs& args) const { | ||||
@@ -123,18 +153,25 @@ void Algo::exec(const ExecArgs& args) const { | |||||
// im2col | // im2col | ||||
deformable_conv::im2col(dev_im, dev_offset, dev_mask, | deformable_conv::im2col(dev_im, dev_offset, dev_mask, | ||||
static_cast<float*>(col_ws), p); | static_cast<float*>(col_ws), p); | ||||
// matmul | |||||
TensorLayout al, bl, cl; | |||||
get_matmul_layout(args, al, bl, cl); | |||||
TensorND A(static_cast<void*>(dev_filter), al), | |||||
B(static_cast<void*>(col_ws), bl), | |||||
C(static_cast<void*>(result_ws), cl); | |||||
auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
if (args.opr->execution_policy().algo.valid()) { | |||||
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); | |||||
bmatmul_opr->execution_policy() = | |||||
args.opr->execution_policy().sub_policy[0]; | |||||
} | |||||
auto&& config = | |||||
sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout); | |||||
bmatmul_opr->param() = config.second; | |||||
// matmul | |||||
TensorND A(static_cast<void*>(dev_filter), config.first[0]), | |||||
B(static_cast<void*>(col_ws), config.first[1]), | |||||
C(static_cast<void*>(result_ws), config.first[2]); | |||||
size_t bmm_ws_size = bundle.get_size(1); | size_t bmm_ws_size = bundle.get_size(1); | ||||
auto&& bmm_opr = args.handle->create_operator<BatchedMatrixMulForward>(); | |||||
bmm_opr->param().compute_mode = param::MatrixMul::ComputeMode::DEFAULT; | |||||
bmm_opr->exec( | |||||
bmatmul_opr->exec( | |||||
A, B, C, | A, B, C, | ||||
Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); | Workspace(static_cast<megdnn::dt_byte*>(bmm_ws), bmm_ws_size)); | ||||
// relayout | // relayout | ||||