|
@@ -166,23 +166,11 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf); |
|
|
|
|
|
|
|
|
NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config) |
|
|
NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config) |
|
|
: Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} { |
|
|
: Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} { |
|
|
constexpr size_t NDIM = 5; |
|
|
|
|
|
mgb_assert(opr->dtype() == dtype::Uint8()); |
|
|
mgb_assert(opr->dtype() == dtype::Uint8()); |
|
|
add_input({opr}); |
|
|
add_input({opr}); |
|
|
//! NvOf hava only one output |
|
|
//! NvOf hava only one output |
|
|
add_output(None); |
|
|
add_output(None); |
|
|
|
|
|
|
|
|
mgb_log_debug("init nvof engine with precision: %u", m_param.precision); |
|
|
mgb_log_debug("init nvof engine with precision: %u", m_param.precision); |
|
|
auto input_shape = this->input()[0]->shape(); |
|
|
|
|
|
|
|
|
|
|
|
//! nvof input format: nthwc4 |
|
|
|
|
|
mgb_assert(input_shape.ndim == NDIM); |
|
|
|
|
|
//! now only support RGBA format channel data |
|
|
|
|
|
mgb_assert(input_shape[4] == 4); |
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < NDIM; i++) { |
|
|
|
|
|
vshape.push_back(input_shape[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void NvOf::init_output_dtype() { |
|
|
void NvOf::init_output_dtype() { |
|
@@ -195,6 +183,10 @@ SymbolVar NvOf::make(SymbolVar opr, const Param& param, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void NvOf::scn_do_execute() { |
|
|
void NvOf::scn_do_execute() { |
|
|
|
|
|
auto input_shape = this->input()[0]->shape(); |
|
|
|
|
|
for (size_t i = 0; i < 5; i++) { |
|
|
|
|
|
vshape.push_back(input_shape[i]); |
|
|
|
|
|
} |
|
|
auto c = this->comp_node(); |
|
|
auto c = this->comp_node(); |
|
|
//! comp_node may init on CUDA or CPU, eg: lar with --cpu |
|
|
//! comp_node may init on CUDA or CPU, eg: lar with --cpu |
|
|
//! if ON CUDA, need sync, caused by we use different stream |
|
|
//! if ON CUDA, need sync, caused by we use different stream |
|
@@ -229,6 +221,10 @@ void NvOf::init_output_static_infer_desc() { |
|
|
using namespace cg::static_infer; |
|
|
using namespace cg::static_infer; |
|
|
auto infer_shape = [](TensorShape& dest, const InpVal& iv) { |
|
|
auto infer_shape = [](TensorShape& dest, const InpVal& iv) { |
|
|
auto ishp = iv.val.at(0).shape(); |
|
|
auto ishp = iv.val.at(0).shape(); |
|
|
|
|
|
//! nvof input format: nthwc4 |
|
|
|
|
|
mgb_assert(ishp.ndim == 5); |
|
|
|
|
|
//! now only support RGBA format channel data |
|
|
|
|
|
mgb_assert(ishp[4] == 4); |
|
|
SmallVector<size_t> tv; |
|
|
SmallVector<size_t> tv; |
|
|
tv.push_back(ishp[0]); |
|
|
tv.push_back(ishp[0]); |
|
|
tv.push_back(ishp[1] - 1); |
|
|
tv.push_back(ishp[1] - 1); |
|
|