|
|
@@ -20,22 +20,30 @@ namespace { |
|
|
|
cg::OperatorNodeBase* apply_on_var_node( |
|
|
|
const OpDef& def, |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
def.cast_final_safe<GetVarShape>(); |
|
|
|
return opr::GetVarShape::make(inputs).node()->owner_opr(); |
|
|
|
auto&& op_def = def.cast_final_safe<GetVarShape>(); |
|
|
|
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<TensorPtr>& inputs) { |
|
|
|
def.cast_final_safe<GetVarShape>(); |
|
|
|
auto&& op_def = def.cast_final_safe<GetVarShape>(); |
|
|
|
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); |
|
|
|
auto&& inp = inputs[0]; |
|
|
|
auto&& shp = inp->layout(); |
|
|
|
mgb_assert(shp.ndim != 0, "input shape invalid"); |
|
|
|
HostTensorND hv(inp->comp_node(), {shp.ndim}, dtype::Int32()); |
|
|
|
auto* ptr = hv.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < shp.ndim; ++i) { |
|
|
|
ptr[i] = shp.shape[i]; |
|
|
|
HostTensorND hv; |
|
|
|
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ |
|
|
|
hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32()); |
|
|
|
auto* ptr = hv.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < shp.ndim; ++i) { |
|
|
|
ptr[i] = shp.shape[i]; |
|
|
|
} |
|
|
|
}else{ |
|
|
|
mgb_assert(op_def.axis < shp.ndim); |
|
|
|
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); |
|
|
|
auto* ptr = hv.ptr<dt_int32>(); |
|
|
|
ptr[0] = shp.shape[op_def.axis]; |
|
|
|
} |
|
|
|
return {Tensor::make(std::move(hv))}; |
|
|
|
} |
|
|
@@ -43,29 +51,31 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
def.cast_final_safe<GetVarShape>(); |
|
|
|
auto&& op_def = def.cast_final_safe<GetVarShape>(); |
|
|
|
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); |
|
|
|
auto&& desc = inputs[0]; |
|
|
|
if (!desc.layout.ndim) { |
|
|
|
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true}; |
|
|
|
} |
|
|
|
DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); |
|
|
|
auto* ptr = value.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < desc.layout.ndim; ++i) { |
|
|
|
ptr[i] = desc.layout[i]; |
|
|
|
DeviceTensorND value; |
|
|
|
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ |
|
|
|
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); |
|
|
|
auto* ptr = value.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < desc.layout.ndim; ++i) { |
|
|
|
ptr[i] = desc.layout[i]; |
|
|
|
} |
|
|
|
}else{ |
|
|
|
mgb_assert(op_def.axis < desc.layout.ndim); |
|
|
|
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); |
|
|
|
auto* ptr = value.ptr<dt_int32>(); |
|
|
|
ptr[0] = desc.layout[op_def.axis]; |
|
|
|
} |
|
|
|
return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { |
|
|
|
auto* node = &node_->cast_final_safe<opr::GetVarShape>(); |
|
|
|
if (node->config().comp_node().size() || |
|
|
|
node->config().output_dtype().valid() || |
|
|
|
node->param().axis != opr::GetVarShape::Param::INVALID_AXIS) { |
|
|
|
mgb_log_debug("weird GetVarShape"); |
|
|
|
return OpTrait::find_by_typeinfo(OprAttr::typeinfo())->make_from_op_node(node); |
|
|
|
} |
|
|
|
return GetVarShape::make(); |
|
|
|
return GetVarShape::make(node->param()); |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) |
|
|
|