GitOrigin-RevId: 2409b6ba16
tags/v1.3.0
@@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
std::vector<BatchedMatrixMulForward::Algorithm*> | |||
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||
const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/) { | |||
return {static_cast<HandleImpl*>(handle()) | |||
->default_batched_matmul_fwd_algo()}; | |||
} | |||
// vim: syntax=cpp.doxygen | |||
BatchedMatrixMulForward::Algorithm* | |||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||
bool /* reproducible */) { | |||
return static_cast<HandleImpl*>(handle()) | |||
->default_batched_matmul_fwd_algo(); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
@@ -25,17 +26,13 @@ public: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/) override { | |||
return {}; | |||
} | |||
const TensorLayout& /*C*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /* reproducible */) override { | |||
return nullptr; | |||
} | |||
bool /* reproducible */) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
@@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm | |||
DefaultLocalShareBackwardFilterAlgorithm | |||
HandleImpl::m_default_local_share_bwd_filter_algo; | |||
DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo; | |||
DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo; | |||
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | |||
HandleType type) | |||
: HandleImplHelper(computing_handle, type), | |||
@@ -13,6 +13,7 @@ | |||
#include "src/common/handle_impl.h" | |||
#include "src/naive/convolution/algorithms.h" | |||
#include "src/naive/matrix_mul/algorithms.h" | |||
#include "src/naive/local_share/algorithms.h" | |||
#include "src/naive/convolution3d/algorithms.h" | |||
@@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper { | |||
static DefaultLocalShareBackwardFilterAlgorithm | |||
m_default_local_share_bwd_filter_algo; | |||
static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo; | |||
static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo; | |||
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | |||
template <typename T> | |||
void move_kern_func_to_new_kern_and_dispatch(T& func) { | |||
@@ -109,6 +113,14 @@ public: | |||
return &m_default_local_share_bwd_filter_algo; | |||
} | |||
MatrixMulForward::Algorithm* default_matmul_fwd_algo() { | |||
return &m_default_matmul_fwd_algo; | |||
} | |||
BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() { | |||
return &m_default_batched_matmul_fwd_algo; | |||
} | |||
Relayout* relayout_opr() override { | |||
return get_helper_opr<Relayout, 2>(this); | |||
} | |||
@@ -0,0 +1,35 @@ | |||
/** | |||
* \file dnn/src/naive/matrix_mul/algorithms.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs/linalg.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class DefaultMatrixMulAlgorithm final | |||
: public megdnn::MatrixMulForward::Algorithm { | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "DEFAULT"; } | |||
uint32_t type() const override { return 0; } | |||
}; | |||
class DefaultBatchedMatrixMulAlgorithm final | |||
: public megdnn::BatchedMatrixMulForward::Algorithm { | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "DEFAULT"; } | |||
uint32_t type() const override { return 0; } | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
MIDOUT_END(); | |||
} | |||
std::vector<MatrixMulForward::Algorithm*> | |||
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||
const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/) { | |||
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||
} | |||
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||
bool /* reproducible */) { | |||
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
@@ -26,17 +27,13 @@ public: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/) override { | |||
return {}; | |||
} | |||
const TensorLayout& /*C*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
const TensorLayout& /*B*/, | |||
const TensorLayout& /*C*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /* reproducible */) override { | |||
return nullptr; | |||
} | |||
bool /* reproducible */) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||