|
|
@@ -6,59 +6,87 @@ |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
* software distributed under the License is distributed on an |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
|
* implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "src/cuda/handle.h" |
|
|
|
#include "src/cuda/matrix_mul/algos.h" |
|
|
|
#include "src/cuda/utils.h" |
|
|
|
#include "src/common/algo_chooser.h" |
|
|
|
|
|
|
|
using namespace megdnn; |
|
|
|
using namespace cuda; |
|
|
|
|
|
|
|
MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16( |
|
|
|
MatrixMulForwardImpl::AlgoBase* algorithm) |
|
|
|
: m_algorithm(algorithm) { |
|
|
|
megdnn_assert_internal(algorithm); |
|
|
|
m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name()); |
|
|
|
} |
|
|
|
|
|
|
|
MatrixMulForwardImpl::AlgoBase::SizeArgs |
|
|
|
MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const { |
|
|
|
auto new_args = args; |
|
|
|
namespace { |
|
|
|
std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> sub_opr_config( |
|
|
|
const TensorLayoutArray& layouts, const MatrixMulForwardImpl* opr) { |
|
|
|
megdnn_assert(layouts.size() == 3); |
|
|
|
std::pair<TensorLayoutArray, MatrixMulForwardImpl::Param> ret; |
|
|
|
ret.first = layouts; |
|
|
|
auto change_dtype = [](TensorLayout& layout) { |
|
|
|
if (layout.dtype == dtype::BFloat16()) { |
|
|
|
layout.dtype = dtype::Float32(); |
|
|
|
} |
|
|
|
}; |
|
|
|
change_dtype(new_args.layout_a); |
|
|
|
change_dtype(new_args.layout_b); |
|
|
|
change_dtype(new_args.layout_c); |
|
|
|
return new_args; |
|
|
|
change_dtype(ret.first[0]); |
|
|
|
change_dtype(ret.first[1]); |
|
|
|
change_dtype(ret.first[2]); |
|
|
|
|
|
|
|
ret.second = opr->param(); |
|
|
|
ret.second.compute_mode = MatrixMulForwardImpl::Param::ComputeMode::DEFAULT; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
std::vector<Algorithm::SearchItem> |
|
|
|
MatrixMulForwardImpl::AlgoBFloat16::get_subopr_list( |
|
|
|
const TensorLayoutArray& layouts, const OperatorBase* opr) const { |
|
|
|
auto&& config = sub_opr_config( |
|
|
|
layouts, static_cast<const MatrixMulForwardImpl*>(opr)); |
|
|
|
|
|
|
|
std::string param_str; |
|
|
|
Algorithm::serialize_write_pod(config.second, param_str); |
|
|
|
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}}; |
|
|
|
} |
|
|
|
|
|
|
|
bool MatrixMulForwardImpl::AlgoBFloat16::is_available( |
|
|
|
const SizeArgs& args) const { |
|
|
|
auto fargs = float_args(args); |
|
|
|
auto&& config = sub_opr_config( |
|
|
|
{args.layout_a, args.layout_b, args.layout_c}, args.opr); |
|
|
|
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); |
|
|
|
matmul_opr->param() = config.second; |
|
|
|
|
|
|
|
return args.layout_a.dtype == dtype::BFloat16() && |
|
|
|
m_algorithm->is_available(fargs); |
|
|
|
get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()), |
|
|
|
config.first[0], config.first[1], config.first[2]); |
|
|
|
} |
|
|
|
|
|
|
|
WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( |
|
|
|
void* ptr, const SizeArgs& args) const { |
|
|
|
auto fargs = float_args(args); |
|
|
|
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>(); |
|
|
|
if (args.opr->execution_policy().algo.valid()) { |
|
|
|
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); |
|
|
|
matmul_opr->execution_policy() = |
|
|
|
args.opr->execution_policy().sub_policy[0]; |
|
|
|
} |
|
|
|
auto&& config = sub_opr_config( |
|
|
|
{args.layout_a, args.layout_b, args.layout_c}, args.opr); |
|
|
|
matmul_opr->param() = config.second; |
|
|
|
|
|
|
|
SmallVector<size_t> sizes; |
|
|
|
auto get_workspace = [&sizes](const TensorLayout& src) { |
|
|
|
TensorLayout dst = src; |
|
|
|
if (dst.dtype == dtype::BFloat16()) { |
|
|
|
dst.dtype = dtype::Float32(); |
|
|
|
auto get_workspace = [&sizes](const TensorLayout& src, |
|
|
|
const TensorLayout& dst) { |
|
|
|
if (src.dtype != dst.dtype) { |
|
|
|
sizes.push_back(dst.span().dist_byte()); |
|
|
|
} |
|
|
|
}; |
|
|
|
get_workspace(args.layout_a); |
|
|
|
get_workspace(args.layout_b); |
|
|
|
get_workspace(args.layout_c); |
|
|
|
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); |
|
|
|
|
|
|
|
get_workspace(args.layout_a, config.first[0]); |
|
|
|
get_workspace(args.layout_b, config.first[1]); |
|
|
|
get_workspace(args.layout_c, config.first[2]); |
|
|
|
sizes.push_back(matmul_opr->get_workspace_in_bytes( |
|
|
|
config.first[0], config.first[1], config.first[2])); |
|
|
|
return {ptr, std::move(sizes)}; |
|
|
|
} |
|
|
|
|
|
|
@@ -82,7 +110,12 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { |
|
|
|
args.opr->handle()->create_operator<MatrixMulForward>(); |
|
|
|
matmul_opr->param() = args.opr->param(); |
|
|
|
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; |
|
|
|
matmul_opr->execution_policy() = {m_algorithm->desc(), {}}; |
|
|
|
if (args.opr->execution_policy().algo.valid()) { |
|
|
|
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); |
|
|
|
matmul_opr->execution_policy() = |
|
|
|
args.opr->execution_policy().sub_policy[0]; |
|
|
|
} |
|
|
|
|
|
|
|
matmul_opr->exec(a, b, c, ctypecvt.workspace()); |
|
|
|
} |
|
|
|
ctypecvt.comp_to_dst_type(c, args.tensor_c); |
|
|
|