|
@@ -559,25 +559,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
} |
|
|
} |
|
|
dest.comp_node = inputs[0].comp_node; |
|
|
dest.comp_node = inputs[0].comp_node; |
|
|
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); |
|
|
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); |
|
|
return {{dest}, true}; |
|
|
|
|
|
|
|
|
return {{dest}, inputs[0].layout.ndim != 0}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <> |
|
|
template <> |
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible< |
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible< |
|
|
ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
|
|
bool success = inputs[0].layout.ndim != 0; |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> dests(2); |
|
|
SmallVector<LogicalTensorDesc> dests(2); |
|
|
dests[0].comp_node = inputs[0].comp_node; |
|
|
dests[0].comp_node = inputs[0].comp_node; |
|
|
dests[0].layout = TensorLayout(inputs[0].layout); |
|
|
dests[0].layout = TensorLayout(inputs[0].layout); |
|
|
dests[0].layout.dtype = inputs[0].layout.dtype; |
|
|
dests[0].layout.dtype = inputs[0].layout.dtype; |
|
|
dests[1].comp_node = inputs[0].comp_node; |
|
|
dests[1].comp_node = inputs[0].comp_node; |
|
|
dests[1].layout = |
|
|
|
|
|
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32()); |
|
|
|
|
|
return {dests, true}; |
|
|
|
|
|
|
|
|
if (success) { |
|
|
|
|
|
dests[1].layout = |
|
|
|
|
|
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32()); |
|
|
|
|
|
} else { |
|
|
|
|
|
dests[1].layout = TensorLayout(dtype::Int32()); |
|
|
|
|
|
} |
|
|
|
|
|
return {dests, success}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <> |
|
|
template <> |
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>( |
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>( |
|
|
const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
|
|
bool success = inputs[0].layout.ndim != 0; |
|
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> dests(2); |
|
|
SmallVector<LogicalTensorDesc> dests(2); |
|
|
auto cn = inputs[0].comp_node; |
|
|
auto cn = inputs[0].comp_node; |
|
|
dests[0].comp_node = cn; |
|
|
dests[0].comp_node = cn; |
|
@@ -590,8 +598,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro |
|
|
inputs[0].layout); |
|
|
inputs[0].layout); |
|
|
}; |
|
|
}; |
|
|
dests[1].comp_node = cn; |
|
|
dests[1].comp_node = cn; |
|
|
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); |
|
|
|
|
|
return {dests, true}; |
|
|
|
|
|
|
|
|
if (success) { |
|
|
|
|
|
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); |
|
|
|
|
|
} else { |
|
|
|
|
|
dests[1].layout = TensorLayout(dtype::Byte()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return {dests, success}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <typename Op> |
|
|
template <typename Op> |
|
|