|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436 |
- #pragma once
- #include "megdnn/internal/opr_header_prologue.h"
-
- namespace megdnn {
-
- class SeparableConvBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
- DEF_OPR_PARAM(SeparableConv);
-
- public:
- using Mode = Param::Mode;
-
- protected:
- void deduce_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter_x,
- const TensorLayout& filter_y, TensorLayout& dst);
- void check_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter_x,
- const TensorLayout& filter_y, const TensorLayout& dst);
- };
-
- class SeparableConvForward : public SeparableConvBase {
- DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
- _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter_x,
- const TensorLayout& filter_y, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter_x,
- const TensorLayout& filter_y, const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& filter_x,
- const TensorLayout& filter_y, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
- using SeparableConv = SeparableConvForward;
-
- namespace detail {
-
- struct PreprocessedFilter {
- //! user data; its lifetime should be bound to MegDNN Convolution
- //! operator
- void* algorithm_id;
- TensorNDArray tensors;
- };
-
- } // namespace detail
-
- /**
- * \brief base class for convolution operation
- *
- * This operator is supposed to perform convolution on arbitrary input
- * dimensions. The input/output format is N, C, dims..., and kernel format can
- * take two forms:
- * 1. OC, IC, dims..., for conventional dense convolution
- * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
- *
- * Currently, only 2D images are supported.
- */
- template <typename Parameter>
- class ConvolutionBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
- using Param = Parameter;
-
- public:
- Param& param() { return m_param; }
- const Param& param() const { return m_param; }
-
- protected:
- Param m_param;
-
- public:
- static constexpr size_t MAX_SPATIAL_DIM = 2;
- using Mode = typename Param::Mode;
- struct CanonizedFilterMeta {
- DType dtype;
- typename Param::Format format;
-
- uint32_t
- //! whether filter should be flipped (i.e. is CONVOLUTION)
- should_flip,
- group, //!< number of groups
- icpg, //!< input channels per group
- ocpg, //!< output channels per group
- spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
- //! spatial dim
- spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
- //! spatial dim with dilation applied
- dilated_spatial[MAX_SPATIAL_DIM];
-
- //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
- template <typename T>
- void copy_from(const T& b) {
- dtype = b.dtype;
- format = b.format;
- should_flip = b.should_flip;
- group = b.group;
- icpg = b.icpg;
- ocpg = b.ocpg;
- spatial_ndim = b.spatial_ndim;
- memcpy(stride, b.stride, sizeof(stride));
- memcpy(padding, b.padding, sizeof(padding));
- memcpy(spatial, b.spatial, sizeof(spatial));
- memcpy(dilation, b.dilation, sizeof(dilation));
- memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
- }
-
- bool operator==(const CanonizedFilterMeta& b) const {
- bool flag = true;
- flag = flag && (format == b.format);
- flag = flag && (dtype == b.dtype);
- flag = flag && (should_flip == b.should_flip);
- flag = flag && (group == b.group);
- flag = flag && (icpg == b.icpg);
- flag = flag && (ocpg == b.ocpg);
- flag = flag && (spatial_ndim == b.spatial_ndim);
- if (flag) {
- for (uint32_t i = 0; i < spatial_ndim; ++i) {
- flag = flag && (stride[i] == b.stride[i]);
- flag = flag && (padding[i] == b.padding[i]);
- flag = flag && (spatial[i] == b.spatial[i]);
- flag = flag && (dilation[i] == b.dilation[i]);
- flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
- }
- }
- return flag;
- }
- };
- using PreprocessedFilter = detail::PreprocessedFilter;
-
- protected:
- // Check or deduce output DType
- void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
- CanonizedFilterMeta deduce_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- TensorLayout& dst) const;
- CanonizedFilterMeta check_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) const;
-
- CanonizedFilterMeta make_canonized_filter_meta(
- size_t src_ndim, const TensorLayout& filter) const;
- };
-
- class MaskPropagate : public OperatorBase {
- DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
- DEF_OPR_PARAM(MaskPropagate);
-
- public:
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst) = 0;
-
- void deduce_layout(const TensorLayout& src, TensorLayout& dst);
- };
-
- /**
- * \brief ConvolutionForward Operator with 0/1 Mask matrix
- */
- class MaskConvForward : public ConvolutionBase<param::Convolution> {
- DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in mask,
- _megdnn_tensor_out dst, _megdnn_workspace worksapce) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& mask, const TensorLayout& dst) = 0;
-
- void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
- void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& mask, TensorLayout& dst);
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& mask, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
- using MaskConvolution = MaskConvForward;
-
- /**
- * \brief ConvolutionForward operator.
- */
- class ConvolutionForward : public ConvolutionBase<param::Convolution>,
- public detail::MultiAlgoOpr<ConvolutionForward, 3> {
- DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
-
- public:
- /**
- * \param[in] src (n, ic, ih, iw)
- * \param[in] filter (oc, ic, fh, fw)
- * \param[in] preprocessed_filter if weight no preprocessed it will be
- * nullptr, else the preprocessed weights store in the tensors of
- * preprocessed_filter.
- * \param[in] workspace if weight no preprocessed
- * (preprocessed_filter == nullptr), The size of the workspace satisfies the
- * situation that weights is not processed, other wise the size of workspace
- * satisfies the situation that weights is preprocessed
- * \param[out] dst (n, oc, oh, ow)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
- const PreprocessedFilter* preprocessed_filter,
- _megdnn_workspace workspace) = 0;
-
- MGE_WIN_DECLSPEC_FUC void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) {
- exec(src, filter, dst, nullptr, workspace);
- }
- /**
- * \brief execute weight preprocessing, read weights form filter and write
- * to preprocessed_filter after preprocessed.
- *
- * \praram[in] workspace the needed tmp workspace when exec_preprocess
- */
- virtual void exec_preprocess(
- const TensorLayout& src_layout, _megdnn_tensor_in filter,
- const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
- _megdnn_workspace workspace) = 0;
- MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType src, DType filter, DType& dst);
-
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
-
- /**
- * \brief query the workspace needed when executing the opr, if the weights
- * are preprocessed the preprocessed_filter will not be nullptr, else it
- * will be nullptr, the workspace size maybe different whether weights are
- * preprocessed
- *
- * \return the size of workspace needed when executing
- */
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) = 0;
-
- /**
- * \brief deduce the preprocessed filter layouts according to the src,
- * filter and dst layout, the result may contain multi layouts when the
- * weights is not one
- *
- * \return SmallVector<TensorLayout> Derive the layouts of weight
- * preprocessing, return empty if preprocessing is not needed.
- */
- virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) = 0;
-
- /**
- * \brief query the workspace needed when preprocessing the weights,
- * according to the return size, a _megdnn_workspace will be created and
- * passed through exec_preprocess
- *
- * \return the size of workspace needed when preprocessing
- */
- virtual size_t get_preprocess_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::CONVOLUTION_FORWARD;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst, size_t workspace_in_bytes,
- const PreprocessedFilter* preprocessed_filter);
- };
- using Convolution = ConvolutionForward;
-
- /**
- * \brief ConvolutionBackwardData operator.
- *
- * Calculating the gradient wrt. convolution input data.
- */
- class ConvolutionBackwardData
- : public ConvolutionBase<param::Convolution>,
- public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
- DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
-
- public:
- /**
- * \param[in] filter (oc, ic, fh, fw)
- * \param[in] diff (n, oc, oh, ow)
- * \param[out] grad (n, ic, ih, iw)
- */
- virtual void exec(
- _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType filter, DType diff, DType& grad);
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::CONVOLUTION_BACKWARD_DATA;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- /**
- * \brief ConvolutionBackwardFilter operator.
- *
- * Calculating the gradient wrt. convolution filter.
- */
- class ConvolutionBackwardFilter
- : public ConvolutionBase<param::Convolution>,
- public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
- DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
-
- public:
- /**
- * \param[in] src (n, ic, ih, iw)
- * \param[in] diff (n, oc, oh, ow)
- * \param[out] grad (oc, ic, fh, fw)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- /**
- * \brief ConvolutionBias operator
- */
- class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
- public detail::MultiAlgoOpr<ConvBiasForward, 5> {
- DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
-
- public:
- /**
- * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
- * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
- * 4 * ic)
- * \param[in] bias (1, oc, 1, 1)
- * \param[in] z same as dst
- * \param[in] preprocessed_filter if weight no preprocessed it will be
- * nullptr, else the preprocessed weights store in the tensors of
- * preprocessed_filter.
- * \param[in] workspace if weight no preprocessed
- * (preprocessed_filter == nullptr), The size of the workspace satisfies the
- * situation that weights is not processed, other wise the size of workspace
- * satisfies the situation that weights is preprocessed
- * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
- *
- * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
- * alphaw, oc, ic)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
- _megdnn_tensor_in z, _megdnn_tensor_out dst,
- const PreprocessedFilter* preprocessed_filter,
- _megdnn_workspace workspace) = 0;
-
- MGE_WIN_DECLSPEC_FUC void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
- _megdnn_tensor_in z, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
- exec(src, filter, bias, z, dst, nullptr, workspace);
- }
-
- /**
- * \brief execute weight preprocessing, read weights form filter and bias,
- * write to preprocessed_filter after preprocessed.
- *
- * \praram[in] workspace the needed tmp workspace when exec_preprocess
- * running, the size is got by get_preprocess_workspace_in_bytes
- */
- virtual void exec_preprocess(
- const TensorLayout& src_layout, _megdnn_tensor_in filter,
- _megdnn_tensor_in bias, const TensorLayout& z_layout,
- const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
- _megdnn_workspace workspace) = 0;
- MGE_WIN_DECLSPEC_FUC void deduce_dtype(
- DType src, DType filter, DType bias, DType z, DType& dst);
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
-
- /**
- * \brief query the workspace needed when executing the opr, if the weights
- * are preprocessed the preprocessed_filter will not be nullptr, else it
- * will be nullptr, the workspace size maybe different whether weights are
- * preprocessed
- *
- * \return the size of workspace needed when executing
- */
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
- const PreprocessedFilter* preprocessed_filter) = 0;
-
- /**
- * \brief query the workspace needed when pre-processing the weights,
- * according to the return size, a _megdnn_workspace will be created and
- * passed through exec_preprocess
- *
- * \return the size of workspace needed when pre-processing
- */
- virtual size_t get_preprocess_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z,
- const TensorLayout& dst) = 0;
-
- /**
- * \brief deduce the pre-processed filter layouts according to the src,
- * filter and dst layout, which may contain multi layouts when the weights
- * is not one
- *
- * \return SmallVector<TensorLayout> Derive the layouts of weight
- * preprocessing, return empty if preprocessing is not needed.
- */
- virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z,
- const TensorLayout& dst) = 0;
-
- enum class BiasMode : uint32_t {
- NO_BIAS = 0, //!< no bias
- BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
- BIAS //!< [N, C, H, W]
- };
-
- //! param for winograd algos.
-
- struct WinogradParam {
- uint32_t channel_block_size;
- uint32_t output_block_size;
- uint32_t tile_size;
- uint32_t filter_size;
- bool operator==(const WinogradParam& rhs) const {
- return channel_block_size == rhs.channel_block_size &&
- output_block_size == rhs.output_block_size &&
- tile_size == rhs.tile_size && filter_size == rhs.filter_size;
- }
-
- std::string to_string() const;
- };
- static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0, 0};
-
- struct DirectParam {
- std::string to_string() const { return ""; }
- };
-
- struct MatmulParam {
- std::string to_string() const { return ""; }
- };
-
- struct DefaultParam {
- std::string to_string() const { return ""; }
- };
-
- //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
- //! \warning: base must not contain :.
- template <typename T>
- static std::string algo_name(
- const std::string& base, const T& p,
- param::ConvBias::Format format = param::ConvBias::Format::NCHW);
- /*!
- * \brief parse algo_name and get WinogradParam from algo name.
- *
- * \param algo name string
- * \return WinogradParam parsed from algo name, use pattern
- * winograd:base:m:tile_size.
- *
- * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
- */
- static WinogradParam parse_winograd_name(const std::string& algo_name);
-
- /**
- * @brief find if there is nchw_nchwxx conv kernel optimized for argment,
- * nchw44 used for arm, nchw88 used for x86
- *
- * @param src_dtype conv feature map data type
- * @param filter_dtype conv filter or weight data type
- * @param dst_dtype output data type
- * @param fm filter meta param
- * @param bias_mode bias mode, no_bias or broadcast or bias
- * @param nonline_mode identity or relu or h_swish or sigmoid
- * @return true, found a kernel
- * @return false, can`t found any kernel
- */
- static bool is_nchw_nchwxx_optimized(
- const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
- const DTypeEnum dst_dtype,
- const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
- const ConvBiasForward::BiasMode bias_mode,
- const param::ConvBias::NonlineMode nonline_mode);
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::CONVBIAS_FORWARD;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
- size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
-
- CanonizedFilterMeta check_exec_allow_noncontiguous(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
- size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
- };
- using ConvBias = ConvBiasForward;
-
- /**
- * \brief RegionRestrictedConvolutionForward operator.
- */
- class RegionRestrictedConvolutionForward : public ConvolutionBase<param::Convolution> {
- DEF_OPR_IMPL(RegionRestrictedConvolutionForward, ConvolutionBase, 4, 1);
-
- public:
- /**
- * \param[in] src (n, ic, ih, iw) or (n, g*icpg, ih, iw)
- * \param[in] filter (oc, ic, fh, fw) or (g, ocpg, icpg, fh, fw)
- * \param[in] rin (n, ih, iw)
- * \param[in] rout (n, oh, ow)
- * \param[out] dst (n, oc, oh, ow) or (n, g*ocpg, oh, ow)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin,
- _megdnn_tensor_in rout, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
-
- void deduce_dtype(DType src, DType filter, DType rin, DType rout, DType& dst);
-
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& rin, const TensorLayout& rout, TensorLayout& dst);
-
- /**
- * \brief query the workspace needed when executing the opr
- * \return the size of workspace needed when executing
- */
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& rin, const TensorLayout& rout,
- const TensorLayout& dst) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_FORWARD;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& rin, const TensorLayout& rout, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
- using RegionRestrictedConvolution = RegionRestrictedConvolutionForward;
-
- /**
- * \brief RegionRestrictedConvolutionBackwardData operator.
- *
- * Calculating the gradient wrt. convolution input data.
- */
- class RegionRestrictedConvolutionBackwardData
- : public ConvolutionBase<param::Convolution> {
- DEF_OPR_IMPL(RegionRestrictedConvolutionBackwardData, ConvolutionBase, 4, 1);
-
- public:
- /**
- * \param[in] filter (oc, ic, fh, fw) or (g, ocpg, icpg, fh, fw)
- * \param[in] diff (n, oc, oh, ow) or (n, g*ocpg, oh, ow)
- * \param[in] rin (n, ih, iw)
- * \param[in] rout (n, oh, ow)
- * \param[out] grad (n, ic, ih, iw) or (n, g*icpg, ih, iw)
- */
- virtual void exec(
- _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
- _megdnn_tensor_in rout, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& rin, const TensorLayout& rout,
- const TensorLayout& grad) = 0;
-
- MGE_WIN_DECLSPEC_FUC void deduce_dtype(
- DType filter, DType diff, DType rin, DType rout, DType& grad);
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& rin, const TensorLayout& rout, TensorLayout& grad);
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_BACKWARD_DATA;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& rin, const TensorLayout& rout, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- /**
- * \brief RegionRestrictedConvolutionBackwardFilter operator.
- *
- * Calculating the gradient wrt. convolution filter.
- */
- class RegionRestrictedConvolutionBackwardFilter
- : public ConvolutionBase<param::Convolution> {
- DEF_OPR_IMPL(RegionRestrictedConvolutionBackwardFilter, ConvolutionBase, 4, 1);
-
- public:
- /**
- * \param[in] src (n, ic, ih, iw) or (n, g*icpg, ih, iw)
- * \param[in] diff (n, oc, oh, ow) or (n, g*ocpg, oh, ow)
- * \param[in] rin (n, ih, iw)
- * \param[in] rout (n, oh, ow)
- * \param[out] grad (oc, ic, fh, fw) or (g, ocpg, icpg, fh, fw)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
- _megdnn_tensor_in rout, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin,
- const TensorLayout& rout, const TensorLayout& grad) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_BACKWARD_FILTER;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin,
- const TensorLayout& rout, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- /**
- * \brief base class for Conv - Nonline - Pooling
- */
- class ConvPoolingBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
-
- /**
- * \ Param::Method: Two methods to fetch the input data.
- * Default methods is WITH_TEXTURE_OBJ.
- * If you want to use WITH_SHARED_MEM mode,
- * please make sure that the size of
- * [ all of the fliter kernels + a channel
- * of input data + a channel of output data]
- * should be no large than 38KB.
- * And the pooling mode should not be "MAX".
- */
- DEF_OPR_PARAM(ConvPooling);
-
- protected:
- virtual void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, TensorLayout& dst) = 0;
- virtual void check_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, TensorLayout& dst,
- size_t workspace_limit_in_bytes) = 0;
- };
-
- class ConvPoolingForward : public ConvPoolingBase {
- DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
-
- public:
- /**
- * \param[in] src input tensor
- * \param[out] dst output tensor
- */
- virtual void exec(
- const _megdnn_in TensorND src, const _megdnn_in TensorND filter,
- const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
- _megdnn_out Workspace workspace) = 0;
- virtual void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, TensorLayout& dst) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& dst) = 0;
-
- protected:
- virtual void check_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, TensorLayout& dst,
- size_t workspace_limit_in_bytes) = 0;
- };
- using ConvPooling = ConvPoolingForward;
-
- class GroupLocalBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
- DEF_OPR_PARAM(Convolution);
-
- public:
- using Mode = Param::Mode;
-
- protected:
- void deduce_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
- void check_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst);
- };
-
- class GroupLocalForward : public GroupLocalBase {
- DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
-
- public:
- /**
- * \param[in] src (N, IC, IH, IW)
- * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
- * \param[out] dst (N, OC, OH, OW)
- **/
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
- deduce_layout_fwd(src, filter, dst);
- }
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
- using GroupLocal = GroupLocalForward;
-
- class GroupLocalBackwardData : public GroupLocalBase {
- DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- class GroupLocalBackwardFilter : public GroupLocalBase {
- DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- class Images2NeibsBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
- DEF_OPR_PARAM(Images2Neibs);
-
- protected:
- void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
- void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
- };
-
- class Images2NeibsForward : public Images2NeibsBase {
- DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
-
- public:
- /**
- * \param[in] src (N, C, IH, IW)
- * \param[out] dst (N, C, OH, OW, window_h, window_w)
- *
- * \see
- * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
- *
- * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
- * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
- * iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst) = 0;
- void deduce_layout(const TensorLayout& src, TensorLayout& dst);
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
- using Images2Neibs = Images2NeibsForward;
-
- class Images2NeibsBackward : public Images2NeibsBase {
- DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
-
- public:
- /**
- * \param[in] diff the backpropagated gradient wrt. dst
- * \param[out] grad the backpropagated gradient wrt. src
- */
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- class SlidingWindowTransposeBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(SlidingWindowTransposeBase, OperatorBase);
- DEF_OPR_PARAM(SlidingWindowTranspose);
-
- protected:
- void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
- void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
- };
-
- class SlidingWindowTransposeForward : public SlidingWindowTransposeBase {
- DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1);
-
- public:
- /**
- * \param[in] src (N, C, IH, IW, window_h, window_w)
- * \param[out] dst (N, C, OH, OW)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst) = 0;
- void deduce_layout(const TensorLayout& src, TensorLayout& dst);
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
- using SlidingWindowTranspose = SlidingWindowTransposeForward;
-
- class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase {
- DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1);
-
- public:
- /**
- * \param[in] diff the backpropagated gradient wrt. dst
- * \param[out] grad the backpropagated gradient wrt. src
- */
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- /**
- * \brief base class for Pooling
- */
- class PoolingBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
- DEF_OPR_PARAM(Pooling);
-
- public:
- using Mode = Param::Mode;
-
- protected:
- void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
- void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
-
- public:
- static void deduce_layout_impl(
- const TensorLayout& src, const Param& param, TensorLayout& dst);
- };
-
- class PoolingForward : public PoolingBase,
- public detail::MultiAlgoOpr<PoolingForward, 2> {
- DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
-
- public:
- /**
- * \param[in] src input tensor
- * \param[out] dst output tensor
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- MGE_WIN_DECLSPEC_FUC void deduce_layout(const TensorLayout& src, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::POOLING_FORWARD;
- }
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
-
- using Pooling = PoolingForward;
-
- class PoolingBackward : public PoolingBase,
- public detail::MultiAlgoOpr<PoolingBackward, 4> {
- DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
-
- public:
- /**
- * \param[in] src the `src' parameter in PoolingForward::exec
- * \param[in] dst the `dst' parameter in PoolingForward::exec
- * \param[in] diff the backpropagated gradient wrt. dst
- * \param[out] grad the backpropagated gradient wrt. src
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
- _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::POOLING_BACKWARD;
- }
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- /**
- * \brief base class for AdaptivePooling
- */
- class AdaptivePoolingBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase);
- DEF_OPR_PARAM(AdaptivePooling);
-
- protected:
- param::Pooling deduce_pooling_param(
- const TensorLayout& src, const TensorLayout& dst);
- };
-
- class AdaptivePoolingForward : public AdaptivePoolingBase {
- DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1);
-
- public:
- /**
- * \param[in] src input tensor
- * \param[out] dst output tensor
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst) = 0;
- };
-
- using AdaptivePooling = AdaptivePoolingForward;
-
- class AdaptivePoolingBackward : public AdaptivePoolingBase {
- DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1);
-
- public:
- /**
- * \param[in] src the `src' parameter in AdaptivePoolingForward::exec
- * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
- * \param[in] diff the backpropagated gradient wrt. dst
- * \param[out] grad the backpropagated gradient wrt. src
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
- _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
- };
-
- /**
- * \brief base class for Local
- */
- class LocalBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
- DEF_OPR_PARAM(Convolution);
-
- public:
- using Mode = Param::Mode;
-
- protected:
- void deduce_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
- void check_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst);
- };
-
- class LocalForward : public LocalBase {
- DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
-
- public:
- /**
- * \param[in] src (n, ic, ih, iw)
- * \param[in] filter (oh, ow, ic, fh, fw, oc)
- * \param[out] dst (n, oc, oh, ow)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- /**
- * \brief Deducing output tensor layouts from input tensor layouts.
- *
- * Be aware that the first and second dimension of `filter' are ignored
- * when deducing `dst' layout.
- */
- void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
- using Local = LocalForward;
-
- class LocalBackwardData : public LocalBase {
- DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
-
- public:
- /**
- * \param[in] filter (oh, ow, ic, fh, fw, oc)
- * \param[in] diff (n, oc, oh, ow)
- * \param[out] grad (n, ic, ih, iw)
- */
- virtual void exec(
- _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
-
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- class LocalBackwardFilter : public LocalBase {
- DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
-
- public:
- /**
- * \param[in] src (n, ic, ih, iw)
- * \param[in] diff (n, oc, oh, ow)
- * \param[out] grad (oh, ow, ic, fh, fw, oc)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
-
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- class BNBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
- DEF_OPR_PARAM(BN);
-
- protected:
- void check_param();
- };
-
- class BNForward : public BNBase {
- DEF_OPR_IMPL(BNForward, BNBase, 6, 6);
-
- public:
- /**
- * \dst[i] = gemma
- * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
- * epsilon is a very small value to avoid a "divide by zero" error.
- * \param[in] src (n, c, h, w)
- * \param[out] dst (n, c, h, w)
- * \param[out] mean (see m_param.ParamDim) Global mean.
- * \param[out] variance (see m_param.ParamDim) Global variance.
- * \param[out] batch_mean (see m_param.ParamDim)
- * Optionally cached intermediate mean from forward pass
- * \param[out] batch_inv_variance (see m_param.ParamDim)
- * Optionally cached intermediate variance from forward pass
- * \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
- * src and dst must have the same shape.
- * src and dst must be contiguous.
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
- _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
- _megdnn_tensor_inout variance, _megdnn_tensor_out batch_mean,
- _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
- _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& src, const TensorLayout& bn_scale,
- const TensorLayout& bn_bias, TensorLayout& mean, TensorLayout& variance,
- TensorLayout& batch_mean, TensorLayout& batch_inv_variance,
- TensorLayout& reserve, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& bn_scale,
- const TensorLayout& bn_bias, const TensorLayout& mean,
- const TensorLayout& variance, const TensorLayout& batch_mean,
- const TensorLayout& batch_inv_variance, const TensorLayout& reserve,
- const TensorLayout& dst) = 0;
- virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& bn_scale,
- const TensorLayout& bn_bias, const TensorLayout& mean,
- const TensorLayout& variance, const TensorLayout& batch_mean,
- const TensorLayout& batch_inv_variance, const TensorLayout& dst,
- size_t workspace_in_bytes, size_t reserve_in_bytes = 0);
- };
- using BN = BNForward;
-
- class BNBackward : public BNBase {
- DEF_OPR_IMPL(BNBackward, BNBase, 6, 3);
-
- public:
- /**
- * \param[in] input data of forwarding propagate.
- * \param[in] dy the backpropagated gradient of y.
- * \param[out] dx the backpropagated gradient of x.
- * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
- * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
- * Optionally cached intermediate results from forward pass
- * \param[in] saved_batch_mean mean of the input batch.
- Calculated in the forwardpropagation.
- * \param[in] saved_batch_variance of the input batch.
- Calculated in the forwardpropagation.
- * \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
- */
- virtual void exec(
- _megdnn_tensor_in x, _megdnn_tensor_in dy,
- _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_variance,
- _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
- _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
- _megdnn_tensor_out dx, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& x, const TensorLayout& dy,
- const TensorLayout& saved_batch_mean,
- const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
- const TensorLayout& reserve, const TensorLayout& d_bn_scale,
- const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
- virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& x, const TensorLayout& dy,
- const TensorLayout& saved_batch_mean,
- const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
- const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias,
- const TensorLayout& dx, size_t workspace_in_bytes,
- size_t reserve_in_bytes = 0);
- };
-
- class LRNBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
- DEF_OPR_PARAM(LRN);
-
- protected:
- void check_param();
- };
-
- class LRNForward : public LRNBase {
- DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
-
- public:
- /**
- * \see ImageNet Classification with Deep Convolutional Neural Networks
- * \param[in] src (n, c, h, w)
- * \param[out] dst (n, c, h, w)
- *
- * src and dst must have the same shape.
- * src and dst must be contiguous.
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(const TensorLayout& src, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
- using LRN = LRNForward;
-
- class LRNBackward : public LRNBase {
- DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
-
- public:
- /**
- * \param[in] src the `src' parameter in LRNForward::exec
- * \param[in] dst the `dst' parameter in LRNForward::exec
- * \param[in] diff the backpropagated gradient wrt. dst
- * \param[out] grad the backpropagated gradient wrt. src
- *
- * All tensors should be contiguous and of the same shape.
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
- _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- class ROIPoolingBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
- DEF_OPR_PARAM(ROIPooling);
-
- protected:
- void check_layout_fwd(
- const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
- const TensorLayout& index);
- };
-
- class ROIPoolingForward : public ROIPoolingBase {
- DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
-
- public:
- /**
- * \param[in] src (n, c, ih, iw)
- * \param[in] rois (m, 5)
- * \param[out] dst (m, c, oh, ow)
- * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
- *
- * The internal implementation is akin to
- * https://github.com/rbgirshick/caffe-fast-rcnn .d
- * Note that rois(, 0) denotes the input image index. We store it as
- * a float, but it should be an integer instead.
- *
- * index is a temporary tensor to facilitate its backward operator.
- * It is used to store argmax indicex in MAX mode, and it is not used
- * in AVERAGE mode.
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
- _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
- const TensorLayout& index) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
- const TensorLayout& index, size_t workspace_in_bytes);
- };
- using ROIPooling = ROIPoolingForward;
-
- class ROIPoolingBackward : public ROIPoolingBase {
- DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
-
- public:
- /**
- * \param[in] diff the backpropagated gradient wrt. dst
- * \param[in] src the `src' parameter in ROIPoolingForward::exec
- * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
- * \param[in] index the `index' parameter in ROIPoolingForward::exec
- * \param[out] grad the backpropagated gradient wrt. src
- */
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_in src, _megdnn_tensor_in rois,
- _megdnn_tensor_in index, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
- const TensorLayout& index, const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
- const TensorLayout& index, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- class Convolution3DBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
- DEF_OPR_PARAM(Convolution3D);
-
- public:
- static constexpr size_t MAX_SPATIAL_DIM = 3;
- using Mode = Param::Mode;
- struct CanonizedFilterMeta {
- DTypeEnum dtype_enum;
- Param::Format format;
- uint32_t
- //! whether filter should be flipped (i.e. is CONVOLUTION)
- should_flip,
- group, //!< number of groups
- icpg, //!< input channels per group
- ocpg, //!< output channels per group
- spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
- //! spatial dim
- spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
- //! spatial dim with dilation applied
- dilated_spatial[MAX_SPATIAL_DIM];
- } MEGDNN_PACKED;
-
- protected:
- CanonizedFilterMeta deduce_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- TensorLayout& dst) const;
- CanonizedFilterMeta check_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) const;
-
- static CanonizedFilterMeta make_canonized_filter_meta_impl(
- size_t src_ndim, const TensorLayout& filter, const Param& param);
- CanonizedFilterMeta make_canonized_filter_meta(
- size_t src_ndim, const TensorLayout& filter) const;
- };
-
- class Convolution3DForward : public Convolution3DBase,
- public detail::MultiAlgoOpr<Convolution3DForward, 3> {
- DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
-
- public:
- /**
- * \param[in] src (n, ic, id, ih, iw)
- * \param[in] filter (oc, ic, fd, fh, fw)
- * \param[out] dst (n, oc, od, oh, ow)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::CONVOLUTION3D_FORWARD;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
- using Convolution3D = Convolution3DForward;
-
- class Convolution3DBackwardData
- : public Convolution3DBase,
- public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
- DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
-
- public:
- /**
- * \param[in] filter (oc, ic, fd, fh, fw)
- * \param[in] diff (n, oc, od, oh, ow)
- * \param[out] grad (n, ic, id, ih, iw)
- */
- static void deduce_layout_impl(
- const TensorLayout& filter, const TensorLayout& diff, const Param& param,
- TensorLayout& grad);
- virtual void exec(
- _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- class Convolution3DBackwardFilter
- : public Convolution3DBase,
- public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
- DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
-
- public:
- /**
- * \param[in] src (n, ic, id, ih, iw)
- * \param[in] diff (n, oc, od, oh, ow)
- * \param[out] grad (oc, ic, fd, fh, fw)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- class LocalShareBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
- DEF_OPR_PARAM(LocalShare);
-
- protected:
- void deduce_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
- void check_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst);
- };
-
- class LocalShareForward : public LocalShareBase,
- public detail::MultiAlgoOpr<LocalShareForward, 3> {
- DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
-
- public:
- /**
- * \param[in] src (N, IC, IH, IW)
- * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
- * FH, FW, OC / G)
- * \param[out] dst (N, OC, OH, OW)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- /**
- * \brief deduce layout of the ouput tensor
- */
- void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::LOCAL_SHARE_FORWARD;
- }
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
- using LocalShare = LocalShareForward;
-
- class LocalShareBackwardData : public LocalShareBase,
- public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
- DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
-
- public:
- /**
- * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
- * FH, FW, OC / G)
- * \param[in] diff (N, OC, OH, OW)
- * \param[out] grad (N, IC, IH, IW)
- */
- virtual void exec(
- _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
- void deduce_layout(
- const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::LOCAL_SHARE_BACKWARD_DATA;
- }
-
- protected:
- void check_exec(
- const TensorLayout& filter, const TensorLayout& diff,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- class LocalShareBackwardFilter
- : public LocalShareBase,
- public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
- DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
-
- public:
- /**
- * \param[in] src (N, IC, IH, IW)
- * \param[in] diff (N, OC, OH, OW)
- * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
- * FH, FW, OC / G)
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
-
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& diff,
- const TensorLayout& grad) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::LOCAL_SHARE_BACKWARD_FILTER;
- }
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- class ROIAlignBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
- DEF_OPR_PARAM(ROIAlign);
-
- protected:
- void deduce_layout_fwd(
- const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
- TensorLayout& index);
- void check_layout_fwd(
- const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
- const TensorLayout& index);
- };
-
- class ROIAlignForward : public ROIAlignBase {
- DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
-
- public:
- /**
- * \param[in] src (n, c, ih, iw)
- * \param[in] rois (m, 5)
- * \param[out] dst (m, c, oh, ow)
- * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
- *
- * Note that rois(, 0) denotes the input image index. We store it as
- * a float, but it should be an integer instead.
- *
- * index is a temporary tensor to facilitate its backward operator.
- * It is used to store argmax indicex in MAX mode, and it is not used
- * in AVERAGE mode.
- */
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
- _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
- TensorLayout& index);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
- const TensorLayout& index) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
- const TensorLayout& index, size_t workspace_in_bytes);
- };
- using ROIAlign = ROIAlignForward;
-
- class ROIAlignBackward : public ROIAlignBase {
- DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
-
- public:
- /**
- * \param[in] diff the backpropagated gradient wrt. dst
- * \param[in] rois the `rois' parameter in ROIAlignForward::exec
- * \param[in] index the `index' parameter in ROIAlignForward::exec
- * \param[out] grad the backpropagated gradient wrt. src
- */
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_in rois, _megdnn_tensor_in index,
- _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& rois,
- const TensorLayout& index, const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& rois,
- const TensorLayout& index, const TensorLayout& grad,
- size_t workspace_in_bytes);
- };
-
- class DeformableConvBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
- DEF_OPR_PARAM(Convolution);
-
- public:
- static constexpr size_t MAX_SPATIAL_DIM = 2;
- struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
- uint32_t deformable_group;
- };
-
- protected:
- CanonizedFilterMeta make_canonized_filter_meta(
- size_t src_ndim, const TensorLayout& filter,
- const TensorLayout& offset) const;
- void deduce_layout_fwd(
- const TensorLayout& im, const TensorLayout& filter,
- const TensorLayout& mask, const TensorLayout& offset, TensorLayout& dst);
- void check_layout_fwd(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& mask, const TensorLayout& offset,
- const TensorLayout& dst);
- };
-
- class DeformableConvForward : public DeformableConvBase,
- public detail::MultiAlgoOpr<DeformableConvForward, 5> {
- DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
-
- public:
- /**
- * \param[in] im (n, ic, ih, iw)
- * \param[in] filter (oc, ic, fh, fw)
- * \param[in] offset (dg, 2, fh, fw, oh, ow)
- * \param[in] mask (dg, fh, fw, oh, ow)
- * \param[out] dst (n, oc, oh, ow)
- */
- virtual void exec(
- _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
- _megdnn_tensor_in mask, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& im, const TensorLayout& filter,
- const TensorLayout& offset, const TensorLayout& mask, TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& im, const TensorLayout& filter,
- const TensorLayout& offset, const TensorLayout& mask,
- const TensorLayout& dst) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::DEFORMABLE_CONV_FORWARD;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& im, const TensorLayout& filter,
- const TensorLayout& offset, const TensorLayout& mask,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
- using DeformableConv = DeformableConvForward;
-
- /**
- * \brief DeformableConvBackwardFilter operator.
- *
- * Calculating the gradient wrt. convolution filter.
- */
- class DeformableConvBackwardFilter
- : public DeformableConvBase,
- public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
- DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
-
- public:
- /**
- * \param[in] im (oc, ic, fh, fw)
- * \param[in] offset (dg, 2, fh, fw, oh, ow)
- * \param[in] mask (dg, fh, fw, oh, ow)
- * \param[in] out_grad (n, oc, oh, ow)
- * \param[out] filter_grad (oc, ic, ih, iw)
- */
- virtual void exec(
- _megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
- _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& im, const TensorLayout& offset,
- const TensorLayout& mask, const TensorLayout& out_grad,
- const TensorLayout& filter_grad) = 0;
- void deduce_layout(
- const TensorLayout& im, const TensorLayout& offset,
- const TensorLayout& mask, const TensorLayout& out_grad,
- TensorLayout& filter_grad);
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_FILTER;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& im, const TensorLayout& offset,
- const TensorLayout& mask, const TensorLayout& out_grad,
- const TensorLayout& filter_grad, size_t workspace_in_bytes);
- };
-
- /**
- * \brief DeformableConvBackwardData operator.
- *
- * Calculating the gradient wrt. convolution input data, offset and mask.
- */
- class DeformableConvBackwardData
- : public DeformableConvBase,
- public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
- DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
-
- public:
- /**
- * \param[in] im (oc, ic, fh, fw)
- * \param[in] filter (oc, ic, fh, fw)
- * \param[in] offset (dg, 2, fh, fw, oh, ow)
- * \param[in] mask (dg, fh, fw, oh, ow)
- * \param[in] out_grad (n, oc, oh, ow)
- * \param[out] im_grad (n, ic, ih, iw)
- * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
- * \param[out] mask_grad (dg, fh, fw, oh, ow)
- */
- virtual void exec(
- _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
- _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
- _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad,
- _megdnn_tensor_out mask_grad, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& im, const TensorLayout& filter,
- const TensorLayout& offset, const TensorLayout& mask,
- const TensorLayout& out_grad, const TensorLayout& im_grad,
- const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
- void deduce_layout(
- const TensorLayout& im, const TensorLayout& filter,
- const TensorLayout& offset, const TensorLayout& mask,
- const TensorLayout& out_grad, TensorLayout& im_grad,
- TensorLayout& offset_grad, TensorLayout& mask_grad);
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_DATA;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& im, const TensorLayout& filter,
- const TensorLayout& offset, const TensorLayout& mask,
- const TensorLayout& out_grad, const TensorLayout& im_grad,
- const TensorLayout& offset_grad, const TensorLayout& mask_grad,
- size_t workspace_in_bytes);
- };
-
- class DeformablePSROIPoolingBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
- DEF_OPR_PARAM(DeformablePSROIPooling);
-
- protected:
- void deduce_layout_fwd(
- const TensorLayout& data, const TensorLayout& trans,
- const TensorLayout& rois, TensorLayout& out_data, TensorLayout& out_count);
-
- void check_layout_fwd(
- const TensorLayout& data, const TensorLayout& trans,
- const TensorLayout& rois, const TensorLayout& out_data,
- const TensorLayout& out_count, size_t workspace_in_bytes);
- };
-
- class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
- DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3, 2);
-
- public:
- /**
- * \param[in] data (oc, ic, ih, iw)
- * \param[in] rois (xx, xx, xx, xx)
- * \param[in] trans (oc, ic, fh, fw)
- * \param[out] out_data ( n, ic, ih, iw)
- * \param[out] out_count ( n, ic, ih, iw)
- */
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& data, const TensorLayout& rois,
- const TensorLayout& trans, const TensorLayout& out_data,
- const TensorLayout& out_count) = 0;
- virtual void exec(
- _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
- _megdnn_tensor_out out_data, _megdnn_tensor_out out_count,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& data, const TensorLayout& rois,
- const TensorLayout& trans, TensorLayout& out_data, TensorLayout& out_count);
- void check_exec(
- const TensorLayout& data, const TensorLayout& rois,
- const TensorLayout& trans, const TensorLayout& out_data,
- const TensorLayout& out_count, size_t workspace_in_bytes);
- };
-
- using DeformablePSROIPooling = DeformablePSROIPoolingForward;
-
- class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
- DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5, 2);
-
- public:
- /**
- * \param[in] data (oc, ic, ih, iw)
- * \param[in] rois (xx, xx, xx, xx)
- * \param[in] trans (oc, ic, fh, fw)
- * \param[in] out_diff (xx, xx, xx, xx)
- * \param[in] out_count (xx, xx, xx, xx)
- * \param[out] data_diff ( n, ic, ih, iw)
- * \param[out] trans_diff ( n, ic, ih, iw)
- */
- virtual void exec(
- _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
- _megdnn_tensor_in out_diff, _megdnn_tensor_in out_count,
- _megdnn_tensor_out data_diff, _megdnn_tensor_out trans_diff,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& data, const TensorLayout& rois,
- const TensorLayout& trans, const TensorLayout& out_diff,
- const TensorLayout& out_count, const TensorLayout& data_diff,
- const TensorLayout& trans_diff) = 0;
-
- void check_exec(
- const TensorLayout& data, const TensorLayout& rois,
- const TensorLayout& trans, const TensorLayout& out_diff,
- const TensorLayout& out_count, const TensorLayout& data_diff,
- const TensorLayout& trans_diff, size_t workspace_in_bytes);
- };
-
- class BatchConvBiasForward : public ConvolutionBase<param::BatchConvBias>,
- public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
- DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
- _megdnn_tensor_in z, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
-
- void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
- void deduce_layout(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
-
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z,
- const TensorLayout& dst) = 0;
-
- static Algorithm::OprType get_opr_type() {
- return Algorithm::OprType::BATCH_CONV_FORWARD;
- }
-
- protected:
- CanonizedFilterMeta check_exec(
- const TensorLayout& src, const TensorLayout& filter,
- const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
- using BatchConvBias = BatchConvBiasForward;
-
- class FakeQuantBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
- DEF_OPR_PARAM(FakeQuant);
-
- protected:
- void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
- void check_layout_fwd(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, const TensorLayout& output);
- };
-
- class FakeQuantForward : public FakeQuantBase {
- DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in scale,
- _megdnn_tensor_in zero_point, _megdnn_tensor_out output,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, TensorLayout& output);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, const TensorLayout& output) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, const TensorLayout& output,
- size_t workspace_in_bytes);
- };
-
- using FakeQuant = FakeQuantForward;
-
- class FakeQuantBackward : public FakeQuantBase {
- DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
- _megdnn_tensor_in zero_point, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& input,
- const TensorLayout& scale, const TensorLayout& zero_point,
- const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& input,
- const TensorLayout& scale, const TensorLayout& zero_point,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- class TQTBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(TQTBase, OperatorBase);
- DEF_OPR_PARAM(TQT);
-
- protected:
- void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
- void check_layout_fwd(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& output);
- };
-
- class TQTForward : public TQTBase {
- DEF_OPR_IMPL(TQTForward, TQTBase, 2, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_out output,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& input, const TensorLayout& scale, TensorLayout& output);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& output) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& output, size_t workspace_in_bytes);
- };
- using TQT = TQTForward;
-
- class TQTBackward : public TQTBase {
- DEF_OPR_IMPL(TQTBackward, TQTBase, 3, 2);
-
- public:
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
- _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& input,
- const TensorLayout& scale, const TensorLayout& grad_x,
- const TensorLayout& grad_s) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& input,
- const TensorLayout& scale, const TensorLayout& grad_x,
- const TensorLayout& grad_s, size_t workspace_in_bytes);
- };
-
- class LSQBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase);
- DEF_OPR_PARAM(LSQ);
-
- protected:
- void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
- void check_layout_fwd(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, const TensorLayout& grad_scale,
- const TensorLayout& output);
- };
-
- class LSQForward : public LSQBase {
- DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in scale,
- _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
- _megdnn_tensor_out output, _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, const TensorLayout& grad_scale,
- TensorLayout& output);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, const TensorLayout& grad_scale,
- const TensorLayout& output) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& scale,
- const TensorLayout& zero_point, const TensorLayout& grad_scale,
- const TensorLayout& output, size_t workspace_in_bytes);
- };
- using LSQ = LSQForward;
-
- class LSQBackward : public LSQBase {
- DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2);
-
- public:
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
- _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
- _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& input,
- const TensorLayout& scale, const TensorLayout& zero_point,
- const TensorLayout& grad_scale, const TensorLayout& grad_x,
- const TensorLayout& grad_s) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& input,
- const TensorLayout& scale, const TensorLayout& zero_point,
- const TensorLayout& grad_scale, const TensorLayout& grad_x,
- const TensorLayout& grad_s, size_t workspace_in_bytes);
- };
-
- class LayerNormBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase);
- DEF_OPR_PARAM(LayerNorm);
-
- public:
- MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl(
- const TensorLayout& data, const Param& p, TensorLayout& dst,
- TensorLayout& mean, TensorLayout& rstd);
-
- protected:
- void deduce_layout_fwd(
- const TensorLayout& data, const TensorLayout& weight,
- const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
- TensorLayout& rstd);
- void check_layout_fwd(
- const TensorLayout& data, const TensorLayout& weight,
- const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
- const TensorLayout& rstd);
- };
-
- class LayerNormForward : public LayerNormBase {
- DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3);
-
- public:
- virtual void exec(
- _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
- _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
- _megdnn_workspace workspace) = 0;
- MGE_WIN_DECLSPEC_FUC void deduce_layout(
- const TensorLayout& data, const TensorLayout& weight,
- const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
- TensorLayout& rstd);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& data, const TensorLayout& weight,
- const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
- const TensorLayout& rstd) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& data, const TensorLayout& weight,
- const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
- const TensorLayout& rstd, size_t workspace_in_bytes);
- };
- using LayerNorm = LayerNormForward;
-
- class LayerNormBackward : public LayerNormBase {
- DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3);
-
- public:
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
- _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
- _megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& diff, const TensorLayout& data,
- const TensorLayout& weight, const TensorLayout& mean,
- const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight,
- TensorLayout& dbias);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& data,
- const TensorLayout& weight, const TensorLayout& mean,
- const TensorLayout& rstd, const TensorLayout& ddata,
- const TensorLayout& dweight, const TensorLayout& dbias) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& data,
- const TensorLayout& weight, const TensorLayout& mean,
- const TensorLayout& rstd, const TensorLayout& ddata,
- const TensorLayout& dweight, const TensorLayout& dbias,
- size_t workspace_in_bytes);
- };
-
- class DropoutBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(DropoutBase, OperatorBase);
- DEF_OPR_PARAM(Dropout);
- };
-
- class DropoutForward : public DropoutBase {
- DEF_OPR_IMPL(DropoutForward, DropoutBase, 1, 2);
-
- public:
- void deduce_layout(const TensorLayout& inp, TensorLayout& oup, TensorLayout& mask);
- virtual void exec(
- _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& inp, const TensorLayout& oup,
- const TensorLayout& mask) = 0;
- virtual size_t get_mask_size_in_bytes(const TensorLayout& inp) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& inp, const TensorLayout& oup, const TensorLayout& mask,
- size_t workspace_in_bytes);
- };
- using Dropout = DropoutForward;
-
- class DropoutBackward : public DropoutBase {
- DEF_OPR_IMPL(DropoutBackward, DropoutBase, 2, 1);
-
- public:
- void deduce_layout(
- const TensorLayout& doup, const TensorLayout& mask, TensorLayout& dinp);
- virtual void exec(
- _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& doup, const TensorLayout& mask,
- const TensorLayout& dinp) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& doup, const TensorLayout& mask,
- const TensorLayout& dinp, size_t workspace_in_bytes);
- };
- class SoftmaxBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(SoftmaxBase, OperatorBase);
- DEF_OPR_PARAM(Softmax);
-
- protected:
- void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
- void check_layout_fwd(const TensorLayout& input, const TensorLayout& output);
- };
-
- class SoftmaxForward : public SoftmaxBase {
- DEF_OPR_IMPL(SoftmaxForward, SoftmaxBase, 1, 1);
-
- public:
- /**
- * \param[in] input input tensor
- * \param[out] output output tensor
- */
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_out output,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(const TensorLayout& input, TensorLayout& output);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& output) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& output,
- size_t workspace_in_bytes);
- };
- using Softmax = SoftmaxForward;
-
- class SoftmaxBackward : public SoftmaxBase {
- DEF_OPR_IMPL(SoftmaxBackward, SoftmaxBase, 2, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& diff,
- const TensorLayout& grad_x) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& diff,
- const TensorLayout& grad_x, size_t workspace_in_bytes);
- };
-
- class RNNCellForward : public OperatorBase {
- DEF_OPR_PARAM(RNNCell);
- DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
- _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
- _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
- _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- TensorLayout& dst);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
- using RNNCell = RNNCellForward;
-
- class LSTMCellForward : public OperatorBase {
- // DEF_OPR_PARAM(LSTMCell);
- DEF_OPR_PARAM(Empty);
- DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
- _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
- _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
- _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
- _megdnn_tensor_out gates, _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
- TensorLayout& gates);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- const TensorLayout& cx, const TensorLayout& h_new,
- const TensorLayout& c_new, const TensorLayout& gates) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- const TensorLayout& cx, const TensorLayout& h_new,
- const TensorLayout& c_new, const TensorLayout& gates,
- size_t workspace_in_bytes);
- };
- using LSTMCell = LSTMCellForward;
-
- class RNNForward : public OperatorBase {
- DEF_OPR_PARAM(RNN);
- DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in hx,
- _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
- _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& input, const TensorLayout& hx,
- const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
- TensorLayout& reserve_space);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& hx,
- const TensorLayout& flatten_weights, const TensorLayout& output,
- const TensorLayout& hy, const TensorLayout& reserve_space) = 0;
- virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& hx,
- const TensorLayout& flatten_weights, const TensorLayout& output,
- const TensorLayout& hy, const TensorLayout& reserve_space,
- size_t workspace_in_bytes);
- };
- using RNN = RNNForward;
-
- class RNNBackward : public OperatorBase {
- DEF_OPR_PARAM(RNN);
- DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3);
-
- public:
- virtual void exec(
- _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
- _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
- _megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
- _megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
- const TensorLayout& dy, const TensorLayout& dhy,
- const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
- TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
- const TensorLayout& dy, const TensorLayout& dhy,
- const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
- const TensorLayout& dx, const TensorLayout& dhx,
- const TensorLayout& dw) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
- const TensorLayout& dy, const TensorLayout& dhy,
- const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
- const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw,
- size_t workspace_in_bytes);
- };
-
- class LSTMForward : public OperatorBase {
- DEF_OPR_PARAM(LSTM);
- DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4);
-
- public:
- virtual void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
- _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
- _megdnn_tensor_out hy, _megdnn_tensor_out cy,
- _megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
- const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
- TensorLayout& cy, TensorLayout& reserve_space);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
- const TensorLayout& flatten_weights, const TensorLayout& output,
- const TensorLayout& hy, const TensorLayout& cy,
- const TensorLayout& reserve_space) = 0;
- virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
- const TensorLayout& flatten_weights, const TensorLayout& output,
- const TensorLayout& hy, const TensorLayout& cy,
- const TensorLayout& reserve_space, size_t workspace_in_bytes);
- };
- using LSTM = LSTMForward;
-
- class LSTMBackward : public OperatorBase {
- DEF_OPR_PARAM(LSTM);
- DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4);
-
- public:
- virtual void exec(
- _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
- _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
- _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
- _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
- _megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
- const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
- const TensorLayout& dcy, const TensorLayout& flatten_weights,
- const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx,
- TensorLayout& dcx, TensorLayout& dw);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
- const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
- const TensorLayout& dcy, const TensorLayout& flatten_weights,
- const TensorLayout& reserve_space, const TensorLayout& dx,
- const TensorLayout& dhx, const TensorLayout& dcx,
- const TensorLayout& dw) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
- const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
- const TensorLayout& dcy, const TensorLayout& flatten_weights,
- const TensorLayout& reserve_space, const TensorLayout& dx,
- const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
- size_t workspace_in_bytes);
- };
- } // namespace megdnn
- #include "megdnn/internal/opr_header_epilogue.h"
-
- // vim: syntax=cpp.doxygen
|