|
@@ -13,6 +13,7 @@ |
|
|
#include "./internal/megdnn_opr_wrapper.inl" |
|
|
#include "./internal/megdnn_opr_wrapper.inl" |
|
|
#include "megbrain/graph/grad_impl.h" |
|
|
#include "megbrain/graph/grad_impl.h" |
|
|
#include "megbrain/opr/imgproc.h" |
|
|
#include "megbrain/opr/imgproc.h" |
|
|
|
|
|
#include "megbrain/opr/io.h" |
|
|
#include "megbrain/opr/utility.h" |
|
|
#include "megbrain/opr/utility.h" |
|
|
|
|
|
|
|
|
using namespace mgb; |
|
|
using namespace mgb; |
|
@@ -486,6 +487,7 @@ struct MegDNNOprInitPostCtor<DctChannelSelectForward> { |
|
|
} // namespace intl |
|
|
} // namespace intl |
|
|
} // namespace opr |
|
|
} // namespace opr |
|
|
} // namespace mgb |
|
|
} // namespace mgb |
|
|
|
|
|
|
|
|
void DctChannelSelectForward::get_output_var_shape( |
|
|
void DctChannelSelectForward::get_output_var_shape( |
|
|
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { |
|
|
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { |
|
|
auto mo = megdnn_opr(); |
|
|
auto mo = megdnn_opr(); |
|
@@ -504,6 +506,7 @@ void DctChannelSelectForward::get_output_var_shape( |
|
|
} |
|
|
} |
|
|
out_shape[0] = dst; |
|
|
out_shape[0] = dst; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
size_t DctChannelSelectForward::get_workspace_size_bytes( |
|
|
size_t DctChannelSelectForward::get_workspace_size_bytes( |
|
|
const TensorShapeArray& input_shapes, |
|
|
const TensorShapeArray& input_shapes, |
|
|
const TensorShapeArray& output_shapes) const { |
|
|
const TensorShapeArray& output_shapes) const { |
|
@@ -513,6 +516,7 @@ size_t DctChannelSelectForward::get_workspace_size_bytes( |
|
|
{input_shapes[0], input(0)->dtype(), input(0)->format()}, {}, {}, |
|
|
{input_shapes[0], input(0)->dtype(), input(0)->format()}, {}, {}, |
|
|
{output_shapes[0], output(0)->dtype(), output(0)->format()}); |
|
|
{output_shapes[0], output(0)->dtype(), output(0)->format()}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void DctChannelSelectForward::scn_do_execute() { |
|
|
void DctChannelSelectForward::scn_do_execute() { |
|
|
auto&& inp = input(); |
|
|
auto&& inp = input(); |
|
|
auto mo = megdnn_opr(); |
|
|
auto mo = megdnn_opr(); |
|
@@ -524,7 +528,6 @@ void DctChannelSelectForward::scn_do_execute() { |
|
|
} else { |
|
|
} else { |
|
|
mgb_assert(inp.size() == 3, "no support input tensor num %zu", |
|
|
mgb_assert(inp.size() == 3, "no support input tensor num %zu", |
|
|
inp.size()); |
|
|
inp.size()); |
|
|
|
|
|
|
|
|
mo->exec(inp[0]->dev_tensor().as_megdnn(), |
|
|
mo->exec(inp[0]->dev_tensor().as_megdnn(), |
|
|
inp[1]->dev_tensor().as_megdnn(), |
|
|
inp[1]->dev_tensor().as_megdnn(), |
|
|
inp[2]->dev_tensor().as_megdnn(), |
|
|
inp[2]->dev_tensor().as_megdnn(), |
|
@@ -533,7 +536,70 @@ void DctChannelSelectForward::scn_do_execute() { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
MEGDNN_OPR_INIT3(DctChannelSelectForward, "dct_channel_select") |
|
|
|
|
|
|
|
|
void DctChannelSelectForward::valid_mask(const int* mask_offset, int mask_len, |
|
|
|
|
|
const int* mask_val, int mask_val_len, |
|
|
|
|
|
const Param& param) { |
|
|
|
|
|
if (mask_len <= 0) |
|
|
|
|
|
return; |
|
|
|
|
|
mgb_assert(mask_offset[0] == 0, |
|
|
|
|
|
"The first element of mask_offset must be zero, but got %d. For " |
|
|
|
|
|
"example mask offset [0, 15, 20] indicate there are 2 ic, and " |
|
|
|
|
|
"ic_0 will have (15 - 0) oc, ic_1 have (20 - 15) oc", |
|
|
|
|
|
mask_offset[0]); |
|
|
|
|
|
for (int i = 1; i < mask_len; ++i) { |
|
|
|
|
|
if (param.format == Param::Format::NCHW4) { |
|
|
|
|
|
mgb_assert(mask_offset[i] % 4 == 0, |
|
|
|
|
|
"Invalid mask offset %d at %d, it should be times of " |
|
|
|
|
|
"4 when using nchw4 format", |
|
|
|
|
|
mask_offset[i], i); |
|
|
|
|
|
} |
|
|
|
|
|
mgb_assert(mask_offset[i] >= mask_offset[i - 1], |
|
|
|
|
|
"The offset of mask must be increasing, but %d(%d) is less " |
|
|
|
|
|
"than %d(%d)", |
|
|
|
|
|
mask_offset[i], i, mask_offset[i - 1], i - 1); |
|
|
|
|
|
} |
|
|
|
|
|
const int max_mask = param.dct_block_size * param.dct_block_size; |
|
|
|
|
|
for (int i = 0; i < mask_val_len; ++i) { |
|
|
|
|
|
mgb_assert(0 <= mask_val[i] && mask_val[i] < max_mask, |
|
|
|
|
|
"Invalid mask_val, assert 0 <= mask_val[%d] < %d, aka 0 <= " |
|
|
|
|
|
"%d < %d", |
|
|
|
|
|
i, max_mask, mask_val[i], max_mask); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
DctChannelSelectForward::DctChannelSelectForward( |
|
|
|
|
|
VarNode* src, VarNode* mask_offset, VarNode* mask_val, |
|
|
|
|
|
const Param& param, const OperatorNodeConfig& config) |
|
|
|
|
|
: Super(OperatorNodeBaseCtorParam{ |
|
|
|
|
|
src->owner_graph(), config, "dct_channel_select", {src}}) { |
|
|
|
|
|
init_megdnn_opr(*this, param); |
|
|
|
|
|
add_input({src, mask_offset, mask_val}); |
|
|
|
|
|
if (mask_offset != nullptr) { |
|
|
|
|
|
mgb_assert(mask_val, |
|
|
|
|
|
"mask_val should not be null when mask_offset is not null"); |
|
|
|
|
|
auto host_offset = mask_offset->owner_opr() |
|
|
|
|
|
->cast_final_safe<opr::ImmutableTensor>() |
|
|
|
|
|
.host_value(); |
|
|
|
|
|
auto host_val = mask_val->owner_opr() |
|
|
|
|
|
->cast_final_safe<opr::ImmutableTensor>() |
|
|
|
|
|
.host_value(); |
|
|
|
|
|
|
|
|
|
|
|
valid_mask(host_offset.ptr<int>(), |
|
|
|
|
|
host_offset.layout().total_nr_elems(), host_val.ptr<int>(), |
|
|
|
|
|
host_val.layout().total_nr_elems(), param); |
|
|
|
|
|
} |
|
|
|
|
|
intl::MegDNNOprInitPostCtor<DctChannelSelectForward>::apply(*this); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
SymbolVar DctChannelSelectForward::make(SymbolVar src, SymbolVar mask_offset, |
|
|
|
|
|
SymbolVar mask_val, const Param& param, |
|
|
|
|
|
const OperatorNodeConfig& config) { |
|
|
|
|
|
intl::MegDNNOprInitInputsModifier<DctChannelSelectForward>::apply( |
|
|
|
|
|
param, {&src, &mask_offset, &mask_val}); |
|
|
|
|
|
return src.insert_single_output_opr<DctChannelSelectForward>( |
|
|
|
|
|
src.node(), mask_offset.node(), mask_val.node(), param, config); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
MEGDNN_OPR_INIT1(DctChannelSelectForward, "dct_channel_select") |
|
|
MEGDNN_OPR_INIT1(DctChannelSelectForward, "dct_channel_select") |
|
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |