|
- /**
- * \file dnn/test/common/exec_proxy.h
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
- #pragma once
- #include "megdnn/basic_types.h"
-
- #include "test/common/workspace_wrapper.h"
-
- #include <cstddef>
- #include <vector>
-
- namespace megdnn {
- namespace test {
-
- template <typename Opr, size_t Arity, bool has_workspace>
- struct ExecProxy;
-
- template <typename Opr>
- struct ExecProxy<Opr, 13, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout, tensors[4].layout, tensors[5].layout,
- tensors[6].layout, tensors[7].layout, tensors[8].layout,
- tensors[9].layout, tensors[10].layout, tensors[11].layout,
- tensors[12].layout));
- opr->exec(
- tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
- tensors[6], tensors[7], tensors[8], tensors[9], tensors[10],
- tensors[11], tensors[12], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 10, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout, tensors[4].layout, tensors[5].layout,
- tensors[6].layout, tensors[7].layout, tensors[8].layout,
- tensors[9].layout));
- opr->exec(
- tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
- tensors[6], tensors[7], tensors[8], tensors[9], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 9, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout, tensors[4].layout, tensors[5].layout,
- tensors[6].layout, tensors[7].layout, tensors[8].layout));
- opr->exec(
- tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
- tensors[6], tensors[7], tensors[8], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 8, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout, tensors[4].layout, tensors[5].layout,
- tensors[6].layout, tensors[7].layout));
- opr->exec(
- tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
- tensors[6], tensors[7], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 7, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout, tensors[4].layout, tensors[5].layout,
- tensors[6].layout));
- opr->exec(
- tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
- tensors[6], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 6, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout, tensors[4].layout, tensors[5].layout));
- opr->exec(
- tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
- W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 5, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout, tensors[4].layout));
- opr->exec(
- tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
- W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 4, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout,
- tensors[3].layout));
- opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 3, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(
- tensors[0].layout, tensors[1].layout, tensors[2].layout));
- opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 2, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(tensors[0].layout, tensors[1].layout));
- opr->exec(tensors[0], tensors[1], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 1, true> {
- WorkspaceWrapper W;
- void exec(Opr* opr, const TensorNDArray& tensors) {
- if (!W.valid()) {
- W = WorkspaceWrapper(opr->handle(), 0);
- }
- W.update(opr->get_workspace_in_bytes(tensors[0].layout));
- opr->exec(tensors[0], W.workspace());
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 5, false> {
- void exec(Opr* opr, const TensorNDArray& tensors) {
- opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 4, false> {
- void exec(Opr* opr, const TensorNDArray& tensors) {
- opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 3, false> {
- void exec(Opr* opr, const TensorNDArray& tensors) {
- opr->exec(tensors[0], tensors[1], tensors[2]);
- }
- };
-
- template <typename Opr>
- struct ExecProxy<Opr, 2, false> {
- void exec(Opr* opr, const TensorNDArray& tensors) {
- opr->exec(tensors[0], tensors[1]);
- }
- };
-
- } // namespace test
- } // namespace megdnn
- // vim: syntax=cpp.doxygen
|