You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mgb_opr_param_defs.fbs 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // generated by gen_param_defs.py for c23d51f3c4f33119fd74f58f04d112ccea8f64f1249ab372300975ab7e710e9a
  2. include "dtype.fbs";
  3. namespace mgb.serialization.fbs.param;
  4. /// mode of collective communication
  5. enum CollectiveCommMode : uint {
  6. /// reduce by sum to output computing node
  7. REDUCE_SUM = 0,
  8. /// copy input value to each output computing node
  9. BROADCAST = 1,
  10. /// each output comp node gets the concatenated value of all inputs
  11. ALL_GATHER = 2,
  12. /// reduce inputs by sum and each output gets one part of it
  13. REDUCE_SCATTER_SUM = 3,
  14. /// every output gets the sum of all inputs
  15. ALL_REDUCE_SUM = 4,
  16. /// every output gets the max of all inputs
  17. ALL_REDUCE_MAX = 5,
  18. /// every output gets the min of all inputs
  19. ALL_REDUCE_MIN = 6,
  20. /// every output gets the prod of all inputs
  21. ALL_REDUCE_PROD = 7,
  22. /// concat inputs to one node
  23. GATHER = 8,
  24. /// scatter input to each output computing node
  25. SCATTER = 9,
  26. /// scatter inputs and gather them on each computing node
  27. ALL_TO_ALL = 10,
  28. }
  29. /// mode for computing the gradient
  30. enum CondExecMarkGradMode : uint {
  31. /// normal gradient mode: sum all the activated components
  32. SUM = 0,
  33. /// use :attr:`CondExecMerge.SUM_COND_OUT` mode so oprs that depend on the
  34. /// gradient opr would not be executed if the forward var is not used.
  35. SUM_COND_OUT = 1,
  36. }
  37. /// static inference option. **Note:** This is a workaround: since
  38. /// currently static inference in MegBrain does not take conditional
  39. /// execution into account, this option can be used to bypass static
  40. /// inference errors. This is currently only used by automatically
  41. /// generated gradient oprs.
  42. enum CondExecMarkStaticInfer : uint {
  43. /// enable both shape and value inference
  44. SHAPE_VALUE = 0,
  45. /// only enable shape inference (disable value inference)
  46. SHAPE_ONLY = 1,
  47. /// disable both shape and value inference
  48. NONE = 2,
  49. }
  50. enum CondExecMergeMode : uint {
  51. /// copy the var whose mask is activated to the output, requiring that
  52. /// exactly one branch is active
  53. EXACT_ONE = 0,
  54. /// like :attr:`EXACT_ONE` with the requirement that all branches have the
  55. /// same shape, so shape inference can be easier
  56. EXACT_ONE_SAME_SHAPE = 1,
  57. /// sum all the active branches into output var; require all branches to
  58. /// have the same shape. Extra shape vars are needed in this mod, so the
  59. /// outputs can be initialized to zero when no input is active (and their
  60. /// shapes are probably unknown).
  61. SUM = 2,
  62. /// like :attr:`SUM` but also add an ExecutionMask to the readers of output
  63. /// vars, so they would be skipped if no branch is taken
  64. SUM_COND_OUT = 3,
  65. }
  66. /// how to compare predicate var with branch keys
  67. enum CondExecPredMode : uint {
  68. /// The outputs correspond to branch keys, and the one which equals
  69. /// predicate would be activated. This behaves like a case-statement in many
  70. /// languages.
  71. CASE = 0,
  72. /// like :attr:`CASE`, but add an extra output that would be activated if no
  73. /// branch is matched
  74. CASE_FALLBACK = 1,
  75. /// One more outputs would be produced than the number of branch keys,
  76. /// representing the interval in which the predicate var fits in. The
  77. /// intervals are defined as :math:`(-\\infty, k_0), [k_0, k_1), \\ldots,
  78. /// [k_{n-2}, k_{n-1}), [k_{n-1}, \infty)`. The keys must be given in
  79. /// ascending order.
  80. PIECEWISE = 2,
  81. }
  82. enum CondExecPredLogicalMode : uint {
  83. /// logical or
  84. OR = 0,
  85. /// logical and
  86. AND = 1,
  87. /// exclusive-or
  88. XOR = 2,
  89. /// not or(inputs)
  90. NOR = 3,
  91. /// not and(inputs)
  92. NAND = 4,
  93. /// not xor(inputs)
  94. XNOR = 5,
  95. }
  96. enum ExecutionPolicyStrategy : uint (bit_flags) {
  97. /// use heuristic to choose the fastest algorithm
  98. HEURISTIC = 0,
  99. /// run possible algorithms on real device to find the best
  100. PROFILE = 1,
  101. /// when profile or heuristic algo selection it require the algosmust be
  102. /// reproducible
  103. REPRODUCIBLE = 2,
  104. /// profile require algos are optmized to achieve fast-profile
  105. OPTIMIZED = 3,
  106. }
  107. enum ExecutionPolicyV0Strategy : uint {
  108. /// use heuristic to choose the fastest algorithm
  109. HEURISTIC = 0,
  110. /// use heuristic to choose the fastest algorithm, and the chosen algorithm
  111. /// is reproducible
  112. HEURISTIC_REPRODUCIBLE = 1,
  113. /// run possible algorithms on real device to find the best
  114. PROFILE = 2,
  115. /// the fastest of profile result that is also reproducible
  116. PROFILE_REPRODUCIBLE = 3,
  117. /// use profile result and heuristic to choose the fastest algorithm
  118. PROFILE_HEURISTIC = 4,
  119. }
  120. table DType {
  121. dtype:DTypeEnum = Byte;
  122. }
  123. table PersistentOutputStorage {
  124. /// This is used for controlling memory sharing. Multiple
  125. /// ``PersistentOutputStorage'' oprs with the same ``share_key'' would share
  126. /// underlying tensor storage. Note that the value ``-1'' is treated
  127. /// specially: storage of oprs with this key would be private and would not
  128. /// be shared with any other opr.
  129. share_key:int = -1;
  130. }
  131. /// optinal axis: axis == -1 means no axis
  132. table OptionalAxis {
  133. axis:int = -1;
  134. }
  135. /// optinal axis: axis == MAX_NDIM means no axis
  136. table OptionalAxisV1 {
  137. axis:int = 7;
  138. }
  139. table ExecutionPolicyV0 {
  140. strategy:ExecutionPolicyV0Strategy = HEURISTIC;
  141. /// workspace limit in bytes
  142. workspace_limit:ulong = 18446744073709551615;
  143. }
  144. /// specify how to select an algorithm for an operator
  145. table ExecutionPolicy {
  146. strategy:ExecutionPolicyStrategy = 1;
  147. /// workspace limit in bytes
  148. workspace_limit:ulong = 18446744073709551615;
  149. }
  150. table AssertEqual {
  151. /// max allowed error; error is defined as the minimal of absolute and
  152. /// relative error
  153. maxerr:float = 0.0001;
  154. /// whether to print maxerr to stdout during opr exec
  155. verbose:bool = false;
  156. }
  157. table FpgaConv {
  158. need_output_quantize:bool = false;
  159. need_output_threshold:bool = false;
  160. stride:int = 1;
  161. input_bit_width:int = 2;
  162. output_bit_width:int = 2;
  163. weight_bit_width:int = 2;
  164. thres0:int = 0;
  165. thres1:int = 1;
  166. unpool_size:uint = 4;
  167. direct_size:uint = 4;
  168. }
  169. /// collective communication between multiple computing nodes on localhost
  170. table CollectiveComm {
  171. /// mode of collective communication
  172. mode:CollectiveCommMode = REDUCE_SUM;
  173. }
  174. /// HACK: The tag of this param def is actually used for another non-generated
  175. /// param def SerializedDType, the sole purpose of this param def is to provide
  176. /// a spare tag. Do not use.
  177. table FakeSerializedDType {
  178. }
  179. /// evaluate a predicate and branch keys to setup ExecutionMask objects with
  180. /// associated predicate proxy vars (PPVs)
  181. table CondExecPred {
  182. /// how to compare predicate var with branch keys
  183. mode:CondExecPredMode = CASE;
  184. /// threshold for checking equality of float point values
  185. eps:float = 0.0001;
  186. }
  187. /// compute a logical function over a set of PPVs
  188. table CondExecPredLogical {
  189. mode:CondExecPredLogicalMode = OR;
  190. }
  191. /// add ExecutionMask of the input PPV to this opr and readers of the outputs of
  192. /// this opr
  193. table CondExecMark {
  194. /// mode for computing the gradient
  195. grad_mode:CondExecMarkGradMode = SUM;
  196. /// static inference option. **Note:** This is a workaround: since
  197. /// currently static inference in MegBrain does not take conditional
  198. /// execution into account, this option can be used to bypass static
  199. /// inference errors. This is currently only used by automatically
  200. /// generated gradient oprs.
  201. static_infer:CondExecMarkStaticInfer = SHAPE_VALUE;
  202. }
  203. /// merge multiple conditional execution branches
  204. table CondExecMerge {
  205. /// number of output vars (i.e. vars per branch)
  206. nr_output:uint = 1;
  207. mode:CondExecMergeMode = EXACT_ONE;
  208. }
  209. /// opr Implements NVIDIA Optical Flow SDK.
  210. table NvOf {
  211. precision:uint = 1;
  212. }