|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- #pragma once
-
- #include "test/common/opr_proxy.h"
-
- namespace megdnn {
- namespace test {
-
- struct OprProxyIndexingMultiAxisVecHelper {
- size_t axes[TensorLayout::MAX_NDIM];
-
- /*!
- * \brief OprProxy for indexing multi-vec family oprs
- *
- * \param init_axes axes that are indexed
- */
- OprProxyIndexingMultiAxisVecHelper(std::initializer_list<size_t> init_axes = {}) {
- size_t i = 0;
- for (auto ax : init_axes)
- axes[i++] = ax;
- }
-
- OprProxyIndexingMultiAxisVecHelper(SmallVector<size_t> init_axes) {
- size_t i = 0;
- for (auto ax : init_axes)
- axes[i++] = ax;
- }
-
- IndexingMultiAxisVec::IndexDesc make_index_desc(
- const TensorNDArray& tensors) const {
- megdnn_assert(tensors.size() >= 3);
- IndexingMultiAxisVec::IndexDesc ret;
- ret.resize(tensors.size() - 2);
- for (size_t i = 2; i < tensors.size(); ++i) {
- ret[i - 2] = {axes[i - 2], tensors[i]};
- }
- return ret;
- }
-
- size_t get_index_ndim(const TensorNDArray& tensors) const {
- megdnn_assert(tensors.size() >= 3);
- size_t ndim = 0;
- for (size_t i = 2; i < tensors.size(); ++i) {
- ndim = std::max(tensors[i].layout.ndim, ndim);
- }
- return ndim;
- }
-
- IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout(
- const TensorLayoutArray& layouts) const {
- megdnn_assert(layouts.size() >= 3);
- IndexingMultiAxisVec::IndexDescLayoutOnly ret;
- ret.resize(layouts.size() - 2);
- for (size_t i = 2; i < layouts.size(); ++i) {
- ret[i - 2] = {axes[i - 2], layouts[i]};
- }
- return ret;
- }
- };
-
- template <>
- struct OprProxy<IndexingMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelper {
- using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper;
-
- void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const {
- WorkspaceWrapper W(
- opr->handle(), opr->get_workspace_in_bytes(
- tensors[1].layout, axes, tensors.size() - 2,
- get_index_ndim(tensors)));
- opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace());
- }
-
- void deduce_layout(IndexingMultiAxisVec* opr, TensorLayoutArray& layouts) {
- opr->deduce_layout(layouts[0], make_index_layout(layouts), layouts[1]);
- }
- };
-
- template <>
- struct OprProxy<IndexingIncrMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelper {
- using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper;
-
- void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const {
- WorkspaceWrapper W(
- opr->handle(), opr->get_workspace_in_bytes(
- tensors[1].layout, axes, tensors.size() - 2,
- get_index_ndim(tensors)));
- opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
- }
-
- void deduce_layout(IndexingIncrMultiAxisVec*, TensorLayoutArray&) {}
- };
-
- template <>
- struct OprProxy<IndexingSetMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelper {
- using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper;
-
- void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const {
- WorkspaceWrapper W(
- opr->handle(), opr->get_workspace_in_bytes(
- tensors[1].layout, axes, tensors.size() - 2,
- get_index_ndim(tensors)));
- opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
- }
-
- void deduce_layout(IndexingSetMultiAxisVec*, TensorLayoutArray&) {}
- };
-
- } // namespace test
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
|