GitOrigin-RevId: 2409b6ba16
tags/v1.3.0
@@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||||
} | } | ||||
} // namespace naive | |||||
} // namespace megdnn | |||||
std::vector<BatchedMatrixMulForward::Algorithm*> | |||||
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_batched_matmul_fwd_algo()}; | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
BatchedMatrixMulForward::Algorithm* | |||||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
bool /* reproducible */) { | |||||
return static_cast<HandleImpl*>(handle()) | |||||
->default_batched_matmul_fwd_algo(); | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -25,17 +26,13 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/) override { | |||||
return {}; | |||||
} | |||||
const TensorLayout& /*C*/) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | ||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) override { | |||||
return nullptr; | |||||
} | |||||
bool /* reproducible */) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
@@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm | |||||
DefaultLocalShareBackwardFilterAlgorithm | DefaultLocalShareBackwardFilterAlgorithm | ||||
HandleImpl::m_default_local_share_bwd_filter_algo; | HandleImpl::m_default_local_share_bwd_filter_algo; | ||||
DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo; | |||||
DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo; | |||||
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | ||||
HandleType type) | HandleType type) | ||||
: HandleImplHelper(computing_handle, type), | : HandleImplHelper(computing_handle, type), | ||||
@@ -13,6 +13,7 @@ | |||||
#include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
#include "src/naive/convolution/algorithms.h" | #include "src/naive/convolution/algorithms.h" | ||||
#include "src/naive/matrix_mul/algorithms.h" | |||||
#include "src/naive/local_share/algorithms.h" | #include "src/naive/local_share/algorithms.h" | ||||
#include "src/naive/convolution3d/algorithms.h" | #include "src/naive/convolution3d/algorithms.h" | ||||
@@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper { | |||||
static DefaultLocalShareBackwardFilterAlgorithm | static DefaultLocalShareBackwardFilterAlgorithm | ||||
m_default_local_share_bwd_filter_algo; | m_default_local_share_bwd_filter_algo; | ||||
static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo; | |||||
static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo; | |||||
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | ||||
template <typename T> | template <typename T> | ||||
void move_kern_func_to_new_kern_and_dispatch(T& func) { | void move_kern_func_to_new_kern_and_dispatch(T& func) { | ||||
@@ -109,6 +113,14 @@ public: | |||||
return &m_default_local_share_bwd_filter_algo; | return &m_default_local_share_bwd_filter_algo; | ||||
} | } | ||||
MatrixMulForward::Algorithm* default_matmul_fwd_algo() { | |||||
return &m_default_matmul_fwd_algo; | |||||
} | |||||
BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() { | |||||
return &m_default_batched_matmul_fwd_algo; | |||||
} | |||||
Relayout* relayout_opr() override { | Relayout* relayout_opr() override { | ||||
return get_helper_opr<Relayout, 2>(this); | return get_helper_opr<Relayout, 2>(this); | ||||
} | } | ||||
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* \file dnn/src/naive/matrix_mul/algorithms.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs/linalg.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class DefaultMatrixMulAlgorithm final | |||||
: public megdnn::MatrixMulForward::Algorithm { | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
uint32_t type() const override { return 0; } | |||||
}; | |||||
class DefaultBatchedMatrixMulAlgorithm final | |||||
: public megdnn::BatchedMatrixMulForward::Algorithm { | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
uint32_t type() const override { return 0; } | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
} | } | ||||
std::vector<MatrixMulForward::Algorithm*> | |||||
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) { | |||||
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||||
} | |||||
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
bool /* reproducible */) { | |||||
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||||
} | |||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -26,17 +27,13 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/) override { | |||||
return {}; | |||||
} | |||||
const TensorLayout& /*C*/) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | ||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) override { | |||||
return nullptr; | |||||
} | |||||
bool /* reproducible */) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||