GitOrigin-RevId: 5b036c2c5a
tags/v1.9.0
@@ -111,7 +111,6 @@ def test_xornet_trace_dump(): | |||
_, loss = val_fun(data, label) | |||
loss = loss.numpy() | |||
val_loss.append((step, loss)) | |||
print("Step: {} loss={}".format(step, loss)) | |||
opt.step() | |||
test_data = np.array( | |||
@@ -89,8 +89,7 @@ def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, | |||
return megengine.tensor(np.random.random(shape), dtype=dtype, device=device) | |||
# skip this test because could not do several reduce sequentially with opr cache | |||
if device == "cpux": | |||
return | |||
return | |||
# test shape change | |||
for image_shape in [(223, 223), (10, 20)]: | |||
@@ -718,7 +718,6 @@ def test_assert_equal(): | |||
inp2 = g.make_h2d(dtype=np.float32, device="xpux") | |||
op = builtin.AssertEqual(maxerr=1e-5) | |||
out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0] | |||
print(out) | |||
g.compile(out) | |||
file = io.BytesIO() | |||
out_model = G.dump_graph([out]) | |||
@@ -51,7 +51,6 @@ def test_profiler(format, trace_mode): | |||
with Profiler(profile_prefix, format=format): | |||
infer() | |||
print(profile_path) | |||
assert os.path.exists(profile_path), "profiling results not found" | |||
if format == "chrome_timeline.json": | |||
@@ -49,6 +49,7 @@ struct ApplyOp { | |||
std::shared_ptr<OpDef> op; | |||
SmallVector<TensorInfo*> inputs; | |||
SmallVector<TensorInfo*> outputs; | |||
bool validated = false; | |||
template <typename TFunctor> | |||
void get_props(TFunctor&& functor) const { | |||
@@ -280,7 +280,8 @@ void ChannelImpl::dispatch_default_cpu( | |||
input_tensors.push_back(Tensor::make( | |||
input_tensornd, HostTensorND::make_proxy(input_tensornd))); | |||
} | |||
auto output_tensors = OpDef::apply_on_physical_tensor(*op, input_tensors); | |||
auto output_tensors = OpDef::apply_on_physical_tensor( | |||
*op, input_tensors, output_descs, validated); | |||
for (size_t i = 0; i < output_tensors.size(); ++i) { | |||
output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor()); | |||
} | |||
@@ -324,6 +325,7 @@ void ChannelImpl::dispatch_kernel( | |||
MGB_RECORD_EVENT(ShapeInferEvent, validated); | |||
ApplyOp cmd{Profiler::next_id(), std::move(op)}; | |||
cmd.validated = validated; | |||
cmd.inputs = std::move(input_infos); | |||
for (int i = 0; i < output_descs.size(); ++i) { | |||
auto&& desc = output_descs[i]; | |||
@@ -703,14 +705,16 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
auto_evict(0); | |||
} | |||
auto apply_on_physical_tensor = | |||
[&](auto&& self, const OpDef& def, | |||
SmallVector<TensorPtr> inputs) -> SmallVector<TensorPtr> { | |||
[&](auto&& self, const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, | |||
const bool& validated) -> SmallVector<TensorPtr> { | |||
auto apply_functor = [&](std::shared_ptr<OpDef> op, | |||
SmallVector<TensorPtr> inputs, | |||
size_t nr_outputs) -> SmallVector<TensorPtr> { | |||
auto opname = op->trait()->make_name(*op); | |||
imperative_log_profile_begin(opname.c_str()); | |||
auto outputs = self(self, *op, inputs); | |||
// do not use infered output_desc in subgraph | |||
auto outputs = self(self, *op, inputs, output_descs, false); | |||
imperative_log_profile_end(opname.c_str()); | |||
return outputs; | |||
}; | |||
@@ -726,7 +730,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
inputs, apply_functor, const_functor); | |||
return outputs; | |||
} | |||
return OpDef::apply_on_physical_tensor(def, inputs); | |||
return OpDef::apply_on_physical_tensor(def, inputs, output_descs, validated); | |||
}; | |||
MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | |||
// Begin profiling operator | |||
@@ -757,8 +761,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
Timer::record_device(device)); | |||
} | |||
// Apply op | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
for (auto i : cmd.outputs) { | |||
output_descs.push_back(i->desc); | |||
} | |||
// Here std::move is REQUIRED for removing duplicated references. | |||
auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs); | |||
auto outputs = apply_on_physical_tensor( | |||
apply_on_physical_tensor, *cmd.op, inputs, output_descs, cmd.validated); | |||
// After execute | |||
for (auto&& [device, kernel_id] : kernels) { | |||
MGB_RECORD_EVENT_IF( | |||
@@ -39,8 +39,10 @@ DispatchMode OpDef::decide_dispatch_mode( | |||
} | |||
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | |||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
return def.trait()->apply_on_physical_tensor( | |||
def, std::move(inputs), output_descs, validated); | |||
} | |||
void OpDef::apply_on_device_tensornd( | |||
const OpDef& def, const SmallVector<DeviceTensorND>& inputs, | |||
@@ -51,7 +51,6 @@ bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
def.cast_final_safe<Broadcast>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
auto&& src = inputs[0]; | |||
@@ -82,11 +81,16 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto& input = inputs[0]; | |||
TensorShape target_shape; | |||
cg::copy_tensor_value_to_shape( | |||
target_shape, inputs[1]->get_value().proxy_to_default_cpu()); | |||
if (validated) { | |||
target_shape = output_descs[0].layout; | |||
} else { | |||
cg::copy_tensor_value_to_shape( | |||
target_shape, inputs[1]->get_value().proxy_to_default_cpu()); | |||
} | |||
TensorPtr output = Tensor::make( | |||
TensorLayout(target_shape, input->dtype()), input->comp_node()); | |||
if (output->layout().is_empty()) { | |||
@@ -171,7 +175,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op_def = def.cast_final_safe<Reshape>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||
@@ -179,6 +184,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
auto&& tshp_nd = inputs[1]; | |||
auto slayout = src->layout(); | |||
if (validated) { | |||
return {Tensor::make(src->blob(), 0, output_descs[0].layout)}; | |||
} | |||
TensorShape tshp; | |||
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
@@ -186,9 +195,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
tshp[op_def.axis] = 1; | |||
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||
} | |||
TensorLayout tlayout = slayout.reshape(tshp); | |||
// memory forward | |||
return {Tensor::make(src->blob(), 0, tlayout)}; | |||
return {Tensor::make(src->blob(), 0, slayout.reshape(tshp))}; | |||
} | |||
OP_TRAIT_REG(Reshape, Reshape) | |||
@@ -33,9 +33,8 @@ cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& in | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto&& opr = def.cast_final_safe<CondTake>(); | |||
mgb_assert(opr.same_type<CondTake>()); | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu", inputs.size()); | |||
auto&& inp = inputs[0]; | |||
@@ -196,16 +196,14 @@ void apply_on_device_tensornd( | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto&& op = static_cast<const CustomOpDef&>(def); | |||
auto [output_descs, success] = op.infer_output_attrs(inputs); | |||
mgb_assert(success == true, "infer output attributes fall\n"); | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
mgb_assert(validated == true, "infer output attributes fall\n"); | |||
SmallVector<TensorPtr> outputs(output_descs.size()); | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
auto& output = outputs[i]; | |||
auto& output_desc = output_descs[i]; | |||
output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||
output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||
} | |||
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||
@@ -112,17 +112,14 @@ void apply_on_device_tensornd( | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||
TensorShapeArray inp_shapes(inputs.size()); | |||
for (unsigned i = 0; i < inputs.size(); ++i) { | |||
inp_tensornds[i] = inputs[i]->dev_tensor(); | |||
inp_shapes[i] = inputs[i]->layout(); | |||
} | |||
TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); | |||
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( | |||
inp_tensornds[0].comp_node(), {shape, inp_tensornds[0].layout().dtype}); | |||
inp_tensornds[0].comp_node(), output_descs[0].layout); | |||
SmallVector<DeviceTensorND> oup_tensornds = {out}; | |||
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | |||
return {Tensor::make(oup_tensornds[0])}; | |||
@@ -221,7 +218,8 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( | |||
} | |||
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
mgb_assert( | |||
inputs[0]->blob().use_count() == 1 && inputs[0]->blob()->storage().unique(), | |||
"This inplace modification may change the elements of other tensors. " | |||
@@ -24,7 +24,8 @@ SymbolVarArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
size_t size = inputs.size(); | |||
auto&& op = def.cast_final_safe<CheckNonFinite>(); | |||
SmallVector<TensorPtr> outputs(size + 1); | |||
@@ -63,18 +64,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||
return {dests, true}; | |||
} | |||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
size_t size = inputs.size(); | |||
SmallVector<LogicalTensorDesc> dests(size + 1); | |||
for (size_t i = 0; i < size; ++i) { | |||
dests[i].comp_node = inputs[i]->comp_node(); | |||
dests[i].layout = inputs[i]->layout(); | |||
} | |||
dests[size].comp_node = inputs[0]->comp_node(); | |||
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||
return dests; | |||
} | |||
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -51,11 +51,13 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
if (memory_forward_success(def, inputs)) { | |||
return {Tensor::make(inputs[0]->blob(), 0, inputs[0]->layout())}; | |||
} | |||
return proxy_graph_detail::apply_on_physical_tensor(def, inputs); | |||
return proxy_graph_detail::apply_on_physical_tensor( | |||
def, inputs, output_descs, validated); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
@@ -419,8 +419,7 @@ _INST_RNG_MAKER(2) | |||
template <typename Op> | |||
void exec( | |||
const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||
const SmallVector<TensorPtr>& outputs, | |||
const SmallVector<TensorPtr>& workspace) { | |||
const SmallVector<TensorPtr>& outputs) { | |||
auto&& rng = op.cast_final_safe<Op>(); | |||
auto dest = outputs[0]; | |||
@@ -451,82 +450,68 @@ void exec( | |||
} | |||
template <typename Op> | |||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
SmallVector<CompNode> infer_output_cns( | |||
const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
LogicalTensorDesc dest; | |||
CompNode cn; | |||
auto&& rng = op.cast_final_safe<Op>(); | |||
auto handle = rng.handle; | |||
if (handle) { | |||
dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
cn = RNGDnnOpManager::get_comp_node(handle); | |||
} else { | |||
dest.comp_node = inputs[0]->comp_node(); | |||
cn = inputs[0]->comp_node(); | |||
} | |||
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; | |||
if (!rng_with_shape) { | |||
for (int i = 0; i < inputs.size(); ++i) { | |||
mgb_assert( | |||
inputs[i]->comp_node() == dest.comp_node, | |||
inputs[i]->comp_node() == cn, | |||
"%s expects the device of inputs[%d] to be same as the device of " | |||
"handle; " | |||
"got %s and %s actually", | |||
rng.dyn_typeinfo()->name, i, | |||
inputs[i]->comp_node().to_string().c_str(), | |||
dest.comp_node.to_string().c_str()); | |||
inputs[i]->comp_node().to_string().c_str(), cn.to_string().c_str()); | |||
} | |||
} | |||
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng); | |||
return {dest}; | |||
return {cn}; | |||
} | |||
template <> | |||
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>( | |||
SmallVector<CompNode> infer_output_cns<ShuffleRNG>( | |||
const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<LogicalTensorDesc> dests(2); | |||
SmallVector<CompNode> cns(2); | |||
auto&& rng = op.cast_final_safe<ShuffleRNG>(); | |||
auto handle = rng.handle; | |||
if (handle) { | |||
dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
cns[0] = RNGDnnOpManager::get_comp_node(handle); | |||
cns[1] = RNGDnnOpManager::get_comp_node(handle); | |||
} else { | |||
dests[0].comp_node = inputs[0]->comp_node(); | |||
dests[1].comp_node = inputs[0]->comp_node(); | |||
cns[0] = inputs[0]->comp_node(); | |||
cns[1] = inputs[0]->comp_node(); | |||
} | |||
dests[0].layout = TensorLayout(inputs[0]->layout()); | |||
dests[0].layout.dtype = inputs[0]->layout().dtype; | |||
dests[1].layout = | |||
TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); | |||
return dests; | |||
return cns; | |||
} | |||
template <> | |||
SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( | |||
SmallVector<CompNode> infer_output_cns<Dropout>( | |||
const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<LogicalTensorDesc> dests(2); | |||
SmallVector<CompNode> cns(2); | |||
auto&& cn = inputs[0]->comp_node(); | |||
dests[0].comp_node = cn; | |||
dests[0].layout = TensorLayout(inputs[0]->layout()); | |||
dests[0].layout.dtype = inputs[0]->layout().dtype; | |||
auto get_mask_size = [&]() -> size_t { | |||
auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); | |||
return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes( | |||
inputs[0]->layout()); | |||
}; | |||
dests[1].comp_node = cn; | |||
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); | |||
return dests; | |||
cns[0] = cn; | |||
cns[1] = cn; | |||
return cns; | |||
} | |||
template <typename Op> | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
SmallVector<TensorPtr> outputs; | |||
SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs); | |||
for (auto&& i : desc) { | |||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||
SmallVector<CompNode> cns = infer_output_cns<Op>(def, inputs); | |||
for (size_t i = 0; i < cns.size(); i++) { | |||
outputs.push_back(Tensor::make(output_descs[i].layout, cns[i])); | |||
} | |||
exec<Op>(def, inputs, outputs, {}); | |||
exec<Op>(def, inputs, outputs); | |||
return outputs; | |||
} | |||
@@ -99,7 +99,8 @@ HostTensorND get_var_shape_host_tensor( | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))}; | |||
} | |||
@@ -180,7 +181,8 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||
} | |||
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||
mgb_assert( | |||
inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | |||
@@ -217,7 +219,8 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||
} | |||
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
def.cast_final_safe<ParamPackConcat>(); | |||
mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); | |||
auto comp_node = inputs.front()->comp_node(); | |||
@@ -62,25 +62,10 @@ OP_TRAIT_REG(FastpathCopy, FastpathCopy) | |||
namespace { | |||
namespace shape_infer { | |||
auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto& op = def.cast_final_safe<ShapeInfer>(); | |||
size_t nr_inputs = inputs.size(); | |||
mgb_assert(nr_inputs > 0, "no inputs for ShapeInfer"); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (size_t i = 0; i < nr_inputs; ++i) { | |||
auto input = inputs[i]->get_value(); | |||
TensorLayout layout; | |||
layout.ndim = input.shape(0); | |||
for (size_t i = 0; i < layout.ndim; ++i) { | |||
layout[i] = input.ptr<int32_t>()[i]; | |||
} | |||
layout.dtype = op.dtypes[i]; | |||
layout.init_contiguous_stride(); | |||
input_descs.push_back({layout, op.devices[i]}); | |||
} | |||
auto [output_descs, valid] = | |||
OpDef::infer_output_attrs_fallible(*op.op, input_descs); | |||
mgb_assert(valid, "shape inference incomplete"); | |||
auto apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
mgb_assert(validated, "shape inference incomplete"); | |||
SmallVector<TensorPtr> outputs; | |||
for (auto&& output_desc : output_descs) { | |||
HostTensorND shape_tensor{ | |||
@@ -189,7 +174,9 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
return opr::Identity::make(inputs[0], config); | |||
} | |||
auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
return SmallVector<TensorPtr>{inputs[0]}; | |||
} | |||
OP_TRAIT_REG(Identity, Identity) | |||
@@ -588,7 +575,9 @@ ComputingGraphHolder<Kind>& get_computing_graph( | |||
return *cg_holder_queue.back(); | |||
} | |||
auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (auto&& input : inputs) { | |||
input_descs.push_back({input->layout(), input->comp_node()}); | |||
@@ -451,7 +451,14 @@ public: | |||
} | |||
} else { | |||
if (dep.type == cg::static_infer::DepType::SHAPE) { | |||
if (auto* val = infer(output_data[dep.idx].shape_infer, sync)) { | |||
// using opr->output()->shape when it's available | |||
// otherwise infer it | |||
if (!owner.m_opr->output(dep.idx)->shape().is_empty()) { | |||
target.inp_val.val[i].m_shape = | |||
&owner.m_opr->output(dep.idx)->shape(); | |||
} else if ( | |||
auto* val = | |||
infer(output_data[dep.idx].shape_infer, sync)) { | |||
target.inp_val.val[i].m_shape = val; | |||
} else | |||
return false; | |||
@@ -798,7 +805,8 @@ public: | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& desc, const bool& validated) { | |||
auto raw_inputs = to_raw_ptr_array(inputs); | |||
auto& minigraph = get_cached_minigraph(def, raw_inputs); | |||
auto _ = scoped_attach(&minigraph); | |||
@@ -811,10 +819,12 @@ public: | |||
// LogicalTensorDesc for minigraph.opr()->usable_output() | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
for (size_t i = 0; i < minigraph.opr()->output().size(); ++i) { | |||
auto* var = minigraph.opr()->output()[i]; | |||
auto* shape = sess.infer(sess.output_data[i].shape_infer, true); | |||
mgb_assert(shape); | |||
minigraph.opr()->output()[i]->shape(*shape); | |||
var->shape(*shape); | |||
} | |||
for (size_t i = 0; i < minigraph.output_size(); ++i) { | |||
auto* ovar = minigraph.output_var(i); | |||
mgb_assert(ovar->dtype().valid() && ovar->comp_node().valid()); | |||
@@ -829,6 +839,7 @@ public: | |||
outputs[i] = | |||
Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||
} | |||
auto raw_outputs = to_raw_ptr_array(outputs); | |||
CompNode::UnorderedSet used_cns; | |||
for (auto&& out : raw_outputs) { | |||
@@ -843,6 +854,7 @@ public: | |||
} | |||
} | |||
} | |||
// some opr (e.g. Subtensor) may invoke infer_value during execution, | |||
// so we need create inference session here | |||
minigraph.execute(raw_inputs, raw_outputs, m_env); | |||
@@ -853,6 +865,7 @@ public: | |||
} | |||
} | |||
} | |||
return outputs; | |||
} | |||
}; | |||
@@ -27,9 +27,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
auto ret = | |||
proxy_graph::ProxyGraphTypeI::inst().apply_on_physical_tensor(def, inputs); | |||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto ret = proxy_graph::ProxyGraphTypeI::inst().apply_on_physical_tensor( | |||
def, inputs, output_descs, validated); | |||
return ret; | |||
} | |||
@@ -62,15 +62,19 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (auto&& input : inputs) { | |||
input_descs.push_back({input->layout(), input->comp_node()}); | |||
} | |||
auto subgraph = def.trait()->make_forward_graph(def, input_descs); | |||
auto apply_functor = [](const std::shared_ptr<OpDef>& op, | |||
const SmallVector<TensorPtr>& inputs, size_t nr_outputs) { | |||
return OpDef::apply_on_physical_tensor(*op, inputs); | |||
auto apply_functor = [&output_descs]( | |||
const std::shared_ptr<OpDef>& op, | |||
const SmallVector<TensorPtr>& inputs, | |||
size_t nr_outputs) { | |||
// do not use infered output_desc in subgraph | |||
return OpDef::apply_on_physical_tensor(*op, inputs, output_descs, false); | |||
}; | |||
auto const_functor = [&](const TensorPtr& value) { return value; }; | |||
auto outputs = subgraph.apply<TensorPtr>(inputs, apply_functor, const_functor); | |||
@@ -77,7 +77,9 @@ void TensorSanityCheck::enable() { | |||
std::move(trait.apply_on_physical_tensor)); | |||
trait.apply_on_physical_tensor = ApplyOnPhysicalTensor( | |||
[this, backup = backup.get()]( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, | |||
const bool& validated) { | |||
for (auto&& i : inputs) { | |||
if (!m_checker->check(i)) { | |||
mgb_throw( | |||
@@ -86,7 +88,7 @@ void TensorSanityCheck::enable() { | |||
print_op(def).c_str()); | |||
} | |||
} | |||
auto output = (*backup)(def, inputs); | |||
auto output = (*backup)(def, inputs, output_descs, validated); | |||
for (auto&& i : output) { | |||
mgb_assert(m_checker->check(i)); | |||
} | |||
@@ -51,7 +51,8 @@ public: | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
static SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, SmallVector<TensorPtr> inputs); | |||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated); | |||
/*! | |||
* \brief Call the corresponding dnn op to calculate results. Output | |||
@@ -18,7 +18,8 @@ namespace imperative { | |||
namespace proxy_graph_detail { | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, SmallVector<TensorPtr> inputs); | |||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated); | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
@@ -18,7 +18,8 @@ namespace imperative { | |||
namespace subgraph_detail { | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, SmallVector<TensorPtr> inputs); | |||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated); | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
@@ -81,7 +81,13 @@ T prepare_optimized_backward_inputs( | |||
SmallVector<TensorPtr> apply_shared_on_physical_tensor( | |||
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) { | |||
return OpDef::apply_on_physical_tensor(*def, inputs); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (auto&& i : inputs) { | |||
input_descs.push_back({i->layout(), i->comp_node()}); | |||
} | |||
auto [output_descs, validated] = | |||
OpDef::infer_output_attrs_fallible(*def, input_descs); | |||
return OpDef::apply_on_physical_tensor(*def, inputs, output_descs, validated); | |||
} | |||
TEST(TestImperative, BackwardGraphBasic) { | |||
@@ -106,7 +112,13 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
auto&& save_for_backward = result.input_mask; | |||
auto&& input_has_grad = result.output_mask; | |||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
for (size_t i = 0; i < inputs.size(); i++) { | |||
input_descs[i].value = inputs[i]->dev_tensor(); | |||
} | |||
auto [output_descs, validated] = | |||
OpDef::infer_output_attrs_fallible(*attr, input_descs); | |||
auto outputs = | |||
OpDef::apply_on_physical_tensor(*attr, inputs, output_descs, validated); | |||
inputs.push_back(outputs[0]); | |||
hvs.push_back(*gen({42})); | |||
inputs.push_back(Tensor::make(hvs.back())); | |||
@@ -161,7 +173,10 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
auto&& save_for_backward = result.input_mask; | |||
auto&& input_has_grad = result.output_mask; | |||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
auto [output_descs, validated] = | |||
OpDef::infer_output_attrs_fallible(*attr, input_descs); | |||
auto outputs = | |||
OpDef::apply_on_physical_tensor(*attr, inputs, output_descs, validated); | |||
inputs.push_back(outputs[0]); | |||
inputs.push_back(dc); | |||
mgb_assert(save_for_backward.size() == inputs.size()); | |||
@@ -238,7 +253,13 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
auto a_tn = Tensor::make(*a_hv); | |||
auto b_tn = Tensor::make(*b_hv); | |||
auto dc_tn = Tensor::make(*dc_hv); | |||
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
input_descs.push_back({a_tn->layout(), a_tn->comp_node(), a_tn->dev_tensor()}); | |||
input_descs.push_back({b_tn->layout(), b_tn->comp_node(), b_tn->dev_tensor()}); | |||
auto [output_descs, validated] = | |||
OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
auto c_tn = OpDef::apply_on_physical_tensor( | |||
*op, {a_tn, b_tn}, output_descs, validated)[0]; | |||
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | |||
bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
@@ -35,7 +35,8 @@ TEST(TestImperative, AllReduceBasic) { | |||
megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM, "all_reduce", 2, | |||
idx, idx == 0, false, server_addr, port, dtype::Float32(), "nccl", ""); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}, output_descs, false); | |||
HostTensorND host_v; | |||
host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
@@ -135,7 +135,9 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) { | |||
imp_physical_inp[i] = Tensor::make(host_inp[i]); | |||
} | |||
auto imp_oup = OpDef::apply_on_physical_tensor(*m_op, imp_physical_inp); | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
auto imp_oup = OpDef::apply_on_physical_tensor( | |||
*m_op, imp_physical_inp, output_descs, false); | |||
mgb_assert(imp_oup.size() == nr_oups); | |||
// check input not modified | |||
@@ -122,7 +122,10 @@ void run_graph(size_t mem_reserved) { | |||
Param param{Param::Mode::MUL}; | |||
attr.param.write_pod(param); | |||
auto out = OpDef::apply_on_physical_tensor(*op, {ptr_a[1], ptr_a[99]}).at(0); | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
auto out = OpDef::apply_on_physical_tensor( | |||
*op, {ptr_a[1], ptr_a[99]}, output_descs, false) | |||
.at(0); | |||
// value before defrag | |||
HostTensorND host_out_before; | |||
@@ -36,7 +36,8 @@ TEST(TestImperative, IORemote) { | |||
auto def = imperative::RemoteSend::make( | |||
"io_remote_test", server_addr, port, 1, "nccl"); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}, output_descs, false); | |||
}; | |||
auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | |||
@@ -44,7 +45,8 @@ TEST(TestImperative, IORemote) { | |||
"io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), | |||
std::vector<int32_t>{(int32_t)vector_size}, dtype::Float32(), "nccl"); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}, output_descs, false); | |||
HostTensorND host_v; | |||
host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
@@ -25,7 +25,14 @@ void check_rng_basic(Args&&... args) { | |||
DeviceTensorND tshape_dev; | |||
cg::copy_shape_to_tensor_value(tshape_dev, tshape); | |||
SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)}; | |||
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
input_descs.push_back( | |||
{inputs[0]->layout(), inputs[0]->comp_node(), | |||
inputs[0]->dev_tensor()}); | |||
auto [output_descs, validated] = | |||
OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
auto outputs = OpDef::apply_on_physical_tensor( | |||
*op, inputs, output_descs, validated); | |||
ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); | |||
ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||
// sync before delete handle | |||
@@ -41,7 +48,14 @@ void check_rng_with_input_basic( | |||
const CompNode& cn, const SmallVector<TensorPtr>& inputs, Args&&... args) { | |||
Handle h = new_handle(cn, 123); | |||
auto op = Op::make(std::forward<Args>(args)..., h); | |||
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (auto&& i : inputs) { | |||
input_descs.push_back({i->layout(), i->comp_node(), i->dev_tensor()}); | |||
} | |||
auto [output_descs, validated] = | |||
OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
auto outputs = | |||
OpDef::apply_on_physical_tensor(*op, inputs, output_descs, validated); | |||
ASSERT_TRUE(outputs[0]->layout().eq_shape(inputs[0]->shape())); | |||
ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||
// sync before delete handle | |||
@@ -142,7 +142,8 @@ public: | |||
const TensorLayout& layout() const { return m_layout; } | |||
MemAllocPlan& layout(const TensorLayout& dest, bool allow_shape_change = false); | |||
MGE_WIN_DECLSPEC_FUC MemAllocPlan& layout( | |||
const TensorLayout& dest, bool allow_shape_change = false); | |||
#if MGB_ENABLE_JSON | |||
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> to_json() const override; | |||