Browse Source

fix(imperative): add param(axis) for GetVarShape

GitOrigin-RevId: 0b8f821929
release-1.2
Megvii Engine Team 4 years ago
parent
commit
a78c11099b
2 changed files with 30 additions and 20 deletions
  1. +29
    -19
      imperative/src/impl/ops/tensor_manip.cpp
  2. +1
    -1
      src/core/include/megbrain/ir/ops.td

+ 29
- 19
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -20,22 +20,30 @@ namespace {
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
def.cast_final_safe<GetVarShape>();
return opr::GetVarShape::make(inputs).node()->owner_opr();
auto&& op_def = def.cast_final_safe<GetVarShape>();
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr();
}

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
def.cast_final_safe<GetVarShape>();
auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
mgb_assert(shp.ndim != 0, "input shape invalid");
HostTensorND hv(inp->comp_node(), {shp.ndim}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp.shape[i];
HostTensorND hv;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp.shape[i];
}
}else{
mgb_assert(op_def.axis < shp.ndim);
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = shp.shape[op_def.axis];
}
return {Tensor::make(std::move(hv))};
}
@@ -43,29 +51,31 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
def.cast_final_safe<GetVarShape>();
auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0];
if (!desc.layout.ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true};
}
DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) {
ptr[i] = desc.layout[i];
DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) {
ptr[i] = desc.layout[i];
}
}else{
mgb_assert(op_def.axis < desc.layout.ndim);
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
ptr[0] = desc.layout[op_def.axis];
}
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::GetVarShape>();
if (node->config().comp_node().size() ||
node->config().output_dtype().valid() ||
node->param().axis != opr::GetVarShape::Param::INVALID_AXIS) {
mgb_log_debug("weird GetVarShape");
return OpTrait::find_by_typeinfo(OprAttr::typeinfo())->make_from_op_node(node);
}
return GetVarShape::make();
return GetVarShape::make(node->param());
}

OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)


+ 1
- 1
src/core/include/megbrain/ir/ops.td View File

@@ -122,7 +122,7 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> {
);
}

def GetVarShape : MgbHashableOp<"GetVarShape">;
def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>;

def Concat: MgbHashableOp<"Concat", [AxisParam]> {
let extraArguments = (ins


Loading…
Cancel
Save