|
|
@@ -5,6 +5,8 @@ |
|
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
|
|
#include "megbrain/imperative/ops/utility.h" |
|
|
|
|
|
|
|
#include "megbrain/imperative/utils/helper.h" |
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
namespace imperative { |
|
|
|
|
|
|
@@ -664,9 +666,9 @@ ValueRefList FormatTransformation::apply_transformation( |
|
|
|
if (inp_ref) { |
|
|
|
return {FormatValue::make(inp_ref->format())}; |
|
|
|
} else { |
|
|
|
mgb_log_warn( |
|
|
|
MGE_CALL_ONCE(mgb_log_warn( |
|
|
|
"Not FormattedTensorValue input for GetFormat op: %s, %s", |
|
|
|
op.to_string().c_str(), inputs[0].to_string().c_str()); |
|
|
|
op.to_string().c_str(), inputs[0].to_string().c_str())); |
|
|
|
return {FormatValue::make(FT::DEFAULT)}; |
|
|
|
} |
|
|
|
} else if (auto* _op = op.as<SetFormat>()) { |
|
|
@@ -700,9 +702,9 @@ ValueRefList FormatTransformation::apply_transformation( |
|
|
|
// make params(GradValue) as FormattedTensor |
|
|
|
return wrap_outputs(outputs, format); |
|
|
|
} else { |
|
|
|
mgb_log_warn( |
|
|
|
MGE_CALL_ONCE(mgb_log_warn( |
|
|
|
"Not FormattedTensorValue input for AttachGrad op: %s, %s", |
|
|
|
op.to_string().c_str(), inputs[0].to_string().c_str()); |
|
|
|
op.to_string().c_str(), inputs[0].to_string().c_str())); |
|
|
|
return imperative::apply(op, inputs); |
|
|
|
} |
|
|
|
} else if (auto* set_grad = op.as<SetGrad>()) { |
|
|
@@ -723,9 +725,9 @@ ValueRefList FormatTransformation::apply_transformation( |
|
|
|
wrapped_outputs[i] = |
|
|
|
m_value_type.make(outputs[i], output_ref->format()); |
|
|
|
} else { |
|
|
|
mgb_log_warn( |
|
|
|
MGE_CALL_ONCE(mgb_log_warn( |
|
|
|
"Not FormattedTensorValue outputs for SetGrad op: %s, %s", |
|
|
|
op.to_string().c_str(), inputs_[i].to_string().c_str()); |
|
|
|
op.to_string().c_str(), inputs_[i].to_string().c_str())); |
|
|
|
wrapped_outputs[i] = m_value_type.make(outputs[i], FT::DEFAULT); |
|
|
|
} |
|
|
|
} |
|
|
|