|
|
@@ -13,6 +13,8 @@ |
|
|
|
|
|
|
|
#include "megbrain/imperative/op_def.h" |
|
|
|
|
|
|
|
#include "megbrain/utils/hash.h" |
|
|
|
|
|
|
|
namespace mgb::imperative { |
|
|
|
|
|
|
|
class GetVarShape : public OpDefImplBase<GetVarShape> { |
|
|
@@ -41,6 +43,33 @@ public: |
|
|
|
|
|
|
|
std::vector<dt_int32> offsets; |
|
|
|
std::vector<std::vector<size_t>> shapes; |
|
|
|
|
|
|
|
size_t hash() const override { |
|
|
|
XXHash builder; |
|
|
|
for (auto&& offset : offsets) { |
|
|
|
builder.update(&offset, sizeof(offset)); |
|
|
|
} |
|
|
|
auto&& offset_cnt = offsets.size(); |
|
|
|
builder.update(&offset_cnt, sizeof(offset_cnt)); |
|
|
|
for (auto&& shape : shapes) { |
|
|
|
for (auto&& dim_len : shape) { |
|
|
|
builder.update(&dim_len, sizeof(dim_len)); |
|
|
|
} |
|
|
|
auto&& dim_cnt = shape.size(); |
|
|
|
builder.update(&dim_cnt, sizeof(dim_cnt)); |
|
|
|
} |
|
|
|
auto&& shape_cnt = shapes.size(); |
|
|
|
builder.update(&shape_cnt, sizeof(shape_cnt)); |
|
|
|
return builder.digest(); |
|
|
|
} |
|
|
|
|
|
|
|
bool is_same_st(const Hashable& rhs) const override { |
|
|
|
auto* pps = rhs.try_cast_final<ParamPackSplit>(); |
|
|
|
if(pps == nullptr){ |
|
|
|
return false; |
|
|
|
} |
|
|
|
return offsets == pps->offsets && shapes == pps->shapes; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class ParamPackConcat : public OpDefImplBase<ParamPackConcat> { |
|
|
@@ -53,6 +82,24 @@ public: |
|
|
|
: offsets(offsets_) {} |
|
|
|
|
|
|
|
std::vector<dt_int32> offsets; |
|
|
|
|
|
|
|
size_t hash() const override { |
|
|
|
XXHash builder; |
|
|
|
for (auto&& offset : offsets) { |
|
|
|
builder.update(&offset, sizeof(offset)); |
|
|
|
} |
|
|
|
auto&& offset_cnt = offsets.size(); |
|
|
|
builder.update(&offset_cnt, sizeof(offset_cnt)); |
|
|
|
return builder.digest(); |
|
|
|
} |
|
|
|
|
|
|
|
bool is_same_st(const Hashable& rhs) const override { |
|
|
|
auto* ppc = rhs.try_cast_final<ParamPackConcat>(); |
|
|
|
if(ppc == nullptr){ |
|
|
|
return false; |
|
|
|
} |
|
|
|
return offsets == ppc->offsets; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
} // namespace mgb::imperative |