Browse Source

feat(dnn): add indexing_one_hot and indexing_set_one_hot opr

GitOrigin-RevId: c5406c71ff
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
a49f4a66b7
4 changed files with 10 additions and 8 deletions
  1. +5
    -3
      dnn/src/common/opr_trait.h
  2. +1
    -1
      dnn/test/common/opr_algo_proxy.h
  3. +2
    -1
      dnn/test/common/opr_proxy.h
  4. +2
    -3
      dnn/test/common/powc.h

dnn/test/common/opr_trait.h → dnn/src/common/opr_trait.h View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/test/common/opr_trait.h
* \file dnn/src/common/opr_trait.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -14,7 +14,6 @@
#include <cstddef> #include <cstddef>


namespace megdnn { namespace megdnn {
namespace test {


template <typename Opr> template <typename Opr>
struct OprTrait {}; struct OprTrait {};
@@ -114,7 +113,10 @@ DEF(FakeQuantForward, 4, true, true);
DEF(FakeQuantBackward, 5, true, false); DEF(FakeQuantBackward, 5, true, false);
DEF(TQTForward, 3, true, true); DEF(TQTForward, 3, true, true);
DEF(TQTBackward, 5, true, false); DEF(TQTBackward, 5, true, false);
} // namespace test
DEF(PowC, 2, false, true);
DEF(UniformRNG, 1, true, true);
DEF(GaussianRNG, 1, true, true);

} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 1
- 1
dnn/test/common/opr_algo_proxy.h View File

@@ -12,7 +12,7 @@
#pragma once #pragma once


#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "test/common/opr_trait.h"
#include "src/common/opr_trait.h"
#include "test/common/utils.h" #include "test/common/utils.h"


namespace megdnn { namespace megdnn {


+ 2
- 1
dnn/test/common/opr_proxy.h View File

@@ -11,12 +11,13 @@
*/ */
#pragma once #pragma once


#include "src/common/opr_trait.h"

#include "test/common/deduce_layout_proxy.h" #include "test/common/deduce_layout_proxy.h"
#include "test/common/exec_proxy.h" #include "test/common/exec_proxy.h"
#include "test/common/fast_run_cache.h" #include "test/common/fast_run_cache.h"
#include "test/common/inspect_type.h" #include "test/common/inspect_type.h"
#include "test/common/opr_algo_proxy.h" #include "test/common/opr_algo_proxy.h"
#include "test/common/opr_trait.h"
#include "test/common/timer.h" #include "test/common/timer.h"
#include "test/common/workspace_wrapper.h" #include "test/common/workspace_wrapper.h"




+ 2
- 3
dnn/test/common/powc.h View File

@@ -12,13 +12,12 @@


#include "megdnn/handle.h" #include "megdnn/handle.h"
#include "megdnn/oprs/general.h" #include "megdnn/oprs/general.h"
#include "test/common/opr_proxy.h"

#include "src/common/opr_trait.h"


namespace megdnn { namespace megdnn {
namespace test { namespace test {


DEF(PowC, 2, false, true);

void run_powc_test(Handle* handle, DType dtype); void run_powc_test(Handle* handle, DType dtype);


} // namespace test } // namespace test


Loading…
Cancel
Save