Browse Source

fix(mgb/opr): move NVOF opr's shape inference to execute part

GitOrigin-RevId: 883c55e4a0
release-1.2
Megvii Engine Team 4 years ago
parent
commit
184e1311fa
1 changed files with 8 additions and 12 deletions
  1. +8
    -12
      src/opr/impl/misc.cpp

+ 8
- 12
src/opr/impl/misc.cpp View File

@@ -166,23 +166,11 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf);

NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config)
: Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} {
constexpr size_t NDIM = 5;
mgb_assert(opr->dtype() == dtype::Uint8());
add_input({opr});
//! NvOf hava only one output
add_output(None);

mgb_log_debug("init nvof engine with precision: %u", m_param.precision);
auto input_shape = this->input()[0]->shape();

//! nvof input format: nthwc4
mgb_assert(input_shape.ndim == NDIM);
//! now only support RGBA format channel data
mgb_assert(input_shape[4] == 4);

for (size_t i = 0; i < NDIM; i++) {
vshape.push_back(input_shape[i]);
}
}

void NvOf::init_output_dtype() {
@@ -195,6 +183,10 @@ SymbolVar NvOf::make(SymbolVar opr, const Param& param,
}

void NvOf::scn_do_execute() {
auto input_shape = this->input()[0]->shape();
for (size_t i = 0; i < 5; i++) {
vshape.push_back(input_shape[i]);
}
auto c = this->comp_node();
//! comp_node may init on CUDA or CPU, eg: lar with --cpu
//! if ON CUDA, need sync, caused by we use different stream
@@ -229,6 +221,10 @@ void NvOf::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
auto ishp = iv.val.at(0).shape();
//! nvof input format: nthwc4
mgb_assert(ishp.ndim == 5);
//! now only support RGBA format channel data
mgb_assert(ishp[4] == 4);
SmallVector<size_t> tv;
tv.push_back(ishp[0]);
tv.push_back(ishp[1] - 1);


Loading…
Cancel
Save