Browse Source

fix(megdnn): add algo for matmul/batchedmatrixmul of naive and opencl

GitOrigin-RevId: 2409b6ba16
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
87ff58f7fc
7 changed files with 90 additions and 17 deletions
  1. +18
    -3
      dnn/src/naive/batched_matrix_mul/opr_impl.cpp
  2. +4
    -7
      dnn/src/naive/batched_matrix_mul/opr_impl.h
  3. +3
    -0
      dnn/src/naive/handle.cpp
  4. +12
    -0
      dnn/src/naive/handle.h
  5. +35
    -0
      dnn/src/naive/matrix_mul/algorithms.h
  6. +14
    -0
      dnn/src/naive/matrix_mul/opr_impl.cpp
  7. +4
    -7
      dnn/src/naive/matrix_mul/opr_impl.h

+ 18
- 3
dnn/src/naive/batched_matrix_mul/opr_impl.cpp View File

@@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A,

}

} // namespace naive
} // namespace megdnn
std::vector<BatchedMatrixMulForward::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo()};
}

// vim: syntax=cpp.doxygen
BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) {
return static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo();
}

} // namespace naive
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 4
- 7
dnn/src/naive/batched_matrix_mul/opr_impl.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
@@ -25,17 +26,13 @@ public:

std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override {
return {};
}
const TensorLayout& /*C*/) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override {
return nullptr;
}
bool /* reproducible */) override;

const char* get_algorithm_set_name() const override { return "DEFAULT"; }



+ 3
- 0
dnn/src/naive/handle.cpp View File

@@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm
DefaultLocalShareBackwardFilterAlgorithm
HandleImpl::m_default_local_share_bwd_filter_algo;

DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo;
DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo;

HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type)
: HandleImplHelper(computing_handle, type),


+ 12
- 0
dnn/src/naive/handle.h View File

@@ -13,6 +13,7 @@

#include "src/common/handle_impl.h"
#include "src/naive/convolution/algorithms.h"
#include "src/naive/matrix_mul/algorithms.h"
#include "src/naive/local_share/algorithms.h"
#include "src/naive/convolution3d/algorithms.h"

@@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper {
static DefaultLocalShareBackwardFilterAlgorithm
m_default_local_share_bwd_filter_algo;

static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo;
static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo;

//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
template <typename T>
void move_kern_func_to_new_kern_and_dispatch(T& func) {
@@ -109,6 +113,14 @@ public:
return &m_default_local_share_bwd_filter_algo;
}

MatrixMulForward::Algorithm* default_matmul_fwd_algo() {
return &m_default_matmul_fwd_algo;
}

BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() {
return &m_default_batched_matmul_fwd_algo;
}

Relayout* relayout_opr() override {
return get_helper_opr<Relayout, 2>(this);
}


+ 35
- 0
dnn/src/naive/matrix_mul/algorithms.h View File

@@ -0,0 +1,35 @@
/**
* \file dnn/src/naive/matrix_mul/algorithms.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs/linalg.h"

namespace megdnn {
namespace naive {

class DefaultMatrixMulAlgorithm final
: public megdnn::MatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};

class DefaultBatchedMatrixMulAlgorithm final
: public megdnn::BatchedMatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};

} // namespace naive
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 14
- 0
dnn/src/naive/matrix_mul/opr_impl.cpp View File

@@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
MIDOUT_END();
}

std::vector<MatrixMulForward::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()};
}

MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) {
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
}

} // namespace naive
} // namespace megdnn



+ 4
- 7
dnn/src/naive/matrix_mul/opr_impl.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
@@ -26,17 +27,13 @@ public:

std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override {
return {};
}
const TensorLayout& /*C*/) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override {
return nullptr;
}
bool /* reproducible */) override;

const char* get_algorithm_set_name() const override { return "DEFAULT"; }



Loading…
Cancel
Save