|
|
@@ -33,7 +33,7 @@ DispatchMode decide_dispatch_mode( |
|
|
|
const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
bool host_computable = true; |
|
|
|
for (auto&& inp : inputs) { |
|
|
|
// FIXME(czh): remove value chech after proxy graph's |
|
|
|
// FIXME(czh): remove value check after proxy graph's |
|
|
|
// apply_on_device_tensornd is supported and output Tensor |
|
|
|
// is made before add_task. |
|
|
|
// then if layout is valid, ptr->layout must be ready |
|
|
@@ -50,9 +50,18 @@ void apply_on_device_tensornd( |
|
|
|
const SmallVector<DeviceTensorND>& inputs, |
|
|
|
SmallVector<DeviceTensorND>* outputs) { |
|
|
|
auto&& op_def = def.cast_final_safe<GetVarShape>(); |
|
|
|
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); |
|
|
|
auto&& inp = inputs[0]; |
|
|
|
auto&& shp = inp.layout(); |
|
|
|
|
|
|
|
TensorShape shp; |
|
|
|
if (inputs.size() == 1) { |
|
|
|
shp = inputs[0].layout(); |
|
|
|
} else { |
|
|
|
TensorShapeArray src(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
src[i] = inputs[i].layout(); |
|
|
|
} |
|
|
|
megdnn::Elemwise::deduce_shape(src, shp); |
|
|
|
} |
|
|
|
|
|
|
|
mgb_assert(shp.ndim != 0, "input shape invalid"); |
|
|
|
mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), |
|
|
|
"GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); |
|
|
@@ -99,27 +108,36 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto&& op_def = def.cast_final_safe<GetVarShape>(); |
|
|
|
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); |
|
|
|
auto&& desc = inputs[0]; |
|
|
|
if (!desc.layout.ndim) { |
|
|
|
TensorShape shp; |
|
|
|
if (inputs.size() == 1) { |
|
|
|
shp = desc.layout; |
|
|
|
} else { |
|
|
|
TensorShapeArray src(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
src[i] = inputs[i].layout; |
|
|
|
} |
|
|
|
megdnn::Elemwise::deduce_shape(src, shp); |
|
|
|
} |
|
|
|
if (!shp.ndim) { |
|
|
|
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; |
|
|
|
} |
|
|
|
DeviceTensorND value; |
|
|
|
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { |
|
|
|
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); |
|
|
|
value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32()); |
|
|
|
auto* ptr = value.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < desc.layout.ndim; ++i) { |
|
|
|
ptr[i] = desc.layout[i]; |
|
|
|
for (size_t i = 0; i < shp.ndim; ++i) { |
|
|
|
ptr[i] = shp[i]; |
|
|
|
} |
|
|
|
}else{ |
|
|
|
int32_t axis = op_def.axis; |
|
|
|
if (axis < 0) { |
|
|
|
axis += desc.layout.ndim; |
|
|
|
axis += shp.ndim; |
|
|
|
} |
|
|
|
mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim); |
|
|
|
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); |
|
|
|
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); |
|
|
|
auto* ptr = value.ptr<dt_int32>(); |
|
|
|
ptr[0] = desc.layout[axis]; |
|
|
|
ptr[0] = shp[axis]; |
|
|
|
} |
|
|
|
return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; |
|
|
|
} |
|
|
|