You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.h 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #pragma once
  2. #include "megbrain/test/helper.h"
  3. #include "megbrain/gopt/framework.h"
  4. #include "megbrain/opr/basic_arith_wrapper.h"
  5. #include "megbrain/opr/blas.h"
  6. #include "megbrain/opr/dnn/convolution.h"
  7. #include "megbrain/opr/dnn/pooling.h"
  8. #include "megbrain/opr/imgproc.h"
  9. #include "megbrain/opr/nn_int.h"
  10. #include "megbrain/opr/tensor_gen.h"
  11. #include "megbrain/opr/tensor_manip.h"
  12. #include "megbrain/opr/utility.h"
  13. namespace mgb {
  14. class Network {
  15. private:
  16. HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{-0.01, 0.01};
  17. CompNode cn;
  18. public:
  19. std::shared_ptr<ComputingGraph> graph = ComputingGraph::make();
  20. Network(CompNode cn_) : cn{cn_} {}
  21. ~Network() noexcept = default;
  22. using KernSize = SmallVector<size_t, 2>;
  23. using Stride = SmallVector<size_t, 2>;
  24. using Padding = SmallVector<size_t, 2>;
  25. SymbolVar add_var(const char* name, const TensorShape& shp = {1}) {
  26. return opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name);
  27. }
  28. SymbolVar add_cvar(const char* name, const TensorShape& shp = {1}) {
  29. return opr::SharedDeviceTensor::make(*graph, *gen(shp), cn).rename(name);
  30. }
  31. SymbolVar add_conv(
  32. SymbolVar f, size_t output_channels, KernSize kern_size,
  33. DType out_dtype = dtype::Float32(), bool has_relu = true,
  34. Stride stride = {1, 1}, Padding padding = {0, 0});
  35. SymbolVar add_group_conv(
  36. SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size,
  37. DType out_dtype = dtype::Float32(), bool has_relu = true,
  38. Stride stride = {1, 1}, Padding padding = {0, 0});
  39. SymbolVar add_deconv(
  40. SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype);
  41. SymbolVar add_elemwise(
  42. const SymbolVarArray inps, DType out_dtype = dtype::Float32(),
  43. opr::Elemwise::Param::Mode mode = opr::Elemwise::Param::Mode::ADD);
  44. using Window = SmallVector<size_t, 2>;
  45. SymbolVar add_pooling(
  46. SymbolVar f, Window window, Stride stride = {1, 1},
  47. Padding padding = {0, 0},
  48. opr::Pooling::Param::Mode mode = opr::Pooling::Param::Mode::MAX);
  49. SymbolVar add_type_cvt(SymbolVar f, DType out_dtype = dtype::Float32());
  50. SymbolVar add_concat(SymbolVar f, SymbolVar g, int axis = 0);
  51. SymbolVar add_dimshuffle(SymbolVar f, std::vector<int> pattern);
  52. SymbolVar add_axisaddremove(SymbolVar f);
  53. SymbolVar add_subtensor(SymbolVar f);
  54. SymbolVar add_reshape(SymbolVar f);
  55. SymbolVar add_broadcast(SymbolVar f);
  56. SymbolVar add_copy(SymbolVar f);
  57. };
  58. SymbolVar create_block(
  59. Network& network, SymbolVar f, size_t stride, size_t num_outputs1,
  60. bool has_proj = false, DType out_dtype = dtype::Float32());
  61. SymbolVar make_resnet18(
  62. Network& network, size_t batch = 16, DType out_dtype = dtype::Float32());
  63. SymbolVarArray make_det(
  64. Network& network, size_t batch = 16, DType out_dtype = dtype::Float32());
  65. SymbolVar bottleneck(
  66. Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
  67. size_t stride, DType out_dtype = dtype::Float32());
  68. SymbolVar bottleneck_group(
  69. Network& network, SymbolVar f, size_t input_channels, size_t channels,
  70. size_t stages, size_t s, size_t t, DType out_dtype = dtype::Float32());
  71. SymbolVar make_mobilenet_v2(
  72. Network& network, size_t batch = 1, DType out_dtype = dtype::Float32());
  73. } // namespace mgb
  74. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}