Author | SHA1 | Message | Date |
---|---|---|---|
|
acc2472c41 |
!1939 code_sync_0705_inc
Merge pull request !1939 from mindspore_ding/code_sync_0705 |
4 years ago |
|
f49a21d293 | code_sync_0705_inc | 4 years ago |
|
92b997f9f9 |
!1938 code_sync_0705
Merge pull request !1938 from mindspore_ding/code_sync_0705 |
4 years ago |
|
bac5d1b678 | Merge remote-tracking branch 'origin/r1.3' into code_sync_0705 | 4 years ago |
|
cc7175217c |
!1928 cherry-pick fix for dynamic shape V1
Merge pull request !1928 from 张晓昆/r1.5.0 |
4 years ago |
|
230010b770 |
!1933 parallel group
Merge pull request !1933 from 陈华/r1.5.0 |
4 years ago |
|
3929578dee | fix parallel group | 4 years ago |
|
d0f986ea46 |
!1903 fix sc
Merge pull request !1903 from 王笑天/r1.5.0 |
4 years ago |
|
0684bd48cf |
!1927 fix safe
Merge pull request !1927 from wangjiming/r1.5.0 |
4 years ago |
|
2444f46b8d |
!1926 set size for dynamic input
Merge pull request !1926 from wangzhengjun/set_size_r150 |
4 years ago |
|
ca24b76141 |
!1922 FindLastBpFromBpNode c78
Merge pull request !1922 from 梁昊/r1.5 |
4 years ago |
|
2785670745 | fix printf format | 4 years ago |
|
2d9e3da649 | IsEnterFeedNode | 4 years ago |
|
ac1f4eb1c2 | DSP: Switch -> TransData -> Cast -> Exit | 4 years ago |
|
61c203619c | Remove UT dump env | 4 years ago |
|
b23075b62f | UT for control flow group | 4 years ago |
|
dbc989a4c3 | Clear UpdatePersistTensor Warning for first run | 4 years ago |
|
796513222a | UpdatePersistTensor from ExecutionEngine | 4 years ago |
|
002583e4ef | Fix Set Control flow group for -1 | 4 years ago |
|
65cafdd034 | Replace MemcpyAsyncNodeTask | 4 years ago |
|
64d312ab12 | UT for LaunchKernelCustAicpuSo | 4 years ago |
|
ded54e73af | Fix Guard for variable release | 4 years ago |
|
7b1331770a | Fix multi control from one node | 4 years ago |
|
c440980918 | Fix BuildPartitionFrame failed | 4 years ago |
|
5da987eb3a | set size for dynamic input | 4 years ago |
|
8c55572e12 | FindLastBpFromBpNode c78 | 4 years ago |
|
656cd3f3d5 |
!1873 add copy graph
Merge pull request !1873 from HW_KK/add-copy-graph-1.5.0 |
4 years ago |
|
e5457d5949 | fix sc + | 4 years ago |
|
062757756c | fix sc | 4 years ago |
|
333bbb700a |
!1885 fix mem leak
Merge pull request !1885 from HW_KK/scope_guard |
4 years ago |
|
9088d5f696 |
!1871 FillKernel c78
Merge pull request !1871 from 梁昊/r1.5 |
4 years ago |
|
96b0db9bd5 |
!1877 step info
Merge pull request !1877 from zhupuxu/stepInfo_r1.5.0 |
4 years ago |
|
1e0a3c0bca | FillKernel c78 | 4 years ago |
|
bb2c55fac8 | add copy graph | 4 years ago |
|
e51ffe2f54 | fix mem leak | 4 years ago |
|
b4b7c454e6 |
!1882 code_sync_0626_inc
Merge pull request !1882 from mindspore_ding/code_sync_0626 |
4 years ago |
|
a8420bd064 | code_sync_0626_inc | 4 years ago |
|
f565c42679 |
!1880 code_sync_0626
Merge pull request !1880 from mindspore_ding/code_sync_0626 |
4 years ago |
|
3a4112a159 | code_sync_0626 | 4 years ago |
|
e0ee4afa7d |
!1878 update submodule
From: @wqtshg Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
be6b3a176f |
step info
Signed-off-by: zhupuxu <zhupuxu@huawei.com> |
4 years ago |
|
ee744e6bae |
!1872 Create NodeExecute on-demand
From: @xchu42 Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
17370647b4 |
!1870 update protobuf to 3.13.0
From: @li-lei0106 Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
0cf2a59463 | update version of protobuf to v3.13.0 | 4 years ago |
|
e2cad9c2ec |
fixed ad3e707 from https://gitee.com/mindspore/graphengine/pulls/1821
|
4 years ago |
|
dd2ba23718 |
!1867 add atc_params: check_report for ConvertModelToJson
From: @lichun30 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
4d9fe5505a |
!1862 add op_precision_mode option and support op_debug_level = 4
From: @lianghuikang Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
5c51e07d61 |
!1864 Fix mem leak and recursive depth protection.
From: @zhao_zhixuan Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
6fd9337505 | add atc_params: check_report for ConvertModelToJson | 4 years ago |
|
2a0a6eaf2c | Fix bug. | 4 years ago |
|
cd9869c99d | Fix bug. | 4 years ago |
|
122a05cda3 |
!1863 fix opt info
From: @chen-hua-baker Reviewed-by: @ji_chen,@wqtshg Signed-off-by: @ji_chen |
4 years ago |
|
a562b4b6be | Fix mem leak and recursive depth protection. | 4 years ago |
|
a431199716 | Fix mem leak and recursive depth protection. | 4 years ago |
|
d3bda362d9 | fix opt info | 4 years ago |
|
83dffb39f6 |
!1859 skip control flow op when replace node with empty tensor
From: @wangzhengjun3 Reviewed-by: @wangxiaotian22,@wqtshg Signed-off-by: @wqtshg |
4 years ago |
|
26fc202953 |
!1848 add op_precision_mode option and support op_debug_level = 4
From: @lianghuikang Reviewed-by: Signed-off-by: |
4 years ago |
|
49c2eadb7c | skip control flow op when replace node with empty tensor | 4 years ago |
|
3aec2fa1c8 |
!1850 fix cmetric
From: @wangxiaotian22 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
f410cde8f2 |
!1835 move opt to ge_compile
From: @chen-hua-baker Reviewed-by: @ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
310959e5d9 | move to ge_compile | 4 years ago |
|
802f671006 |
!1826 opt info
From: @chen-hua-baker Reviewed-by: @ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
bba62ec5f3 | opt_info | 4 years ago |
|
0854139134 | update submodule | 4 years ago |
|
9084063b48 |
!1818 回退 'Pull Request !1806 : add header targets for link'
From: @wqtshg Reviewed-by: @xchu42,@zhangxiaokun9,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
287d6db86f |
!1817 update submodule file
From: @wqtshg Reviewed-by: @zhangxiaokun9,@xchu42 Signed-off-by: @ji_chen |
4 years ago |
|
a70ab7c91a | 回退 'Pull Request !1806 : add header targets for link' | 4 years ago |
|
65ba89fb34 | update .gitmodules. | 4 years ago |
|
3e510ffc8d |
!1815 回退 'Pull Request !1784 : Create NodeExecute on-demand'
Merge pull request !1815 from 储星/revert-merge-1784-master |
4 years ago |
|
c2a1076a87 | 回退 'Pull Request !1784 : Create NodeExecute on-demand' | 4 years ago |
|
3e57016ffa |
!1806 add header targets for link
From: @kevin365 Reviewed-by: @zhangxiaokun9,@wqtshg Signed-off-by: @wqtshg |
4 years ago |
|
2d5ad3a72e |
!1808 code_sync_0617_update
From: @ding_fei_fei Reviewed-by: @majorzhang,@zhoufeng54,@liucunwei,@liucunwei Signed-off-by: @majorzhang,@liucunwei,@liucunwei |
4 years ago |
|
076e4222ed |
!1807 回退 'Pull Request ls : Adaptation rectification of op_tiling.'
From: @zhao_zhixuan Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshg |
4 years ago |
|
ec89f88f1f | code_sync_0617_update | 4 years ago |
|
676ce23b55 | 回退 'Pull Request ls : Adaptation rectification of op_tiling.' | 4 years ago |
|
6f130e2290 |
add link header targets
Signed-off-by: wangkai <wang3113@126.com> |
4 years ago |
|
002f4da441 |
!1803 code_sync_0617_inc
From: @ding_fei_fei Reviewed-by: @liucunwei,@ljl0711 Signed-off-by: @liucunwei |
4 years ago |
|
ac2ce17a9a | code_sync_0617_inc | 4 years ago |
|
f71c51e18c |
!1798 code_sync_0617
Merge pull request !1798 from code_sync_0617 |
4 years ago |
|
7610fa5393 |
!1712 Adaptation rectification of op_tiling.
From: @zhao_zhixuan Reviewed-by: @xchu42 Signed-off-by: |
4 years ago |
|
da8b5170e7 | code_sync_0617_1 | 4 years ago |
|
1e9558e8dd |
!1784 Create NodeExecute on-demand
From: @xchu42 Reviewed-by: @ji_chen,@wqtshg Signed-off-by: @ji_chen |
4 years ago |
|
fd51637c46 | Fix zip bug. | 4 years ago |
|
bd1beee90c | Fix zip bug. | 4 years ago |
|
caba8dbf6b |
!1796 bugfix for restore context
From: @yangyongqiang5033 Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshg |
4 years ago |
|
b59fbaca6b |
!1790 fix sc
From: @jiming6 Reviewed-by: @xchu42,@zhangxiaokun9 Signed-off-by: @wqtshg |
4 years ago |
|
9476853d22 | Adaptation rectification of op_tiling. | 4 years ago |
|
b0a017d406 |
!1785 Optimize performance of single_op executor.
From: @zhao_zhixuan Reviewed-by: Signed-off-by: |
4 years ago |
|
535b674d1a |
!1791 modify dump config
From: @zhou_chao1993 Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshg |
4 years ago |
|
c173f92bcc |
!1797 Remove reduplicated useless proto
From: @zhangxiaokun9 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
1bed26c72e | Remove reduplicated useless proto | 4 years ago |
|
82d489a5e6 |
!1794 update submodule metadef
From: @wbigat Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
246d7e4fd8 | bugfix for restore context | 4 years ago |
|
5bcb04dfb7 | update submodule | 4 years ago |
|
4bc0f6f2af | Fix bug. | 4 years ago |
|
b17eafe3db | Fix bug. | 4 years ago |
|
23c8a0d581 | Fix ut. | 4 years ago |
|
c4610153e6 |
!1787 ge code for 1981
From: @zhou_lili Reviewed-by: Signed-off-by: |
4 years ago |
|
6927a8eef3 | modif dump config | 4 years ago |
|
116167dc88 | ge code for 1981 | 4 years ago |
|
58086ab187 | Release mem. | 4 years ago |
|
69da59b6b7 | Fix ut. | 4 years ago |
|
d0e83c26a7 | Merge branch 'my_dev4' of https://gitee.com/zhao_zhixuan/graphengine into my_dev4 | 4 years ago |
|
d1eba02e1e | Fix ut. | 4 years ago |
|
b64048a39f | Init NodeExecutor on demand | 4 years ago |
|
24eedfa3b4 | Fix ut. | 4 years ago |
|
7ce31b2e0e | Fix ut. | 4 years ago |
|
0c2d07eb72 | Fix ut. | 4 years ago |
|
e765135b86 | Merge branch 'my_dev4' of https://gitee.com/zhao_zhixuan/graphengine into my_dev4 | 4 years ago |
|
492d36b237 | Fix ut. | 4 years ago |
|
3a54dc6dd8 | Merge https://gitee.com/mindspore/graphengine into my_dev4 | 4 years ago |
|
ab7334ed78 | Release context in execute end. | 4 years ago |
|
f7370bc074 |
!1788 enable optimization
From: @wangzhengjun3 Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
78bbbda571 |
!1783 Fix dynamic shape partition
From: @zhangxiaokun9 Reviewed-by: @ji_chen Signed-off-by: @ji_chen |
4 years ago |
|
181cd5891b | Release context in execute end. | 4 years ago |
|
367774c5b0 | enable optimization | 4 years ago |
|
f578e8fff4 | Fix NodeState for UT | 4 years ago |
|
ab65075326 | Add Init to NodeState | 4 years ago |
|
8852766766 | Fix hccl_node_executor_unittest | 4 years ago |
|
e85bbe2181 | Fix dynamic shape partition | 4 years ago |
|
5393b67679 |
!1774 ut fix
From: @zhengyuanhua Reviewed-by: Signed-off-by: |
4 years ago |
|
b35412f5ea | Add ut. | 4 years ago |
|
1ab9ae32dc | Add ut. | 4 years ago |
|
13c98395e2 | Add ut. | 4 years ago |
|
4c3c819129 | Optimize performance of single_op executor. | 4 years ago |
|
83421b6c16 |
!1775 run_flag
From: @dimitri_rose Reviewed-by: Signed-off-by: |
4 years ago |
|
1847efba53 | run_flag | 4 years ago |
|
38bea2e8a6 |
!1780 set step profiling
From: @zhupuxu Reviewed-by: @ji_chen,@wqtshg Signed-off-by: @wqtshg |
4 years ago |
|
0be66d0dca |
set step
Signed-off-by: zhupuxu <zhupuxu@huawei.com> |
4 years ago |
|
d8ba1fb2c0 | remove graph ut form ge | 4 years ago |
|
eee4d7b492 | fix safe | 4 years ago |
|
9e33784220 | Merge branch 'master' of gitee.com:jiming6/graphengine | 4 years ago |
|
2abf8be621 | fix sc | 4 years ago |
|
f826880f75 | fix sc | 4 years ago |
|
3442c74c1e |
!1748 code_sync_0604_rt_ffts.h
From: @ding_fei_fei Reviewed-by: @liucunwei,@majorzhang Signed-off-by: @liucunwei |
4 years ago |
|
bc93b5b742 | code_sync_0604 | 4 years ago |
|
94b64c3628 |
!1744 code_sync_0604
From: @ding_fei_fei Reviewed-by: @liucunwei,@ljl0711 Signed-off-by: @liucunwei |
4 years ago |
|
ba20c7d22c | code_sync_0604_inc | 4 years ago |
|
61ef430599 |
!1742 code_sync_0604
From: @ding_fei_fei Reviewed-by: @liucunwei,@ljl0711 Signed-off-by: @liucunwei |
4 years ago |
|
267878bf95 | code_sync_0604 | 4 years ago |
|
c5be7f3198 |
!1677 code_sync_0520_inc
From: @ding_fei_fei Reviewed-by: @xsmq,@liucunwei Signed-off-by: @liucunwei |
4 years ago |
|
2089d50d14 | run ge head | 4 years ago |
|
840fa69b81 |
!1674 code_sync_0520
From: @ding_fei_fei Reviewed-by: @liucunwei,@lilongfei15 Signed-off-by: @liucunwei |
4 years ago |
|
55059dae11 | Merge remote-tracking branch 'origin/release' into code_sync_0520 | 4 years ago |
|
f1f6a01ffc |
!1603 code_sync_0428_inc
From: @ding_fei_fei Reviewed-by: @liucunwei,@ljl0711 Signed-off-by: @liucunwei |
4 years ago |
|
452ec62cb2 | code_sync_0428_inc | 4 years ago |
|
fcf0e310ae |
!1600 code_sync_0428
From: @ding_fei_fei Reviewed-by: @xsmq,@liucunwei Signed-off-by: @liucunwei |
4 years ago |
|
38621eb9ed | code_sync_0428 | 4 years ago |
|
600105f297 | Merge remote-tracking branch 'origin/release' into code_sync_0428 | 4 years ago |
|
708f523f05 |
!1561 code_sync_0420_inc
From: @ding_fei_fei Reviewed-by: @xsmq,@liucunwei Signed-off-by: @liucunwei |
4 years ago |
|
5b9393e323 | code_sync_0420_inc | 4 years ago |
|
10b87bb2cf |
!1556 code_sync_0420
From: @ding_fei_fei Reviewed-by: @liucunwei,@xsmq Signed-off-by: @liucunwei |
4 years ago |
|
4d0b6f9d9a | code_sync_0420 | 4 years ago |
|
8258636925 |
!1555 code_sync_0420
From: @ding_fei_fei Reviewed-by: @xsmq,@liucunwei,@liucunwei Signed-off-by: @liucunwei,@liucunwei |
4 years ago |
|
55f1e60969 | code_sync_0420 | 4 years ago |
|
656d909b35 | Merge remote-tracking branch 'origin/release' into code_sync_0420 | 4 years ago |
|
40e5c42a12 |
!1267 workspace of comm op can be reused
From: @zhoufeng54 Reviewed-by: @youui,@liujunzhu Signed-off-by: @youui |
4 years ago |
|
d6308151e0 |
reuse workspace memory of hccl op
Signed-off-by: zhoufeng <zhoufeng54@huawei.com> |
4 years ago |
|
f65be61197 |
!1250 update ge
From: @shenwei41 Reviewed-by: @lilongfei15,@xsmq Signed-off-by: @xsmq |
4 years ago |
|
6a82dd1947 | update ge | 4 years ago |
|
6ce82eff9b |
!1242 update ge Compile Error
From: @shenwei41 Reviewed-by: @lilongfei15,@xsmq Signed-off-by: @xsmq |
4 years ago |
|
948000fe5b | fix ge Compile Error | 4 years ago |
|
c0f3dcb4f4 |
!1241 update include header files 0311
From: @shenwei41 Reviewed-by: @lilongfei15,@xsmq Signed-off-by: @xsmq |
4 years ago |
|
54a48678ae | update include headers 0311 | 4 years ago |
|
8737b1843d |
!1234 update ge_0311
From: @shenwei41 Reviewed-by: @xsmq Signed-off-by: |
4 years ago |
|
10faa7b052 | change | 4 years ago |
|
5dd83a9abd | Merge remote-tracking branch 'origin/release' into code_sync_0311 | 4 years ago |
|
92286b21ec |
!1155 update metadef
From: @shenwei41 Reviewed-by: @xsmq,@lilongfei15 Signed-off-by: @lilongfei15 |
4 years ago |
|
a1795d6554 | update metdef | 4 years ago |
|
f37d94ba21 |
!1153 update_headerfiles_0225
From: @shenwei41 Reviewed-by: @lilongfei15,@xsmq Signed-off-by: @xsmq |
4 years ago |
|
3e5f9dac74 | update include headers_0225 | 4 years ago |
|
d487b4781d |
!1150 update graphengine_0225
From: @shenwei41 Reviewed-by: @lilongfei15,@xsmq Signed-off-by: @xsmq |
4 years ago |
|
0ebcb55f3f | Merge remote-tracking branch 'origin/release' into code_sync_0225 | 4 years ago |
|
210a007d8e |
!1139 update_headerfiles_0222
From: @shenwei41 Reviewed-by: @xsmq,@lilongfei15 Signed-off-by: @lilongfei15 |
4 years ago |
|
cc456d5803 | update header files 0222 | 4 years ago |
|
b9050a8c51 |
!1133 update graphengine_0222
From: @shenwei41 Reviewed-by: @lilongfei15,@ljl0711 Signed-off-by: @ljl0711 |
4 years ago |
|
028f47826e | Merge remote-tracking branch 'origin/release' into code_sync_0222 | 4 years ago |
|
6ebc4745de |
!1104 modify pkg lib
From: @changzherui Reviewed-by: @ljl0711,@liujunzhu Signed-off-by: @liujunzhu |
4 years ago |
|
aa4045b31d | modify pkg lib | 4 years ago |
|
8dc712ca01 |
!1096 update
From: @shenwei41 Reviewed-by: @lilongfei15,@xsmq Signed-off-by: @xsmq |
4 years ago |
|
de4224d3ba | update | 4 years ago |
|
d28090d511 |
!1095 update metadef
From: @shenwei41 Reviewed-by: @xsmq,@lilongfei15 Signed-off-by: @lilongfei15 |
4 years ago |
|
543fd0a0e8 | update metadef | 4 years ago |
|
06894dcd48 |
!1093 update prebuild
From: @shenwei41 Reviewed-by: @xsmq,@lilongfei15 Signed-off-by: @lilongfei15 |
4 years ago |
|
29e380abf6 | update prebuild | 4 years ago |
|
e7036bb984 |
!1079 update include file
From: @shenwei41 Reviewed-by: @lilongfei15,@liucunwei Signed-off-by: @liucunwei |
4 years ago |
|
cffc6b2e2e | update include file | 4 years ago |
|
ea0e2eadad |
!1071 update commit id
From: @shenwei41 Reviewed-by: @liujunzhu,@lilongfei15 Signed-off-by: @lilongfei15 |
4 years ago |
|
c332519b2e | Merge remote-tracking branch 'origin/release' into code_sync_0203 | 4 years ago |
|
08e4e60663 |
!1056 modify error_codes.h
From: @changzherui Reviewed-by: @lilongfei15,@liujunzhu Signed-off-by: @liujunzhu |
4 years ago |
|
b4539d54cd | modify error_codes.h | 4 years ago |
|
2d95c17f11 |
!1040 sync code 0126
From: @changzherui Reviewed-by: @liujunzhu,@lilongfei15 Signed-off-by: @liucunwei |
4 years ago |
|
46dbe7a2fc | Merge remote-tracking branch 'upstream/release' into code_sync_126 | 4 years ago |
|
dfa4dd4acd |
!1024 Synchronization code .h 0125
From: @changzherui Reviewed-by: @guoqi1024,@liujunzhu Signed-off-by: @guoqi1024 |
4 years ago |
|
ed7e35f927 | sync code 0125 .h | 4 years ago |
|
7183c03452 |
!1022 Synchronization code 0125
From: @changzherui Reviewed-by: @ljl0711,@guoqi1024 Signed-off-by: @guoqi1024 |
4 years ago |
|
a8b7570e71 | Merge remote-tracking branch 'upstream/release' into code_sync_0125 | 4 years ago |
|
687d5f2ab1 |
!966 sync code 0116 h
From: @changzherui Reviewed-by: @liujunzhu,@ljl0711 Signed-off-by: @ljl0711 |
4 years ago |
|
e3b32cd2a0 | clang-format | 4 years ago |
|
411e71f1f3 | sync code h | 4 years ago |
|
50f17e37d9 |
!963 code sync 0116
From: @changzherui Reviewed-by: @liujunzhu,@guoqi1024 Signed-off-by: @guoqi1024 |
4 years ago |
|
2b200b25ed | sync code 0116 | 4 years ago |
|
20a0326976 |
!567 prioritize json downloading from gitee
From: @nicholas_yhr Reviewed-by: @liujunzhu,@youui Signed-off-by: @youui |
4 years ago |
|
d77f36e017 | prioritize json downloading from gitee | 4 years ago |
|
75572bb987 |
!562 Synchronize latest Ascend software suite 09 Dec 2020
From: @nicholas_yhr Reviewed-by: @liujunzhu,@ljl0711 Signed-off-by: @ljl0711 |
4 years ago |
|
d5a82a7f98 | Synchronize latest Ascend software suite 09 Dec 2020 | 4 years ago |
|
24b53b9282 |
!546 fix geruntime missing files and error codes
From: @nicholas_yhr Reviewed-by: @youui,@liujunzhu Signed-off-by: @liujunzhu |
4 years ago |
|
d731918198 | fix geruntime missing files and error codes | 4 years ago |
|
8712387b3c |
!533 update headers for release branch, for mindspore use
From: @nicholas_yhr Reviewed-by: @youui,@liujunzhu Signed-off-by: @liujunzhu |
4 years ago |
|
22ac2e8c14 |
!537 find libraries from both atc and fwk paths
From: @nicholas_yhr Reviewed-by: @youui,@liujunzhu Signed-off-by: @liujunzhu |
4 years ago |
|
fec2e70eda | find libraries from both atc and fwk paths | 4 years ago |
|
8011e1ea9a | update headers | 4 years ago |
@@ -1,8 +1,8 @@ | |||
[submodule "parser"] | |||
path = parser | |||
url = https://gitee.com/ascend/parser.git | |||
branch = master | |||
branch = r1.5.0 | |||
[submodule "metadef"] | |||
path = metadef | |||
url = https://gitee.com/ascend/metadef.git | |||
branch = master | |||
branch = r1.5.0 |
@@ -95,6 +95,7 @@ else () | |||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||
else() | |||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||
find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||
if(PLATFORM STREQUAL "train") | |||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||
@@ -128,7 +129,6 @@ else () | |||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||
endif() | |||
endif() | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||
@@ -161,6 +161,7 @@ else () | |||
elseif(ENABLE_MS_TESTCASES) | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/json.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
@@ -178,5 +179,4 @@ else () | |||
endif() | |||
add_subdirectory(ge) | |||
endif () | |||
endif () |
@@ -355,13 +355,13 @@ generate_package() | |||
if [ "x${PLATFORM}" = "xtrain" ] | |||
then | |||
tar -cf graphengine_lib.tar fwkacllib | |||
tar -zcf graphengine_lib.tar fwkacllib | |||
elif [ "x${PLATFORM}" = "xinference" ] | |||
then | |||
tar -cf graphengine_lib.tar acllib atc | |||
tar -zcf graphengine_lib.tar acllib atc | |||
elif [ "x${PLATFORM}" = "xall" ] | |||
then | |||
tar -cf graphengine_lib.tar fwkacllib acllib atc | |||
tar -zcf graphengine_lib.tar fwkacllib acllib atc | |||
fi | |||
} | |||
@@ -371,6 +371,6 @@ elif [ "X$MINDSPORE_MODE" = "Xon" ] | |||
then | |||
cd "${OUTPUT_PATH}" | |||
find ./ -name graphengine_lib.tar -exec rm {} \; | |||
tar -cf graphengine_lib.tar lib | |||
tar -zcf graphengine_lib.tar lib | |||
fi | |||
echo "---------------- GraphEngine package archive generated ----------------" |
@@ -9,10 +9,6 @@ if (GE_PB_PKG) | |||
set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip") | |||
set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||
set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||
#elseif (ENABLE_GITEE) | |||
# set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") | |||
# set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") | |||
#set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") | |||
else() | |||
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") | |||
set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||
@@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||
endif() | |||
if (GE_PB_PKG) | |||
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") | |||
else() | |||
if (ENABLE_GITEE) | |||
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | |||
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | |||
else() | |||
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||
set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") | |||
set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | |||
endif () | |||
endif() | |||
@@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/ | |||
set(INSTALL_BASE_DIR "") | |||
set(INSTALL_LIBRARY_DIR lib) | |||
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL | |||
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL | |||
DESTINATION ${INSTALL_LIBRARY_DIR}) | |||
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL | |||
DESTINATION ${INSTALL_LIBRARY_DIR}) | |||
@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||
endif() | |||
if(GE_PB_PKG) | |||
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") | |||
else() | |||
if (ENABLE_GITEE) | |||
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | |||
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | |||
else() | |||
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||
set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") | |||
set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | |||
endif () | |||
endif() | |||
@@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | |||
ExternalProject_Add(protobuf_static_build | |||
URL ${REQ_URL} | |||
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | |||
TLS_VERIFY OFF | |||
CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | |||
@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||
endif() | |||
if(GE_PB_PKG) | |||
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") | |||
else() | |||
if (ENABLE_GITEE) | |||
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | |||
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | |||
else() | |||
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||
set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") | |||
set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | |||
endif () | |||
endif() | |||
@@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst | |||
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||
ExternalProject_Add(protoc_build | |||
URL ${REQ_URL} | |||
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | |||
TLS_VERIFY OFF | |||
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | |||
BUILD_COMMAND $(MAKE) | |||
@@ -174,6 +174,7 @@ set(TRAIN_SRC_LIST | |||
"graph/load/model_manager/task_info/model_exit_task_info.cc" | |||
"graph/load/model_manager/task_info/event_record_task_info.cc" | |||
"graph/load/model_manager/task_info/event_wait_task_info.cc" | |||
"graph/load/model_manager/task_info/ffts_task_info.cc" | |||
"graph/load/model_manager/task_info/fusion_start_task_info.cc" | |||
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" | |||
"graph/load/model_manager/task_info/hccl_task_info.cc" | |||
@@ -433,6 +434,7 @@ set(TRAIN_SRC_LIST | |||
"graph/build/memory/max_block_mem_assigner.cc" | |||
"graph/build/memory/var_mem_assign_util.cc" | |||
"graph/build/memory/buffer_pool_mem_assigner.cc" | |||
"ge_opt_info/ge_opt_info.cc" | |||
) | |||
set(INFER_SRC_LIST | |||
@@ -662,6 +664,7 @@ set(INFER_SRC_LIST | |||
"graph/load/model_manager/task_info/task_info.cc" | |||
"graph/load/model_manager/task_info/event_record_task_info.cc" | |||
"graph/load/model_manager/task_info/event_wait_task_info.cc" | |||
"graph/load/model_manager/task_info/ffts_task_info.cc" | |||
"graph/load/model_manager/task_info/fusion_start_task_info.cc" | |||
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" | |||
"graph/load/model_manager/task_info/kernel_ex_task_info.cc" | |||
@@ -709,6 +712,7 @@ set(INFER_SRC_LIST | |||
"graph/build/memory/max_block_mem_assigner.cc" | |||
"graph/build/memory/var_mem_assign_util.cc" | |||
"graph/build/memory/buffer_pool_mem_assigner.cc" | |||
"ge_opt_info/ge_opt_info.cc" | |||
) | |||
if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | |||
@@ -770,11 +774,13 @@ target_include_directories(ge_runner SYSTEM PRIVATE | |||
${GE_CODE_DIR}/../inc/cce | |||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||
${GE_CODE_DIR}/../abl/adump/external | |||
${GE_CODE_DIR}/../abl/licctrl | |||
#### blue zone | |||
${ASCEND_DIR}/driver/include | |||
${ASCEND_DIR}/fwkacllib/include | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info | |||
) | |||
target_link_options(ge_runner PRIVATE | |||
@@ -797,6 +803,7 @@ target_link_libraries(ge_runner PRIVATE | |||
runtime | |||
error_manager | |||
ascend_hal_stub | |||
opt_feature | |||
-Wl,--as-needed | |||
json | |||
-lrt | |||
@@ -851,11 +858,13 @@ target_include_directories(ge_compiler SYSTEM PRIVATE | |||
${GE_CODE_DIR}/../inc/cce | |||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||
${GE_CODE_DIR}/../abl/adump/external | |||
${GE_CODE_DIR}/../abl/licctrl | |||
#### blue zone #### | |||
${ASCEND_DIR}/driver/include | |||
${ASCEND_DIR}/fwkacllib/include | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info | |||
) | |||
target_link_options(ge_compiler PRIVATE | |||
@@ -875,6 +884,7 @@ target_link_libraries(ge_compiler PRIVATE | |||
error_manager | |||
slog | |||
runtime_compile | |||
opt_feature | |||
-Wl,--as-needed | |||
json | |||
-lrt | |||
@@ -1 +0,0 @@ | |||
../../proto/ge_api.proto |
@@ -1,193 +0,0 @@ | |||
syntax = "proto3"; | |||
package ge.proto; | |||
enum DataType | |||
{ | |||
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
DT_FLOAT = 1; // float type | |||
DT_FLOAT16 = 2; // fp16 type | |||
DT_INT8 = 3; // int8 type | |||
DT_UINT8 = 4; // uint8 type | |||
DT_INT16 = 5; // int16 type | |||
DT_UINT16 = 6; // uint16 type | |||
DT_INT32 = 7; // | |||
DT_INT64 = 8; // int64 type | |||
DT_UINT32 = 9; // unsigned int32 | |||
DT_UINT64 = 10; // unsigned int64 | |||
DT_BOOL = 11; // bool type | |||
DT_DOUBLE = 12; // double type | |||
DT_STRING = 13; // string type | |||
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
DT_COMPLEX64 = 16; // complex64 type | |||
DT_COMPLEX128 = 17; // complex128 type | |||
DT_QINT8 = 18; // qint8 type | |||
DT_QINT16 = 19; // qint16 type | |||
DT_QINT32 = 20; // qint32 type | |||
DT_QUINT8 = 21; // quint8 type | |||
DT_QUINT16 = 22; // quint16 type | |||
DT_RESOURCE = 23; // resource type | |||
DT_STRING_REF = 24; // string_ref type | |||
DT_DUAL = 25; /**< dual output type */ | |||
DT_VARIANT = 26; // variant type | |||
DT_BF16 = 27; // bf16 type | |||
DT_INT4 = 28; // int4 type | |||
} | |||
message AttrDef | |||
{ | |||
message ListValue | |||
{ | |||
enum ListValueType{ | |||
VT_LIST_NONE = 0; | |||
VT_LIST_STRING = 1; | |||
VT_LIST_INT = 2; | |||
VT_LIST_FLOAT = 3; | |||
VT_LIST_BOOL = 4; | |||
VT_LIST_BYTES = 5; | |||
VT_LIST_TENSOR_DESC = 6; | |||
VT_LIST_TENSOR = 7; | |||
VT_LIST_GRAPH = 8; | |||
VT_LIST_NAMED_ATTRS = 9; | |||
VT_LIST_DATA_TYPE = 10; | |||
} | |||
repeated bytes s = 2; // "list(string)" | |||
repeated int64 i = 3; // "list(int)" | |||
repeated float f = 4; // "list(float)" | |||
repeated bool b = 5; // "list(bool)" | |||
repeated bytes bt = 7; | |||
repeated TensorDescriptor td = 8; | |||
repeated TensorDef t = 9; | |||
repeated GraphDef g = 10; | |||
repeated NamedAttrs na = 11; | |||
repeated int64 dt = 12; // list ge::DataType | |||
ListValueType val_type = 20; | |||
} | |||
message ListListInt{ | |||
message ListInt{ | |||
repeated int64 list_i = 1; // list int | |||
} | |||
repeated ListInt list_list_i = 1; // list list int | |||
} | |||
oneof value | |||
{ | |||
bytes s = 2; // "string" | |||
int64 i = 3; // "int" | |||
float f = 4; // "float" | |||
bool b = 5; // "bool" | |||
bytes bt = 7; | |||
ListValue list = 1; // any "list(...)" | |||
NamedAttrs func = 10; // Used to support attr nesting | |||
TensorDescriptor td = 11; // GeTensorDesc type | |||
TensorDef t = 12; // GeTensor type | |||
GraphDef g = 13; // Graph type | |||
ListListInt list_list_int = 14; // List List Int type | |||
int64 dt = 15; // ge::DataType | |||
} | |||
} | |||
// A list of attr names and their values. The whole list is attached | |||
// with a string name. E.g., MatMul[T=float]. | |||
message NamedAttrs | |||
{ | |||
string name = 1; | |||
map<string, AttrDef> attr = 2; | |||
} | |||
// Shape / dimension description, using row-major order | |||
message ShapeDef | |||
{ | |||
repeated int64 dim = 1; // Size of each dimension | |||
} | |||
// Multidimensional data description | |||
message TensorDescriptor | |||
{ | |||
string name = 1; // Optional parameter, tensor name | |||
DataType dtype = 2; // tensor datatype | |||
ShapeDef shape = 3; // Shape / dimension | |||
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
bool has_out_attr = 9; | |||
int64 size = 10; | |||
int64 weight_size = 11; | |||
bool reuse_input = 12; | |||
bool output_tensor = 13; | |||
string device_type = 14; | |||
bool input_tensor =15; | |||
int64 real_dim_cnt = 16; | |||
int64 reuse_input_index = 17; | |||
int64 data_offset = 18; | |||
int64 cmps_size = 19; | |||
string cmps_tab = 20; | |||
int64 cmps_tab_offset = 21; | |||
map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
} | |||
// GeTensor definition | |||
message TensorDef | |||
{ | |||
TensorDescriptor desc = 1; // Tensor description | |||
bytes data = 2; // Tensor data | |||
} | |||
// Operator description | |||
message OpDef | |||
{ | |||
string name = 1; // name | |||
string type = 2; // type | |||
repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
bool has_out_attr = 20; | |||
int64 id = 21; | |||
int64 stream_id =22; | |||
repeated string input_name = 23; | |||
repeated string src_name = 24; | |||
repeated int64 src_index = 25; | |||
repeated string dst_name = 26; | |||
repeated int64 dst_index = 27; | |||
repeated int64 input_i = 28; | |||
repeated int64 output_i = 29; | |||
repeated int64 workspace = 30; | |||
repeated int64 workspace_bytes = 31; | |||
repeated bool is_input_const = 32; | |||
repeated TensorDescriptor input_desc = 33; | |||
repeated TensorDescriptor output_desc = 34; | |||
repeated string subgraph_name = 35; | |||
} | |||
// Graph definition | |||
message GraphDef | |||
{ | |||
string name = 1; // name | |||
repeated string input = 4; // Graph input | |||
repeated string output = 5; // Graph output | |||
repeated OpDef op = 6; // List of operators | |||
map<string, AttrDef> attr = 11; // Extended field | |||
} | |||
// model definition | |||
message ModelDef | |||
{ | |||
string name = 1; // name | |||
uint32 version = 2; // IR Proto verion | |||
string custom_version = 3; // User model version number, passed in by user | |||
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
map<string, AttrDef> attr = 11; // Extended field | |||
} | |||
@@ -1,140 +0,0 @@ | |||
syntax = "proto3"; | |||
package domi; | |||
message InsertNewOps { | |||
repeated AippOpParams aipp_op = 1; | |||
repeated MultiShapeOpParams multi_shape_op = 2; | |||
} | |||
message AippOpParams { | |||
enum InputFormat { | |||
UNDEFINED = 0; | |||
YUV420SP_U8 = 1; | |||
XRGB8888_U8 = 2; | |||
RGB888_U8 = 3; | |||
YUV400_U8 = 4; | |||
NC1HWC0DI_FP16 = 5; | |||
NC1HWC0DI_S8 = 6; | |||
ARGB8888_U8 = 7; | |||
YUYV_U8 = 8; | |||
YUV422SP_U8 = 9; | |||
AYUV444_U8 = 10; | |||
RAW10 = 11; | |||
RAW12 = 12; | |||
RAW16 = 13; | |||
RAW24 = 14; | |||
RGB16 = 15; | |||
RGB20 = 16; | |||
RGB24 = 17; | |||
RGB8_IR = 18; | |||
RGB16_IR = 19; | |||
RGB24_IR = 20; | |||
} | |||
enum AippMode { | |||
undefined = 0; | |||
static = 1; | |||
dynamic = 2; | |||
} | |||
// AIPP模式,区分静态AIPP和动态AIPP | |||
AippMode aipp_mode = 1; | |||
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
uint32 related_input_rank = 2; | |||
// related_input_name is optional and the top name of data node which inserts aipp | |||
string related_input_name = 6; | |||
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
// 配置值 <= Data算子输出边的个数。 | |||
repeated uint32 input_edge_idx = 3; | |||
// [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
uint32 max_src_image_size = 4; | |||
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
bool support_rotation = 5; | |||
// [End] 动态AIPP参数 | |||
// [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
InputFormat input_format = 51; | |||
bool csc_switch = 52; | |||
float cpadding_value = 53; | |||
bool rbuv_swap_switch = 54; | |||
bool ax_swap_switch = 55; | |||
bool single_line_mode = 56; | |||
int32 src_image_size_w = 57; | |||
int32 src_image_size_h = 58; | |||
bool crop = 59; | |||
int32 load_start_pos_w = 60; | |||
int32 load_start_pos_h = 61; | |||
int32 crop_size_w = 62; | |||
int32 crop_size_h = 63; | |||
bool resize = 64; | |||
int32 resize_output_w = 65; | |||
int32 resize_output_h = 66; | |||
bool padding = 67; | |||
int32 left_padding_size = 68; | |||
int32 right_padding_size = 69; | |||
int32 top_padding_size = 70; | |||
int32 bottom_padding_size = 71; | |||
float padding_value = 72; | |||
int32 mean_chn_0 = 10; | |||
int32 mean_chn_1 = 11; | |||
int32 mean_chn_2 = 12; | |||
int32 mean_chn_3 = 19; | |||
float min_chn_0 = 13; | |||
float min_chn_1 = 14; | |||
float min_chn_2 = 15; | |||
float min_chn_3 = 20; | |||
repeated float var_reci_chn_0 = 16; | |||
repeated float var_reci_chn_1 = 17; | |||
repeated float var_reci_chn_2 = 18; | |||
repeated float var_reci_chn_3 = 21; | |||
repeated int32 matrix_r0c0 = 30; | |||
repeated int32 matrix_r0c1 = 31; | |||
repeated int32 matrix_r0c2 = 32; | |||
repeated int32 matrix_r1c0 = 33; | |||
repeated int32 matrix_r1c1 = 34; | |||
repeated int32 matrix_r1c2 = 35; | |||
repeated int32 matrix_r2c0 = 36; | |||
repeated int32 matrix_r2c1 = 37; | |||
repeated int32 matrix_r2c2 = 38; | |||
repeated int32 output_bias_0 = 39; | |||
repeated int32 output_bias_1 = 40; | |||
repeated int32 output_bias_2 = 41; | |||
repeated int32 input_bias_0 = 42; | |||
repeated int32 input_bias_1 = 43; | |||
repeated int32 input_bias_2 = 44; | |||
// [End] 静态AIPP参数 | |||
// The n number that is used for raw/rgbir data into f16 transformation. | |||
// The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
uint32 raw_rgbir_to_f16_n = 45; | |||
} | |||
message MultiShapeOpParams { | |||
enum MultiShapeMode { | |||
batch = 0; //动态batch | |||
resolution = 1; //动态分辨率,扩展用 | |||
} | |||
MultiShapeMode mode = 1; //算子模式 | |||
uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
} |
@@ -1,396 +0,0 @@ | |||
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
enum TargetType | |||
{ | |||
MINI = 0; | |||
TINY = 1; | |||
LITE = 2; | |||
} | |||
// offline model | |||
message ModelDef { | |||
string name = 1; | |||
uint32 version = 2; | |||
uint64 memory_size = 10; | |||
uint32 stream_num = 11; | |||
uint32 event_num = 12; | |||
uint64 weight_size = 13; | |||
uint32 label_num = 15; | |||
repeated OpDef op = 20; | |||
TargetType target_type = 23; | |||
map<string, AttrDef> attr = 30; | |||
}; | |||
// operator define | |||
message OpDef { | |||
string name = 1; | |||
string type = 2; | |||
uint32 id = 3; | |||
uint32 stream_id = 4; | |||
repeated string input_name = 5; | |||
repeated string src_name = 8; | |||
repeated int32 src_index = 9; | |||
repeated int64 input = 10; | |||
repeated int64 output = 11; | |||
repeated TensorDescriptor input_desc = 12; | |||
repeated TensorDescriptor output_desc = 13; | |||
repeated WeightDef weights = 14; | |||
repeated string dst_name = 15; | |||
repeated int32 dst_index = 16; | |||
repeated int64 workspace = 20; | |||
repeated uint32 workspace_bytes = 21; | |||
repeated string weight_name = 22; | |||
repeated bool is_input_const = 23; | |||
map<string, AttrDef> attr = 30; | |||
QuantizeFactorParams quantize_factor = 31; | |||
oneof op_params { | |||
// start at 100 here | |||
SendOpParams sender_param = 100; | |||
RecvOpParams receiver_param = 200; | |||
ConvolutionOpParams convolution_param = 300; | |||
PoolingOpParams pooling_param = 400; | |||
EltwiseOpParams eltwise_param = 500; | |||
BatchNormOpParams batchnorm_param = 600; | |||
ScaleOpParams scale_param = 700; | |||
FullConnectionOpParams full_connection_param = 800; | |||
SoftmaxOpParams softmax_param = 900; | |||
ActivationOpParams activation_param = 1000; | |||
ReshapeOpParams reshape_param = 1100; | |||
} | |||
}; | |||
message SendOpParams { | |||
uint32 event_id = 1; | |||
}; | |||
message RecvOpParams { | |||
uint32 event_id = 1; | |||
}; | |||
enum QuantizeScaleType | |||
{ | |||
VECTOR_SCALE = 0; | |||
SCALAR_SCALE = 1; | |||
} | |||
enum QuantizeScaleMode | |||
{ | |||
NORMAL_MODE = 0; | |||
SQRT_MODE = 1; | |||
} | |||
enum QuantizeAlgorithm | |||
{ | |||
NON_OFFSET_ALGO = 0; | |||
HALF_OFFSET_ALGO = 1; | |||
ALL_OFFSET_ALGO = 2; | |||
} | |||
message QuantizeFactor | |||
{ | |||
QuantizeScaleMode scale_mode = 1; | |||
bytes scale_value = 2; | |||
int64 scale_offset = 3; | |||
bytes offset_data_value = 4; | |||
int64 offset_data_offset = 5; | |||
bytes offset_weight_value = 6; | |||
int64 offset_weight_offset = 7; | |||
bytes offset_pad_value = 8; | |||
int64 offset_pad_offset = 9; | |||
}; | |||
message QuantizeCalcFactor | |||
{ | |||
bytes offsetw = 1; | |||
int64 offsetw_offset = 2; | |||
bytes offsetd = 3; | |||
int64 offsetd_offset = 4; | |||
bytes scalereq = 5; | |||
int64 scaledreq_offset = 6; | |||
bytes offsetdnext = 7; | |||
int64 offsetdnext_offset = 8; | |||
} | |||
message QuantizeFactorParams | |||
{ | |||
QuantizeAlgorithm quantize_algo = 1; | |||
QuantizeScaleType scale_type = 2; | |||
QuantizeFactor quantize_param = 3; | |||
QuantizeFactor dequantize_param = 4; | |||
QuantizeFactor requantize_param = 5; | |||
QuantizeCalcFactor quantizecalc_param = 6; | |||
}; | |||
message ConvolutionOpParams { | |||
int32 mode = 1; | |||
int32 algo = 2; | |||
int32 pad_mode = 3; | |||
uint32 group = 4; | |||
uint32 num_output = 5; | |||
repeated uint32 pad = 10; | |||
repeated uint32 stride = 11; | |||
repeated uint32 dilation = 12; | |||
repeated uint32 kernel = 13; | |||
float alpha = 20; | |||
float beta = 21; | |||
WeightDef filter = 40; | |||
WeightDef bias = 41; | |||
bool relu_flag = 62; | |||
repeated uint32 adj = 70; | |||
repeated uint32 target_shape = 71; | |||
repeated uint32 before_pad = 72; | |||
}; | |||
message PoolingOpParams { | |||
int32 mode = 1; | |||
int32 nan_opt = 2; | |||
int32 pad_mode = 3; | |||
bool global_pooling = 4; | |||
repeated uint32 window = 10; | |||
repeated uint32 pad = 11; | |||
repeated uint32 stride = 12; | |||
bool ceil_mode = 13; | |||
int32 data_mode = 14; | |||
float alpha = 20; | |||
float beta = 21; | |||
repeated uint32 before_pad = 22; | |||
}; | |||
message EltwiseOpParams { | |||
int32 mode = 1; | |||
repeated float coeff = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
repeated WeightDef weight = 5; | |||
bool relu_flag = 6; | |||
}; | |||
message ActivationOpParams { | |||
int32 mode = 1; | |||
float coef = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
}; | |||
message BatchNormOpParams { | |||
int32 mode = 1; | |||
float alpha = 2; | |||
float beta = 3; | |||
double epsilon = 4;//optinal,[default = 1e-5] | |||
bool use_global_stats = 5; //optinal,by default true,testing mode | |||
float moving_average_fraction = 6; //optinal,[default = .999]; | |||
WeightDef estimated_mean = 7; | |||
WeightDef estimated_variance = 8; | |||
WeightDef scale = 9; | |||
WeightDef bias = 10; | |||
}; | |||
message ScaleOpParams { | |||
WeightDef scale = 1; | |||
WeightDef bias = 2; | |||
}; | |||
message ReshapeOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
ShapeDef shape = 3; | |||
int32 axis = 4; | |||
int32 num_axes = 5; | |||
int32 format = 6; | |||
}; | |||
message SoftmaxOpParams { | |||
int32 algo = 1; | |||
int32 mode = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
}; | |||
message FullConnectionOpParams { | |||
WeightDef filter = 1; | |||
WeightDef bias = 2; | |||
uint32 num_output = 3; | |||
bool relu_flag = 12; | |||
}; | |||
message FlattenOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 start_axis = 3; | |||
int32 end_axis = 4; | |||
} | |||
message AddLimitedOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 axis = 3; | |||
bool broadcast = 4; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message MulLimitedOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 axis = 3; | |||
bool broadcast = 4; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message AddOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message MulOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message SubOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message BiasAddOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
WeightDef bias = 10; | |||
}; | |||
message MatMulOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
bool transposeX = 3; | |||
bool transposeW = 4; | |||
WeightDef filter = 10; | |||
WeightDef bias = 12; | |||
}; | |||
message RsqrtOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
}; | |||
message WeightDef { | |||
int32 format = 1; | |||
int32 data_type = 2; | |||
ShapeDef shape = 3; | |||
bytes data = 4; | |||
int64 data_offset = 5; | |||
uint32 cmps_size = 6; | |||
bytes cmps_tab = 7; | |||
int64 cmps_tab_offset = 10; | |||
CompressInfo cmps_info = 8; | |||
AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||
} | |||
message ShapeDef { | |||
repeated int64 dim = 1; | |||
} | |||
enum DeviceType { | |||
NPU = 0; // In default, we will use NPU. | |||
CPU = 1; // CPU | |||
} | |||
message AllOffsetQuantizeInfo { | |||
float scale = 1; | |||
int32 offset = 2; | |||
} | |||
message TensorDescriptor { | |||
int32 format = 1; | |||
int32 data_type = 2; | |||
repeated int64 dim = 3; | |||
uint32 size = 4; | |||
bool reuse_input = 5; | |||
bool output_tensor = 7; | |||
DeviceType device_type = 8; | |||
bool input_tensor = 9; | |||
uint32 real_dim_cnt = 10; | |||
uint32 reuse_input_index = 11; | |||
AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||
} | |||
message CompressInfo { | |||
int32 blockRow = 1; // block row | |||
int32 blockCol = 2; // block col | |||
int32 fractalK = 3; // fractal K | |||
int32 fractalN = 4; // fractal N | |||
int32 lastFractalK = 5; // K of last fractal | |||
int32 lastFractalN = 6; // N of last fractal | |||
int32 cubeSize = 7; // cube's length | |||
int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||
} | |||
message AttrDef { | |||
message ListValue { | |||
repeated string s = 2; // "list(string)" | |||
repeated int64 i = 3 [packed = true]; // "list(int)" | |||
repeated float f = 4 [packed = true]; // "list(float)" | |||
repeated bool b = 5 [packed = true]; // "list(bool)" | |||
repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||
repeated bytes bt = 7; | |||
} | |||
oneof value { | |||
string s = 2; // "string" | |||
int64 i = 3; // "int" | |||
float f = 4; // "float" | |||
bool b = 5; // "bool" | |||
uint32 u = 6; // "uint32" | |||
bytes bt = 7; | |||
ListValue list = 1; // any "list(...)" | |||
NamedAttrs func = 10; | |||
} | |||
} | |||
// A list of attr names and their values. The whole list is attached | |||
// with a string name. E.g., MatMul[T=float]. | |||
message NamedAttrs { | |||
string name = 1; | |||
map<string, AttrDef> attr = 2; | |||
} | |||
@@ -1,179 +0,0 @@ | |||
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
message ModelTaskDef { | |||
string version = 1; | |||
map<string, string> attr = 9; // Extended field | |||
repeated TaskDef task = 10; | |||
uint64 memory_size = 11; | |||
uint32 stream_num = 12; | |||
uint32 event_num = 13; | |||
uint64 weight_size = 14; | |||
repeated bytes op = 15; // input/output opdef in bytes | |||
uint64 base_addr = 16; // base addr | |||
uint64 weight_addr = 17; // weight addr | |||
uint32 batch_num = 18; | |||
} | |||
message TaskDef { | |||
uint32 id = 1; | |||
uint32 type = 2; | |||
uint32 stream_id = 10; | |||
uint32 event_id = 11; | |||
KernelDef kernel = 20; | |||
KernelExDef kernel_ex = 21; | |||
KernelHcclDef kernel_hccl = 25; | |||
EventExDef event_ex = 26; | |||
LogTimeStampDef log_timestamp = 28; | |||
uint32 label_id = 30; | |||
MemcpyAsyncDef memcpy_async = 31; | |||
StreamSwitchDef stream_switch = 32; | |||
StreamActiveDef stream_active = 33; | |||
bytes private_def = 34; | |||
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||
StreamSwitchNDef stream_switch_n = 36; | |||
LabelSetDef label_set = 37; | |||
LabelGotoExDef label_goto_ex = 38; | |||
LabelSwitchByIndexDef label_switch_by_index = 39; | |||
KernelDefWithHandle kernel_with_handle = 40; | |||
} | |||
message KernelDef { | |||
KernelContext context = 1; | |||
string stub_func = 10; | |||
uint32 block_dim = 11; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes sm_desc = 14; | |||
bytes flowtable = 15; | |||
string so_name = 16; | |||
string kernel_name = 17; | |||
bytes kernel_ext_info = 18; | |||
uint32 kernel_ext_info_size = 19; | |||
} | |||
message KernelDefWithHandle { | |||
KernelContext context = 1; | |||
uint64 handle = 10; | |||
string dev_func = 11; | |||
uint32 block_dim = 12; | |||
uint32 args_size = 13; | |||
bytes args = 14; | |||
bytes sm_desc = 15; | |||
string original_kernel_key = 16; | |||
string node_info = 17; | |||
} | |||
message KernelContext { | |||
uint32 kernel_type = 1; | |||
uint32 op_id = 2; // OP type in CCE | |||
uint32 kernel_func_id = 3; | |||
uint32 op_index = 4; // TE/Custom operator | |||
bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||
bytes args_offset = 6; // args offset information | |||
uint32 args_count = 7; // args count | |||
repeated uint32 origin_op_index = 8; | |||
} | |||
message KernelExDef { | |||
uint32 flags = 1; | |||
uint32 op_index = 4; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||
uint32 task_info_size = 15; | |||
bytes kernel_ext_info = 16; | |||
uint32 kernel_ext_info_size = 17; | |||
} | |||
message KernelHcclDef { | |||
uint32 op_index = 8; | |||
string hccl_type = 9; | |||
} | |||
message EventExDef { | |||
uint32 op_index = 1; | |||
uint32 event_type = 2; | |||
} | |||
message LogTimeStampDef { | |||
uint64 logid = 1; | |||
bool notify = 2; | |||
uint32 flat = 3; | |||
} | |||
message MemcpyAsyncDef { | |||
uint64 dst = 1; | |||
uint64 dst_max = 2; | |||
uint64 src = 3; | |||
uint64 count = 4; | |||
uint32 kind = 5; | |||
uint32 op_index = 6; | |||
} | |||
message StreamSwitchDef { | |||
uint32 op_index = 1; | |||
uint32 true_stream_id = 2; | |||
int64 value = 3; | |||
uint64 value_ptr = 4; | |||
uint32 data_type = 5; | |||
} | |||
message StreamActiveDef { | |||
uint32 op_index = 1; | |||
uint32 active_stream_id = 2; | |||
} | |||
message StreamSwitchNDef { | |||
uint32 op_index = 1; | |||
uint32 size = 2; | |||
repeated int64 target_value = 3; | |||
repeated uint32 true_stream_id = 4; | |||
uint32 element_size = 5; | |||
uint32 data_type = 6; | |||
} | |||
message LabelSetDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelGotoExDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelSwitchByIndexDef { | |||
uint32 op_index = 1; | |||
uint32 label_max = 2; | |||
} |
@@ -106,6 +106,7 @@ target_link_libraries(ge_common PRIVATE | |||
c_sec | |||
error_manager | |||
slog | |||
opt_feature | |||
-Wl,--as-needed | |||
json | |||
$<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt> | |||
@@ -33,7 +33,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetIn | |||
bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) { | |||
if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) { | |||
dump_properties_map_.emplace(kInferSessionId, dump_properties); | |||
dump_properties_map_[kInferSessionId] = dump_properties; | |||
GELOGI("Dump does not open"); | |||
return false; | |||
} | |||
@@ -41,7 +41,7 @@ bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump | |||
if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) && | |||
dump_config.dump_debug == kDumpoff) { | |||
dump_properties.ClearDumpPropertyValue(); | |||
dump_properties_map_.emplace(kInferSessionId, dump_properties); | |||
dump_properties_map_[kInferSessionId] = dump_properties; | |||
return false; | |||
} | |||
if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) { | |||
@@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &ex | |||
uint64_t proto_size = dump_data.ByteSizeLong(); | |||
std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | |||
GE_CHECK_NOTNULL(proto_msg); | |||
bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | |||
if (!ret || proto_size == 0) { | |||
REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); | |||
@@ -104,7 +104,15 @@ void TBEPluginManager::ProcessSoFullName(vector<string> &file_list, string &caff | |||
} | |||
} | |||
void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) { | |||
void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list, | |||
string &caffe_parser_path, int recursive_depth) { | |||
static const int kMaxRecursiveDepth = 20; // For recursive depth protection | |||
if (recursive_depth >= kMaxRecursiveDepth) { | |||
GELOGW("Recursive depth is become %d, Please check input!", recursive_depth); | |||
return; | |||
} | |||
// Path, change to absolute path | |||
string real_path = RealPath(path.c_str()); | |||
// Plugin path does not exist | |||
@@ -138,7 +146,7 @@ void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_lis | |||
ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, | |||
aicpu_host_so_suff); | |||
} else { | |||
FindParserSo(full_name, file_list, caffe_parser_path); | |||
FindParserSo(full_name, file_list, caffe_parser_path, recursive_depth + 1); | |||
} | |||
} | |||
mmScandirFree(entries, ret); | |||
@@ -57,7 +57,8 @@ class TBEPluginManager { | |||
static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | |||
const string &caffe_parser_so_suff, const string &aicpu_so_suff, | |||
const string &aicpu_host_so_suff); | |||
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path); | |||
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path, | |||
int recursive_depth = 0); | |||
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | |||
static void GetCustomOpPath(std::string &customop_path); | |||
void LoadCustomOpLib(); | |||
@@ -22,6 +22,7 @@ | |||
#include "graph/load/graph_loader.h" | |||
#include "init/gelib.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
#include "model/ge_model.h" | |||
namespace { | |||
const uint32_t kDeviceListIndex = 3; | |||
@@ -42,6 +43,10 @@ const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = { | |||
{kProfCommandhandleFinalize, kProfilingFinalize}, | |||
{kProfCommandhandleModelSubscribe, kProfModelSubscribe}, | |||
{kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; | |||
const uint64_t kModelId = ge::INVALID_MODEL_ID; | |||
const uint16_t kStepStart = 0; | |||
const uint16_t kStepEnd = 1; | |||
} // namespace | |||
bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) { | |||
@@ -216,3 +221,36 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le | |||
return ge::SUCCESS; | |||
} | |||
ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { | |||
static bool is_first_run = true; | |||
int32_t device_id = 0; | |||
rtError_t rt_ret = rtGetDevice(&device_id); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(rt_ret, "[Get][LogicDeviceId]Failed, ret 0x%X", rt_ret); | |||
REPORT_CALL_ERROR("E19999", "Get logic device id failed, ret 0x%X", rt_ret); | |||
return ge::FAILED; | |||
} | |||
if (is_first_run && tag_id == kStepStart) { | |||
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id, | |||
kModelId, | |||
tag_id, | |||
stream, | |||
device_id)); | |||
is_first_run = false; | |||
return ge::SUCCESS; | |||
} | |||
if (!is_first_run && tag_id == kStepEnd) { | |||
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id, | |||
kModelId, | |||
tag_id, | |||
stream, | |||
device_id)); | |||
is_first_run = true; | |||
return ge::SUCCESS; | |||
} | |||
GELOGE(ge::FAILED, "Param tag_id:%u invalid when is_first_run is %d", tag_id, is_first_run); | |||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"value", "parameter", "reason"}), | |||
std::vector<std::string>({std::to_string(tag_id), "tag_id", | |||
"tag id must be 0 when first run, must be 1 when second run"})); | |||
return ge::FAILED; | |||
} |
@@ -1,193 +0,0 @@ | |||
syntax = "proto3"; | |||
package ge.proto; | |||
enum DataType | |||
{ | |||
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
DT_FLOAT = 1; // float type | |||
DT_FLOAT16 = 2; // fp16 type | |||
DT_INT8 = 3; // int8 type | |||
DT_UINT8 = 4; // uint8 type | |||
DT_INT16 = 5; // int16 type | |||
DT_UINT16 = 6; // uint16 type | |||
DT_INT32 = 7; // | |||
DT_INT64 = 8; // int64 type | |||
DT_UINT32 = 9; // unsigned int32 | |||
DT_UINT64 = 10; // unsigned int64 | |||
DT_BOOL = 11; // bool type | |||
DT_DOUBLE = 12; // double type | |||
DT_STRING = 13; // string type | |||
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
DT_COMPLEX64 = 16; // complex64 type | |||
DT_COMPLEX128 = 17; // complex128 type | |||
DT_QINT8 = 18; // qint8 type | |||
DT_QINT16 = 19; // qint16 type | |||
DT_QINT32 = 20; // qint32 type | |||
DT_QUINT8 = 21; // quint8 type | |||
DT_QUINT16 = 22; // quint16 type | |||
DT_RESOURCE = 23; // resource type | |||
DT_STRING_REF = 24; // string_ref type | |||
DT_DUAL = 25; /**< dual output type */ | |||
DT_VARIANT = 26; // variant type | |||
DT_BF16 = 27; // bf16 type | |||
DT_INT4 = 28; // int4 type | |||
} | |||
message AttrDef | |||
{ | |||
message ListValue | |||
{ | |||
enum ListValueType{ | |||
VT_LIST_NONE = 0; | |||
VT_LIST_STRING = 1; | |||
VT_LIST_INT = 2; | |||
VT_LIST_FLOAT = 3; | |||
VT_LIST_BOOL = 4; | |||
VT_LIST_BYTES = 5; | |||
VT_LIST_TENSOR_DESC = 6; | |||
VT_LIST_TENSOR = 7; | |||
VT_LIST_GRAPH = 8; | |||
VT_LIST_NAMED_ATTRS = 9; | |||
VT_LIST_DATA_TYPE = 10; | |||
} | |||
repeated bytes s = 2; // "list(string)" | |||
repeated int64 i = 3; // "list(int)" | |||
repeated float f = 4; // "list(float)" | |||
repeated bool b = 5; // "list(bool)" | |||
repeated bytes bt = 7; | |||
repeated TensorDescriptor td = 8; | |||
repeated TensorDef t = 9; | |||
repeated GraphDef g = 10; | |||
repeated NamedAttrs na = 11; | |||
repeated int64 dt = 12; // list ge::DataType | |||
ListValueType val_type = 20; | |||
} | |||
message ListListInt{ | |||
message ListInt{ | |||
repeated int64 list_i = 1; // list int | |||
} | |||
repeated ListInt list_list_i = 1; // list list int | |||
} | |||
oneof value | |||
{ | |||
bytes s = 2; // "string" | |||
int64 i = 3; // "int" | |||
float f = 4; // "float" | |||
bool b = 5; // "bool" | |||
bytes bt = 7; | |||
ListValue list = 1; // any "list(...)" | |||
NamedAttrs func = 10; // Used to support attr nesting | |||
TensorDescriptor td = 11; // GeTensorDesc type | |||
TensorDef t = 12; // GeTensor type | |||
GraphDef g = 13; // Graph type | |||
ListListInt list_list_int = 14; // List List Int type | |||
int64 dt = 15; // ge::DataType | |||
} | |||
} | |||
// A list of attr names and their values. The whole list is attached | |||
// with a string name. E.g., MatMul[T=float]. | |||
message NamedAttrs | |||
{ | |||
string name = 1; | |||
map<string, AttrDef> attr = 2; | |||
} | |||
// Shape / dimension description, using row-major order | |||
message ShapeDef | |||
{ | |||
repeated int64 dim = 1; // Size of each dimension | |||
} | |||
// Multidimensional data description | |||
message TensorDescriptor | |||
{ | |||
string name = 1; // Optional parameter, tensor name | |||
DataType dtype = 2; // tensor datatype | |||
ShapeDef shape = 3; // Shape / dimension | |||
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
bool has_out_attr = 9; | |||
int64 size = 10; | |||
int64 weight_size = 11; | |||
bool reuse_input = 12; | |||
bool output_tensor = 13; | |||
string device_type = 14; | |||
bool input_tensor =15; | |||
int64 real_dim_cnt = 16; | |||
int64 reuse_input_index = 17; | |||
int64 data_offset = 18; | |||
int64 cmps_size = 19; | |||
string cmps_tab = 20; | |||
int64 cmps_tab_offset = 21; | |||
map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
} | |||
// GeTensor definition | |||
message TensorDef | |||
{ | |||
TensorDescriptor desc = 1; // Tensor description | |||
bytes data = 2; // Tensor data | |||
} | |||
// Operator description | |||
message OpDef | |||
{ | |||
string name = 1; // name | |||
string type = 2; // type | |||
repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
bool has_out_attr = 20; | |||
int64 id = 21; | |||
int64 stream_id =22; | |||
repeated string input_name = 23; | |||
repeated string src_name = 24; | |||
repeated int64 src_index = 25; | |||
repeated string dst_name = 26; | |||
repeated int64 dst_index = 27; | |||
repeated int64 input_i = 28; | |||
repeated int64 output_i = 29; | |||
repeated int64 workspace = 30; | |||
repeated int64 workspace_bytes = 31; | |||
repeated bool is_input_const = 32; | |||
repeated TensorDescriptor input_desc = 33; | |||
repeated TensorDescriptor output_desc = 34; | |||
repeated string subgraph_name = 35; | |||
} | |||
// Graph definition | |||
message GraphDef | |||
{ | |||
string name = 1; // name | |||
repeated string input = 4; // Graph input | |||
repeated string output = 5; // Graph output | |||
repeated OpDef op = 6; // List of operators | |||
map<string, AttrDef> attr = 11; // Extended field | |||
} | |||
// model definition | |||
message ModelDef | |||
{ | |||
string name = 1; // name | |||
uint32 version = 2; // IR Proto verion | |||
string custom_version = 3; // User model version number, passed in by user | |||
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
map<string, AttrDef> attr = 11; // Extended field | |||
} | |||
@@ -1,140 +0,0 @@ | |||
syntax = "proto3"; | |||
package domi; | |||
message InsertNewOps { | |||
repeated AippOpParams aipp_op = 1; | |||
repeated MultiShapeOpParams multi_shape_op = 2; | |||
} | |||
message AippOpParams { | |||
enum InputFormat { | |||
UNDEFINED = 0; | |||
YUV420SP_U8 = 1; | |||
XRGB8888_U8 = 2; | |||
RGB888_U8 = 3; | |||
YUV400_U8 = 4; | |||
NC1HWC0DI_FP16 = 5; | |||
NC1HWC0DI_S8 = 6; | |||
ARGB8888_U8 = 7; | |||
YUYV_U8 = 8; | |||
YUV422SP_U8 = 9; | |||
AYUV444_U8 = 10; | |||
RAW10 = 11; | |||
RAW12 = 12; | |||
RAW16 = 13; | |||
RAW24 = 14; | |||
RGB16 = 15; | |||
RGB20 = 16; | |||
RGB24 = 17; | |||
RGB8_IR = 18; | |||
RGB16_IR = 19; | |||
RGB24_IR = 20; | |||
} | |||
enum AippMode { | |||
undefined = 0; | |||
static = 1; | |||
dynamic = 2; | |||
} | |||
// AIPP模式,区分静态AIPP和动态AIPP | |||
AippMode aipp_mode = 1; | |||
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
uint32 related_input_rank = 2; | |||
// related_input_name is optional and the top name of data node which inserts aipp | |||
string related_input_name = 6; | |||
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
// 配置值 <= Data算子输出边的个数。 | |||
repeated uint32 input_edge_idx = 3; | |||
// [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
uint32 max_src_image_size = 4; | |||
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
bool support_rotation = 5; | |||
// [End] 动态AIPP参数 | |||
// [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
InputFormat input_format = 51; | |||
bool csc_switch = 52; | |||
float cpadding_value = 53; | |||
bool rbuv_swap_switch = 54; | |||
bool ax_swap_switch = 55; | |||
bool single_line_mode = 56; | |||
int32 src_image_size_w = 57; | |||
int32 src_image_size_h = 58; | |||
bool crop = 59; | |||
int32 load_start_pos_w = 60; | |||
int32 load_start_pos_h = 61; | |||
int32 crop_size_w = 62; | |||
int32 crop_size_h = 63; | |||
bool resize = 64; | |||
int32 resize_output_w = 65; | |||
int32 resize_output_h = 66; | |||
bool padding = 67; | |||
int32 left_padding_size = 68; | |||
int32 right_padding_size = 69; | |||
int32 top_padding_size = 70; | |||
int32 bottom_padding_size = 71; | |||
float padding_value = 72; | |||
int32 mean_chn_0 = 10; | |||
int32 mean_chn_1 = 11; | |||
int32 mean_chn_2 = 12; | |||
int32 mean_chn_3 = 19; | |||
float min_chn_0 = 13; | |||
float min_chn_1 = 14; | |||
float min_chn_2 = 15; | |||
float min_chn_3 = 20; | |||
repeated float var_reci_chn_0 = 16; | |||
repeated float var_reci_chn_1 = 17; | |||
repeated float var_reci_chn_2 = 18; | |||
repeated float var_reci_chn_3 = 21; | |||
repeated int32 matrix_r0c0 = 30; | |||
repeated int32 matrix_r0c1 = 31; | |||
repeated int32 matrix_r0c2 = 32; | |||
repeated int32 matrix_r1c0 = 33; | |||
repeated int32 matrix_r1c1 = 34; | |||
repeated int32 matrix_r1c2 = 35; | |||
repeated int32 matrix_r2c0 = 36; | |||
repeated int32 matrix_r2c1 = 37; | |||
repeated int32 matrix_r2c2 = 38; | |||
repeated int32 output_bias_0 = 39; | |||
repeated int32 output_bias_1 = 40; | |||
repeated int32 output_bias_2 = 41; | |||
repeated int32 input_bias_0 = 42; | |||
repeated int32 input_bias_1 = 43; | |||
repeated int32 input_bias_2 = 44; | |||
// [End] 静态AIPP参数 | |||
// The n number that is used for raw/rgbir data into f16 transformation. | |||
// The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
uint32 raw_rgbir_to_f16_n = 45; | |||
} | |||
message MultiShapeOpParams { | |||
enum MultiShapeMode { | |||
batch = 0; //动态batch | |||
resolution = 1; //动态分辨率,扩展用 | |||
} | |||
MultiShapeMode mode = 1; //算子模式 | |||
uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
} |
@@ -1,396 +0,0 @@ | |||
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
enum TargetType | |||
{ | |||
MINI = 0; | |||
TINY = 1; | |||
LITE = 2; | |||
} | |||
// offline model | |||
message ModelDef { | |||
string name = 1; | |||
uint32 version = 2; | |||
uint64 memory_size = 10; | |||
uint32 stream_num = 11; | |||
uint32 event_num = 12; | |||
uint64 weight_size = 13; | |||
uint32 label_num = 15; | |||
repeated OpDef op = 20; | |||
TargetType target_type = 23; | |||
map<string, AttrDef> attr = 30; | |||
}; | |||
// operator define | |||
message OpDef { | |||
string name = 1; | |||
string type = 2; | |||
uint32 id = 3; | |||
uint32 stream_id = 4; | |||
repeated string input_name = 5; | |||
repeated string src_name = 8; | |||
repeated int32 src_index = 9; | |||
repeated int64 input = 10; | |||
repeated int64 output = 11; | |||
repeated TensorDescriptor input_desc = 12; | |||
repeated TensorDescriptor output_desc = 13; | |||
repeated WeightDef weights = 14; | |||
repeated string dst_name = 15; | |||
repeated int32 dst_index = 16; | |||
repeated int64 workspace = 20; | |||
repeated uint32 workspace_bytes = 21; | |||
repeated string weight_name = 22; | |||
repeated bool is_input_const = 23; | |||
map<string, AttrDef> attr = 30; | |||
QuantizeFactorParams quantize_factor = 31; | |||
oneof op_params { | |||
// start at 100 here | |||
SendOpParams sender_param = 100; | |||
RecvOpParams receiver_param = 200; | |||
ConvolutionOpParams convolution_param = 300; | |||
PoolingOpParams pooling_param = 400; | |||
EltwiseOpParams eltwise_param = 500; | |||
BatchNormOpParams batchnorm_param = 600; | |||
ScaleOpParams scale_param = 700; | |||
FullConnectionOpParams full_connection_param = 800; | |||
SoftmaxOpParams softmax_param = 900; | |||
ActivationOpParams activation_param = 1000; | |||
ReshapeOpParams reshape_param = 1100; | |||
} | |||
}; | |||
message SendOpParams { | |||
uint32 event_id = 1; | |||
}; | |||
message RecvOpParams { | |||
uint32 event_id = 1; | |||
}; | |||
enum QuantizeScaleType | |||
{ | |||
VECTOR_SCALE = 0; | |||
SCALAR_SCALE = 1; | |||
} | |||
enum QuantizeScaleMode | |||
{ | |||
NORMAL_MODE = 0; | |||
SQRT_MODE = 1; | |||
} | |||
enum QuantizeAlgorithm | |||
{ | |||
NON_OFFSET_ALGO = 0; | |||
HALF_OFFSET_ALGO = 1; | |||
ALL_OFFSET_ALGO = 2; | |||
} | |||
message QuantizeFactor | |||
{ | |||
QuantizeScaleMode scale_mode = 1; | |||
bytes scale_value = 2; | |||
int64 scale_offset = 3; | |||
bytes offset_data_value = 4; | |||
int64 offset_data_offset = 5; | |||
bytes offset_weight_value = 6; | |||
int64 offset_weight_offset = 7; | |||
bytes offset_pad_value = 8; | |||
int64 offset_pad_offset = 9; | |||
}; | |||
message QuantizeCalcFactor | |||
{ | |||
bytes offsetw = 1; | |||
int64 offsetw_offset = 2; | |||
bytes offsetd = 3; | |||
int64 offsetd_offset = 4; | |||
bytes scalereq = 5; | |||
int64 scaledreq_offset = 6; | |||
bytes offsetdnext = 7; | |||
int64 offsetdnext_offset = 8; | |||
} | |||
message QuantizeFactorParams | |||
{ | |||
QuantizeAlgorithm quantize_algo = 1; | |||
QuantizeScaleType scale_type = 2; | |||
QuantizeFactor quantize_param = 3; | |||
QuantizeFactor dequantize_param = 4; | |||
QuantizeFactor requantize_param = 5; | |||
QuantizeCalcFactor quantizecalc_param = 6; | |||
}; | |||
message ConvolutionOpParams { | |||
int32 mode = 1; | |||
int32 algo = 2; | |||
int32 pad_mode = 3; | |||
uint32 group = 4; | |||
uint32 num_output = 5; | |||
repeated uint32 pad = 10; | |||
repeated uint32 stride = 11; | |||
repeated uint32 dilation = 12; | |||
repeated uint32 kernel = 13; | |||
float alpha = 20; | |||
float beta = 21; | |||
WeightDef filter = 40; | |||
WeightDef bias = 41; | |||
bool relu_flag = 62; | |||
repeated uint32 adj = 70; | |||
repeated uint32 target_shape = 71; | |||
repeated uint32 before_pad = 72; | |||
}; | |||
message PoolingOpParams { | |||
int32 mode = 1; | |||
int32 nan_opt = 2; | |||
int32 pad_mode = 3; | |||
bool global_pooling = 4; | |||
repeated uint32 window = 10; | |||
repeated uint32 pad = 11; | |||
repeated uint32 stride = 12; | |||
bool ceil_mode = 13; | |||
int32 data_mode = 14; | |||
float alpha = 20; | |||
float beta = 21; | |||
repeated uint32 before_pad = 22; | |||
}; | |||
message EltwiseOpParams { | |||
int32 mode = 1; | |||
repeated float coeff = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
repeated WeightDef weight = 5; | |||
bool relu_flag = 6; | |||
}; | |||
message ActivationOpParams { | |||
int32 mode = 1; | |||
float coef = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
}; | |||
message BatchNormOpParams { | |||
int32 mode = 1; | |||
float alpha = 2; | |||
float beta = 3; | |||
double epsilon = 4;//optinal,[default = 1e-5] | |||
bool use_global_stats = 5; //optinal,by default true,testing mode | |||
float moving_average_fraction = 6; //optinal,[default = .999]; | |||
WeightDef estimated_mean = 7; | |||
WeightDef estimated_variance = 8; | |||
WeightDef scale = 9; | |||
WeightDef bias = 10; | |||
}; | |||
message ScaleOpParams { | |||
WeightDef scale = 1; | |||
WeightDef bias = 2; | |||
}; | |||
message ReshapeOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
ShapeDef shape = 3; | |||
int32 axis = 4; | |||
int32 num_axes = 5; | |||
int32 format = 6; | |||
}; | |||
message SoftmaxOpParams { | |||
int32 algo = 1; | |||
int32 mode = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
}; | |||
message FullConnectionOpParams { | |||
WeightDef filter = 1; | |||
WeightDef bias = 2; | |||
uint32 num_output = 3; | |||
bool relu_flag = 12; | |||
}; | |||
message FlattenOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 start_axis = 3; | |||
int32 end_axis = 4; | |||
} | |||
message AddLimitedOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 axis = 3; | |||
bool broadcast = 4; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message MulLimitedOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 axis = 3; | |||
bool broadcast = 4; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message AddOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message MulOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message SubOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message BiasAddOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
WeightDef bias = 10; | |||
}; | |||
message MatMulOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
bool transposeX = 3; | |||
bool transposeW = 4; | |||
WeightDef filter = 10; | |||
WeightDef bias = 12; | |||
}; | |||
message RsqrtOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
}; | |||
message WeightDef { | |||
int32 format = 1; | |||
int32 data_type = 2; | |||
ShapeDef shape = 3; | |||
bytes data = 4; | |||
int64 data_offset = 5; | |||
uint32 cmps_size = 6; | |||
bytes cmps_tab = 7; | |||
int64 cmps_tab_offset = 10; | |||
CompressInfo cmps_info = 8; | |||
AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||
} | |||
message ShapeDef { | |||
repeated int64 dim = 1; | |||
} | |||
enum DeviceType { | |||
NPU = 0; // In default, we will use NPU. | |||
CPU = 1; // CPU | |||
} | |||
message AllOffsetQuantizeInfo { | |||
float scale = 1; | |||
int32 offset = 2; | |||
} | |||
message TensorDescriptor { | |||
int32 format = 1; | |||
int32 data_type = 2; | |||
repeated int64 dim = 3; | |||
uint32 size = 4; | |||
bool reuse_input = 5; | |||
bool output_tensor = 7; | |||
DeviceType device_type = 8; | |||
bool input_tensor = 9; | |||
uint32 real_dim_cnt = 10; | |||
uint32 reuse_input_index = 11; | |||
AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||
} | |||
message CompressInfo { | |||
int32 blockRow = 1; // block row | |||
int32 blockCol = 2; // block col | |||
int32 fractalK = 3; // fractal K | |||
int32 fractalN = 4; // fractal N | |||
int32 lastFractalK = 5; // K of last fractal | |||
int32 lastFractalN = 6; // N of last fractal | |||
int32 cubeSize = 7; // cube's length | |||
int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||
} | |||
message AttrDef { | |||
message ListValue { | |||
repeated string s = 2; // "list(string)" | |||
repeated int64 i = 3 [packed = true]; // "list(int)" | |||
repeated float f = 4 [packed = true]; // "list(float)" | |||
repeated bool b = 5 [packed = true]; // "list(bool)" | |||
repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||
repeated bytes bt = 7; | |||
} | |||
oneof value { | |||
string s = 2; // "string" | |||
int64 i = 3; // "int" | |||
float f = 4; // "float" | |||
bool b = 5; // "bool" | |||
uint32 u = 6; // "uint32" | |||
bytes bt = 7; | |||
ListValue list = 1; // any "list(...)" | |||
NamedAttrs func = 10; | |||
} | |||
} | |||
// A list of attr names and their values. The whole list is attached | |||
// with a string name. E.g., MatMul[T=float]. | |||
message NamedAttrs { | |||
string name = 1; | |||
map<string, AttrDef> attr = 2; | |||
} | |||
@@ -1,75 +0,0 @@ | |||
syntax = "proto3"; | |||
package toolkit.aicpu.dump; | |||
message Shape { | |||
repeated uint64 dim = 1; | |||
} | |||
message Output { | |||
int32 data_type = 1; | |||
int32 format = 2; | |||
Shape shape = 3; | |||
uint64 address = 4; | |||
string original_name = 5; | |||
int32 original_output_index = 6; | |||
int32 original_output_data_type = 7; | |||
int32 original_output_format = 8; | |||
uint64 size = 9; | |||
Shape origin_shape = 10; | |||
} | |||
message Input { | |||
int32 data_type =1; | |||
int32 format = 2; | |||
Shape shape = 3; | |||
uint64 address = 4; | |||
uint64 size = 5; | |||
Shape origin_shape = 6; | |||
} | |||
enum BufferType { | |||
L1 = 0; | |||
} | |||
message OpBuffer { | |||
BufferType buffer_type = 1; | |||
uint64 address = 2; | |||
uint64 size = 3; | |||
} | |||
message Op { | |||
string op_name = 1; | |||
string op_type = 2; | |||
} | |||
message Task { | |||
uint32 task_id = 1; | |||
uint32 stream_id = 2; | |||
Op op = 3; | |||
repeated Output output = 4; | |||
bool end_graph = 5; | |||
repeated Input input = 6; | |||
repeated OpBuffer buffer = 7; | |||
} | |||
message OpMappingInfo { | |||
string dump_path = 1; | |||
oneof model_name_param { | |||
string model_name = 2; | |||
} | |||
oneof model_id_param { | |||
uint32 model_id = 3; | |||
} | |||
oneof step_id { | |||
uint64 step_id_addr = 4; | |||
} | |||
oneof iterations_per_loop { | |||
uint64 iterations_per_loop_addr = 5; | |||
} | |||
oneof loop_cond { | |||
uint64 loop_cond_addr = 6; | |||
} | |||
uint32 flag = 7; // 0x01 load, 0x00 unload | |||
repeated Task task = 8; | |||
string dump_step = 9; | |||
} |
@@ -1,179 +0,0 @@ | |||
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
message ModelTaskDef { | |||
string version = 1; | |||
map<string, string> attr = 9; // Extended field | |||
repeated TaskDef task = 10; | |||
uint64 memory_size = 11; | |||
uint32 stream_num = 12; | |||
uint32 event_num = 13; | |||
uint64 weight_size = 14; | |||
repeated bytes op = 15; // input/output opdef in bytes | |||
uint64 base_addr = 16; // base addr | |||
uint64 weight_addr = 17; // weight addr | |||
uint32 batch_num = 18; | |||
} | |||
message TaskDef { | |||
uint32 id = 1; | |||
uint32 type = 2; | |||
uint32 stream_id = 10; | |||
uint32 event_id = 11; | |||
KernelDef kernel = 20; | |||
KernelExDef kernel_ex = 21; | |||
KernelHcclDef kernel_hccl = 25; | |||
EventExDef event_ex = 26; | |||
LogTimeStampDef log_timestamp = 28; | |||
uint32 label_id = 30; | |||
MemcpyAsyncDef memcpy_async = 31; | |||
StreamSwitchDef stream_switch = 32; | |||
StreamActiveDef stream_active = 33; | |||
bytes private_def = 34; | |||
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||
StreamSwitchNDef stream_switch_n = 36; | |||
LabelSetDef label_set = 37; | |||
LabelGotoExDef label_goto_ex = 38; | |||
LabelSwitchByIndexDef label_switch_by_index = 39; | |||
KernelDefWithHandle kernel_with_handle = 40; | |||
} | |||
message KernelDef { | |||
KernelContext context = 1; | |||
string stub_func = 10; | |||
uint32 block_dim = 11; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes sm_desc = 14; | |||
bytes flowtable = 15; | |||
string so_name = 16; | |||
string kernel_name = 17; | |||
bytes kernel_ext_info = 18; | |||
uint32 kernel_ext_info_size = 19; | |||
} | |||
message KernelDefWithHandle { | |||
KernelContext context = 1; | |||
uint64 handle = 10; | |||
string dev_func = 11; | |||
uint32 block_dim = 12; | |||
uint32 args_size = 13; | |||
bytes args = 14; | |||
bytes sm_desc = 15; | |||
string original_kernel_key = 16; | |||
string node_info = 17; | |||
} | |||
message KernelContext { | |||
uint32 kernel_type = 1; | |||
uint32 op_id = 2; // OP type in CCE | |||
uint32 kernel_func_id = 3; | |||
uint32 op_index = 4; // TE/Custom operator | |||
bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||
bytes args_offset = 6; // args offset information | |||
uint32 args_count = 7; // args count | |||
repeated uint32 origin_op_index = 8; | |||
} | |||
message KernelExDef { | |||
uint32 flags = 1; | |||
uint32 op_index = 4; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||
uint32 task_info_size = 15; | |||
bytes kernel_ext_info = 16; | |||
uint32 kernel_ext_info_size = 17; | |||
} | |||
message KernelHcclDef { | |||
uint32 op_index = 8; | |||
string hccl_type = 9; | |||
} | |||
message EventExDef { | |||
uint32 op_index = 1; | |||
uint32 event_type = 2; | |||
} | |||
message LogTimeStampDef { | |||
uint64 logid = 1; | |||
bool notify = 2; | |||
uint32 flat = 3; | |||
} | |||
message MemcpyAsyncDef { | |||
uint64 dst = 1; | |||
uint64 dst_max = 2; | |||
uint64 src = 3; | |||
uint64 count = 4; | |||
uint32 kind = 5; | |||
uint32 op_index = 6; | |||
} | |||
message StreamSwitchDef { | |||
uint32 op_index = 1; | |||
uint32 true_stream_id = 2; | |||
int64 value = 3; | |||
uint64 value_ptr = 4; | |||
uint32 data_type = 5; | |||
} | |||
message StreamActiveDef { | |||
uint32 op_index = 1; | |||
uint32 active_stream_id = 2; | |||
} | |||
message StreamSwitchNDef { | |||
uint32 op_index = 1; | |||
uint32 size = 2; | |||
repeated int64 target_value = 3; | |||
repeated uint32 true_stream_id = 4; | |||
uint32 element_size = 5; | |||
uint32 data_type = 6; | |||
} | |||
message LabelSetDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelGotoExDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelSwitchByIndexDef { | |||
uint32 op_index = 1; | |||
uint32 label_max = 2; | |||
} |
@@ -1,70 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "AttrValueProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
import "tensor.proto"; | |||
import "tensor_shape.proto"; | |||
import "types.proto"; | |||
// Protocol buffer representing the value for an attr used to configure an Op. | |||
// Comment indicates the corresponding attr type. Only the field matching the | |||
// attr type may be filled. | |||
message AttrValue { | |||
// LINT.IfChange | |||
message ListValue { | |||
repeated bytes s = 2; // "list(string)" | |||
repeated int64 i = 3 [packed = true]; // "list(int)" | |||
repeated float f = 4 [packed = true]; // "list(float)" | |||
repeated bool b = 5 [packed = true]; // "list(bool)" | |||
repeated DataType type = 6 [packed = true]; // "list(type)" | |||
repeated TensorShapeProto shape = 7; // "list(shape)" | |||
repeated TensorProto tensor = 8; // "list(tensor)" | |||
repeated NameAttrList func = 9; // "list(attr)" | |||
} | |||
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) | |||
oneof value { | |||
bytes s = 2; // "string" | |||
int64 i = 3; // "int" | |||
float f = 4; // "float" | |||
bool b = 5; // "bool" | |||
DataType type = 6; // "type" | |||
TensorShapeProto shape = 7; // "shape" | |||
TensorProto tensor = 8; // "tensor" | |||
ListValue list = 1; // any "list(...)" | |||
// "func" represents a function. func.name is a function's name or | |||
// a primitive op's name. func.attr.first is the name of an attr | |||
// defined for that function. func.attr.second is the value for | |||
// that attr in the instantiation. | |||
NameAttrList func = 10; | |||
// This is a placeholder only used in nodes defined inside a | |||
// function. It indicates the attr value will be supplied when | |||
// the function is instantiated. For example, let us suppose a | |||
// node "N" in function "FN". "N" has an attr "A" with value | |||
// placeholder = "foo". When FN is instantiated with attr "foo" | |||
// set to "bar", the instantiated node N's attr A will have been | |||
// given the value "bar". | |||
string placeholder = 9; | |||
} | |||
} | |||
// A list of attr names and their values. The whole list is attached | |||
// with a string name. E.g., MatMul[T=float]. | |||
message NameAttrList { | |||
string name = 1; | |||
map<string, AttrValue> attr = 2; | |||
} |
@@ -1,108 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "FunctionProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
import "attr_value.proto"; | |||
import "node_def.proto"; | |||
import "op_def.proto"; | |||
// A library is a set of named functions. | |||
message FunctionDefLibrary { | |||
repeated FunctionDef function = 1; | |||
repeated GradientDef gradient = 2; | |||
} | |||
// A function can be instantiated when the runtime can bind every attr | |||
// with a value. When a GraphDef has a call to a function, it must | |||
// have binding for every attr defined in the signature. | |||
// * device spec, etc. | |||
message FunctionDef { | |||
// The definition of the function's name, arguments, return values, | |||
// attrs etc. | |||
OpDef signature = 1; | |||
// Attributes specific to this function definition. | |||
map<string, AttrValue> attr = 5; | |||
// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. | |||
reserved 2; | |||
// In both of the following fields, there is the need to specify an | |||
// output that is used as either the input to another node (in | |||
// `node_def`) or as a return value of the function (in `ret`). | |||
// Unlike the NodeDefs in GraphDef, we need to be able to specify a | |||
// list in some cases (instead of just single outputs). Also, we | |||
// need to be able to deal with lists of unknown length (so the | |||
// output index may not be known at function definition time). So | |||
// we use the following format instead: | |||
// * "fun_in" where "fun_in" is the name of a function input arg in | |||
// the `signature` field above. This represents that input, whether | |||
// it is a single tensor or a list. | |||
// * "fun_in:0" gives the first element of a function input arg (a | |||
// non-list input is considered a list of length 1 for these | |||
// purposes). | |||
// * "node:out" where "node" is the name of a node in `node_def` and | |||
// "out" is the name one of its op's output arguments (the name | |||
// comes from the OpDef of the node's op). This represents that | |||
// node's output, whether it is a single tensor or a list. | |||
// Note: We enforce that an op's output arguments are never | |||
// renamed in the backwards-compatibility test. | |||
// * "node:out:0" gives the first element of a node output arg (a | |||
// non-list output is considered a list of length 1 for these | |||
// purposes). | |||
// | |||
// NOT CURRENTLY SUPPORTED (but may be in the future): | |||
// * "node:out:-1" gives last element in a node output list | |||
// * "node:out:1:" gives a list with all but the first element in a | |||
// node output list | |||
// * "node:out::-1" gives a list with all but the last element in a | |||
// node output list | |||
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs | |||
// may have values of type `placeholder` and the `input` field uses | |||
// the "output" format above. | |||
// By convention, "op" in node_def is resolved by consulting with a | |||
// user-defined library first. If not resolved, "func" is assumed to | |||
// be a builtin op. | |||
repeated NodeDef node_def = 3; | |||
// A mapping from the output arg names from `signature` to the | |||
// outputs from `node_def` that should be returned by the function. | |||
map<string, string> ret = 4; | |||
} | |||
// GradientDef defines the gradient function of a function defined in | |||
// a function library. | |||
// | |||
// A gradient function g (specified by gradient_func) for a function f | |||
// (specified by function_name) must follow the following: | |||
// | |||
// The function 'f' must be a numerical function which takes N inputs | |||
// and produces M outputs. Its gradient function 'g', which is a | |||
// function taking N + M inputs and produces N outputs. | |||
// | |||
// I.e. if we have | |||
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||
// then, g is | |||
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||
// dL/dy1, dL/dy2, ..., dL/dy_M), | |||
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||
// loss function). dL/dx_i is the partial derivative of L with respect | |||
// to x_i. | |||
message GradientDef { | |||
string function_name = 1; // The function name. | |||
string gradient_func = 2; // The gradient function's name. | |||
} |
@@ -1,64 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "GraphProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
import "node_def.proto"; | |||
import "function.proto"; | |||
import "versions.proto"; | |||
// Represents the graph of operations | |||
message GraphDef { | |||
repeated NodeDef node = 1; | |||
// Compatibility versions of the graph. See core/public/version.h for version | |||
// history. The GraphDef version is distinct from the TensorFlow version, and | |||
// each release of TensorFlow will support a range of GraphDef versions. | |||
VersionDef versions = 4; | |||
// Deprecated single version field; use versions above instead. Since all | |||
// GraphDef changes before "versions" was introduced were forward | |||
// compatible, this field is entirely ignored. | |||
int32 version = 3 [deprecated = true]; | |||
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||
// | |||
// "library" provides user-defined functions. | |||
// | |||
// Naming: | |||
// * library.function.name are in a flat namespace. | |||
// NOTE: We may need to change it to be hierarchical to support | |||
// different orgs. E.g., | |||
// { "/google/nn", { ... }}, | |||
// { "/google/vision", { ... }} | |||
// { "/org_foo/module_bar", { ... }} | |||
// map<string, FunctionDefLib> named_lib; | |||
// * If node[i].op is the name of one function in "library", | |||
// node[i] is deemed as a function call. Otherwise, node[i].op | |||
// must be a primitive operation supported by the runtime. | |||
// | |||
// | |||
// Function call semantics: | |||
// | |||
// * The callee may start execution as soon as some of its inputs | |||
// are ready. The caller may want to use Tuple() mechanism to | |||
// ensure all inputs are ready in the same time. | |||
// | |||
// * The consumer of return values may start executing as soon as | |||
// the return values the consumer depends on are ready. The | |||
// consumer may want to use Tuple() mechanism to ensure the | |||
// consumer does not start until all return values of the callee | |||
// function are ready. | |||
FunctionDefLibrary library = 2; | |||
}; |
@@ -1,22 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
import "graph.proto"; | |||
message GeGraphDef { | |||
string name = 1; | |||
GraphDef graph = 2; | |||
} | |||
message GraphDefLibrary { | |||
repeated GeGraphDef graph_def = 1; | |||
}; |
@@ -1,71 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "NodeProto"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
import "attr_value.proto"; | |||
message NodeDef { | |||
// The name given to this operator. Used for naming inputs, | |||
// logging, visualization, etc. Unique within a single GraphDef. | |||
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". | |||
string name = 1; | |||
// The operation name. There may be custom parameters in attrs. | |||
// Op names starting with an underscore are reserved for internal use. | |||
string op = 2; | |||
// Each input is "node:src_output" with "node" being a string name and | |||
// "src_output" indicating which output tensor to use from "node". If | |||
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs | |||
// may optionally be followed by control inputs that have the format | |||
// "^node". | |||
repeated string input = 3; | |||
// A (possibly partial) specification for the device on which this | |||
// node should be placed. | |||
// The expected syntax for this string is as follows: | |||
// | |||
// DEVICE_SPEC ::= PARTIAL_SPEC | |||
// | |||
// PARTIAL_SPEC ::= ("/" CONSTRAINT) * | |||
// CONSTRAINT ::= ("job:" JOB_NAME) | |||
// | ("replica:" [1-9][0-9]*) | |||
// | ("task:" [1-9][0-9]*) | |||
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) | |||
// | |||
// Valid values for this string include: | |||
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) | |||
// * "/job:worker/device:GPU:3" (partial specification) | |||
// * "" (no specification) | |||
// | |||
// If the constraints do not resolve to a single device (or if this | |||
// field is empty or not present), the runtime will attempt to | |||
// choose a device automatically. | |||
string device = 4; | |||
// Operation-specific graph-construction-time configuration. | |||
// Note that this should include all attrs defined in the | |||
// corresponding OpDef, including those with a value matching | |||
// the default -- this allows the default to change and makes | |||
// NodeDefs easier to interpret on their own. However, if | |||
// an attr with a default is not specified in this list, the | |||
// default will be used. | |||
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and | |||
// one of the names from the corresponding OpDef's attr field). | |||
// The values must have a type matching the corresponding OpDef | |||
// attr's type field. | |||
// Add some examples here showing best practices. | |||
map<string, AttrValue> attr = 5; | |||
}; |
@@ -1,172 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "OpDefProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
import "attr_value.proto"; | |||
import "types.proto"; | |||
// Defines an operation. A NodeDef in a GraphDef specifies an Op by | |||
// using the "op" field which should match the name of a OpDef. | |||
// LINT.IfChange | |||
message OpDef { | |||
// Op names starting with an underscore are reserved for internal use. | |||
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". | |||
string name = 1; | |||
// For describing inputs and outputs. | |||
message ArgDef { | |||
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". | |||
string name = 1; | |||
// Human readable description. | |||
string description = 2; | |||
// Describes the type of one or more tensors that are accepted/produced | |||
// by this input/output arg. The only legal combinations are: | |||
// * For a single tensor: either the "type" field is set or the | |||
// "type_attr" field is set to the name of an attr with type "type". | |||
// * For a sequence of tensors with the same type: the "number_attr" | |||
// field will be set to the name of an attr with type "int", and | |||
// either the "type" or "type_attr" field will be set as for | |||
// single tensors. | |||
// * For a sequence of tensors, the "type_list_attr" field will be set | |||
// to the name of an attr with type "list(type)". | |||
DataType type = 3; | |||
string type_attr = 4; // if specified, attr must have type "type" | |||
string number_attr = 5; // if specified, attr must have type "int" | |||
// If specified, attr must have type "list(type)", and none of | |||
// type, type_attr, and number_attr may be specified. | |||
string type_list_attr = 6; | |||
// For inputs: if true, the inputs are required to be refs. | |||
// By default, inputs can be either refs or non-refs. | |||
// For outputs: if true, outputs are refs, otherwise they are not. | |||
bool is_ref = 16; | |||
}; | |||
// Description of the input(s). | |||
repeated ArgDef input_arg = 2; | |||
// Description of the output(s). | |||
repeated ArgDef output_arg = 3; | |||
// Description of the graph-construction-time configuration of this | |||
// Op. That is to say, this describes the attr fields that will | |||
// be specified in the NodeDef. | |||
message AttrDef { | |||
// A descriptive name for the argument. May be used, e.g. by the | |||
// Python client, as a keyword argument name, and so should match | |||
// the regexp "[a-z][a-z0-9_]+". | |||
string name = 1; | |||
// One of the type names from attr_value.proto ("string", "list(string)", | |||
// "int", etc.). | |||
string type = 2; | |||
// A reasonable default for this attribute if the user does not supply | |||
// a value. If not specified, the user must supply a value. | |||
AttrValue default_value = 3; | |||
// Human-readable description. | |||
string description = 4; | |||
// --- Constraints --- | |||
// These constraints are only in effect if specified. Default is no | |||
// constraints. | |||
// For type == "int", this is a minimum value. For "list(___)" | |||
// types, this is the minimum length. | |||
bool has_minimum = 5; | |||
int64 minimum = 6; | |||
// The set of allowed values. Has type that is the "list" version | |||
// of the "type" field above (uses the "list" field of AttrValue). | |||
// If type == "type" or "list(type)" above, then the "type" field | |||
// of "allowed_values.list" has the set of allowed DataTypes. | |||
// If type == "string" or "list(string)", then the "s" field of | |||
// "allowed_values.list" has the set of allowed strings. | |||
AttrValue allowed_values = 7; | |||
} | |||
repeated AttrDef attr = 4; | |||
// Optional deprecation based on GraphDef versions. | |||
OpDeprecation deprecation = 8; | |||
// One-line human-readable description of what the Op does. | |||
string summary = 5; | |||
// Additional, longer human-readable description of what the Op does. | |||
string description = 6; | |||
// ------------------------------------------------------------------------- | |||
// Which optimizations this operation can participate in. | |||
// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) | |||
bool is_commutative = 18; | |||
// If is_aggregate is true, then this operation accepts N >= 2 | |||
// inputs and produces 1 output all of the same type. Should be | |||
// associative and commutative, and produce output with the same | |||
// shape as the input. The optimizer may replace an aggregate op | |||
// taking input from multiple devices with a tree of aggregate ops | |||
// that aggregate locally within each device (and possibly within | |||
// groups of nearby devices) before communicating. | |||
bool is_aggregate = 16; // for things like add | |||
// Other optimizations go here, like | |||
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. | |||
// ------------------------------------------------------------------------- | |||
// Optimization constraints. | |||
// Ops are marked as stateful if their behavior depends on some state beyond | |||
// their input tensors (e.g. variable reading op) or if they have | |||
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops | |||
// must always produce the same output for the same input and have | |||
// no side-effects. | |||
// | |||
// By default Ops may be moved between devices. Stateful ops should | |||
// either not be moved, or should only be moved if that state can also | |||
// be moved (e.g. via some sort of save / restore). | |||
// Stateful ops are guaranteed to never be optimized away by Common | |||
// Subexpression Elimination (CSE). | |||
bool is_stateful = 17; // for things like variables, queue | |||
// ------------------------------------------------------------------------- | |||
// Non-standard options. | |||
// By default, all inputs to an Op must be initialized Tensors. Ops | |||
// that may initialize tensors for the first time should set this | |||
// field to true, to allow the Op to take an uninitialized Tensor as | |||
// input. | |||
bool allows_uninitialized_input = 19; // for Assign, etc. | |||
}; | |||
// LINT.ThenChange( | |||
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) | |||
// Information about version-dependent deprecation of an op | |||
message OpDeprecation { | |||
// First GraphDef version at which the op is disallowed. | |||
int32 version = 1; | |||
// Explanation of why it was deprecated and what to use instead. | |||
string explanation = 2; | |||
}; | |||
// A collection of OpDefs | |||
message OpList { | |||
repeated OpDef op = 1; | |||
}; |
@@ -1,37 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "ResourceHandle"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
// Protocol buffer representing a handle to a tensorflow resource. Handles are | |||
// not valid across executions, but can be serialized back and forth from within | |||
// a single run. | |||
message ResourceHandleProto { | |||
// Unique name for the device containing the resource. | |||
string device = 1; | |||
// Container in which this resource is placed. | |||
string container = 2; | |||
// Unique name of this resource. | |||
string name = 3; | |||
// Hash code for the type of the resource. Is only valid in the same device | |||
// and in the same execution. | |||
uint64 hash_code = 4; | |||
// For debug-only, the name of the type pointed to by this handle, if | |||
// available. | |||
string maybe_type_name = 5; | |||
}; |
@@ -1,102 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "TensorProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
import "resource_handle.proto"; | |||
import "tensor_shape.proto"; | |||
import "types.proto"; | |||
// Protocol buffer representing a tensor. | |||
message TensorProto { | |||
DataType dtype = 1; | |||
// Shape of the tensor. | |||
TensorShapeProto tensor_shape = 2; | |||
// Only one of the representations below is set, one of "tensor_contents" and | |||
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot | |||
// contain repeated fields it would require another extra set of messages. | |||
// Version number. | |||
// | |||
// In version 0, if the "repeated xxx" representations contain only one | |||
// element, that element is repeated to fill the shape. This makes it easy | |||
// to represent a constant Tensor with a single value. | |||
int32 version_number = 3; | |||
// Serialized raw tensor content from either Tensor::AsProtoTensorContent or | |||
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation | |||
// can be used for all tensor types. The purpose of this representation is to | |||
// reduce serialization overhead during RPC call by avoiding serialization of | |||
// many repeated small items. | |||
bytes tensor_content = 4; | |||
// Type specific representations that make it easy to create tensor protos in | |||
// all languages. Only the representation corresponding to "dtype" can | |||
// be set. The values hold the flattened representation of the tensor in | |||
// row major order. | |||
// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll | |||
// have some pointless zero padding for each value here. | |||
repeated int32 half_val = 13 [packed = true]; | |||
// DT_FLOAT. | |||
repeated float float_val = 5 [packed = true]; | |||
// DT_DOUBLE. | |||
repeated double double_val = 6 [packed = true]; | |||
// DT_INT32, DT_INT16, DT_INT8, DT_UINT8. | |||
repeated int32 int_val = 7 [packed = true]; | |||
// DT_STRING | |||
repeated bytes string_val = 8; | |||
// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real | |||
// and imaginary parts of i-th single precision complex. | |||
repeated float scomplex_val = 9 [packed = true]; | |||
// DT_INT64 | |||
repeated int64 int64_val = 10 [packed = true]; | |||
// DT_BOOL | |||
repeated bool bool_val = 11 [packed = true]; | |||
// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real | |||
// and imaginary parts of i-th double precision complex. | |||
repeated double dcomplex_val = 12 [packed = true]; | |||
// DT_RESOURCE | |||
repeated ResourceHandleProto resource_handle_val = 14; | |||
// DT_VARIANT | |||
repeated VariantTensorDataProto variant_val = 15; | |||
// DT_UINT32 | |||
repeated uint32 uint32_val = 16 [packed = true]; | |||
// DT_UINT64 | |||
repeated uint64 uint64_val = 17 [packed = true]; | |||
}; | |||
// Protocol buffer representing the serialization format of DT_VARIANT tensors. | |||
message VariantTensorDataProto { | |||
// Name of the type of objects being serialized. | |||
string type_name = 1; | |||
// Portions of the object that are not Tensors. | |||
bytes metadata = 2; | |||
// Tensors contained within objects being serialized. | |||
repeated TensorProto tensors = 3; | |||
} |
@@ -1,53 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
// Protocol buffer representing the shape of tensors. | |||
syntax = "proto3"; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "TensorShapeProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
package domi.tensorflow; | |||
// Dimensions of a tensor. | |||
message TensorShapeProto { | |||
// One dimension of the tensor. | |||
message Dim { | |||
// Size of the tensor in that dimension. | |||
// This value must be >= -1, but values of -1 are reserved for "unknown" | |||
// shapes (values of -1 mean "unknown" dimension). Certain wrappers | |||
// that work with TensorShapeProto may fail at runtime when deserializing | |||
// a TensorShapeProto containing a dim value of -1. | |||
int64 size = 1; | |||
// Optional name of the tensor dimension. | |||
string name = 2; | |||
}; | |||
// Dimensions of the tensor, such as {"input", 30}, {"output", 40} | |||
// for a 30 x 40 2D tensor. If an entry has size -1, this | |||
// corresponds to a dimension of unknown size. The names are | |||
// optional. | |||
// | |||
// The order of entries in "dim" matters: It indicates the layout of the | |||
// values in the tensor in-memory representation. | |||
// | |||
// The first entry in "dim" is the outermost dimension used to layout the | |||
// values, the last entry is the innermost dimension. This matches the | |||
// in-memory layout of RowMajor Eigen tensors. | |||
// | |||
// If "dim.size()" > 0, "unknown_rank" must be false. | |||
repeated Dim dim = 2; | |||
// If true, the number of dimensions in the shape is unknown. | |||
// | |||
// If true, "dim.size()" must be 0. | |||
bool unknown_rank = 3; | |||
}; |
@@ -1,82 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "TypesProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
// LINT.IfChange | |||
enum DataType { | |||
// Not a legal value for DataType. Used to indicate a DataType field | |||
// has not been set. | |||
DT_INVALID = 0; | |||
// Data types that all computation devices are expected to be | |||
// capable to support. | |||
DT_FLOAT = 1; | |||
DT_DOUBLE = 2; | |||
DT_INT32 = 3; | |||
DT_UINT8 = 4; | |||
DT_INT16 = 5; | |||
DT_INT8 = 6; | |||
DT_STRING = 7; | |||
DT_COMPLEX64 = 8; // Single-precision complex | |||
DT_INT64 = 9; | |||
DT_BOOL = 10; | |||
DT_QINT8 = 11; // Quantized int8 | |||
DT_QUINT8 = 12; // Quantized uint8 | |||
DT_QINT32 = 13; // Quantized int32 | |||
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. | |||
DT_QINT16 = 15; // Quantized int16 | |||
DT_QUINT16 = 16; // Quantized uint16 | |||
DT_UINT16 = 17; | |||
DT_COMPLEX128 = 18; // Double-precision complex | |||
DT_HALF = 19; | |||
DT_RESOURCE = 20; | |||
DT_VARIANT = 21; // Arbitrary C++ data types | |||
DT_UINT32 = 22; | |||
DT_UINT64 = 23; | |||
// Do not use! These are only for parameters. Every enum above | |||
// should have a corresponding value below (verified by types_test). | |||
DT_FLOAT_REF = 101; | |||
DT_DOUBLE_REF = 102; | |||
DT_INT32_REF = 103; | |||
DT_UINT8_REF = 104; | |||
DT_INT16_REF = 105; | |||
DT_INT8_REF = 106; | |||
DT_STRING_REF = 107; | |||
DT_COMPLEX64_REF = 108; | |||
DT_INT64_REF = 109; | |||
DT_BOOL_REF = 110; | |||
DT_QINT8_REF = 111; | |||
DT_QUINT8_REF = 112; | |||
DT_QINT32_REF = 113; | |||
DT_BFLOAT16_REF = 114; | |||
DT_QINT16_REF = 115; | |||
DT_QUINT16_REF = 116; | |||
DT_UINT16_REF = 117; | |||
DT_COMPLEX128_REF = 118; | |||
DT_HALF_REF = 119; | |||
DT_RESOURCE_REF = 120; | |||
DT_VARIANT_REF = 121; | |||
DT_UINT32_REF = 122; | |||
DT_UINT64_REF = 123; | |||
} | |||
// LINT.ThenChange( | |||
// https://www.tensorflow.org/code/tensorflow/c/c_api.h, | |||
// https://www.tensorflow.org/code/tensorflow/go/tensor.go, | |||
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, | |||
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, | |||
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, | |||
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, | |||
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) |
@@ -1,39 +0,0 @@ | |||
/** | |||
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||
* | |||
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||
*/ | |||
syntax = "proto3"; | |||
package domi.tensorflow; | |||
option cc_enable_arenas = true; | |||
option java_outer_classname = "VersionsProtos"; | |||
option java_multiple_files = true; | |||
option java_package = "org.tensorflow.framework"; | |||
// Version information for a piece of serialized data | |||
// | |||
// There are different types of versions for each type of data | |||
// (GraphDef, etc.), but they all have the same common shape | |||
// described here. | |||
// | |||
// Each consumer has "consumer" and "min_producer" versions (specified | |||
// elsewhere). A consumer is allowed to consume this data if | |||
// | |||
// producer >= min_producer | |||
// consumer >= min_consumer | |||
// consumer not in bad_consumers | |||
// | |||
message VersionDef { | |||
// The version of the code that produced this data. | |||
int32 producer = 1; | |||
// Any consumer below this version is not allowed to consume this data. | |||
int32 min_consumer = 2; | |||
// Specific consumer versions which are disallowed (e.g. due to bugs). | |||
repeated int32 bad_consumers = 3; | |||
}; |
@@ -37,6 +37,7 @@ set(SRC_LIST | |||
"../graph/load/model_manager/task_info/task_info.cc" | |||
"../graph/load/model_manager/task_info/event_record_task_info.cc" | |||
"../graph/load/model_manager/task_info/event_wait_task_info.cc" | |||
"../graph/load/model_manager/task_info/ffts_task_info.cc" | |||
"../graph/load/model_manager/task_info/fusion_start_task_info.cc" | |||
"../graph/load/model_manager/task_info/fusion_stop_task_info.cc" | |||
"../graph/load/model_manager/task_info/kernel_ex_task_info.cc" | |||
@@ -1,113 +0,0 @@ | |||
syntax = "proto3"; | |||
package toolkit.dump; | |||
enum OutputDataType { | |||
DT_UNDEFINED = 0; | |||
DT_FLOAT = 1; | |||
DT_FLOAT16 = 2; | |||
DT_INT8 = 3; | |||
DT_UINT8 = 4; | |||
DT_INT16 = 5; | |||
DT_UINT16 = 6; | |||
DT_INT32 = 7; | |||
DT_INT64 = 8; | |||
DT_UINT32 = 9; | |||
DT_UINT64 = 10; | |||
DT_BOOL = 11; | |||
DT_DOUBLE = 12; | |||
DT_STRING = 13; | |||
DT_DUAL_SUB_INT8 = 14; | |||
DT_DUAL_SUB_UINT8 = 15; | |||
DT_COMPLEX64 = 16; | |||
DT_COMPLEX128 = 17; | |||
DT_QINT8 = 18; | |||
DT_QINT16 = 19; | |||
DT_QINT32 = 20; | |||
DT_QUINT8 = 21; | |||
DT_QUINT16 = 22; | |||
DT_RESOURCE = 23; | |||
DT_STRING_REF = 24; | |||
DT_DUAL = 25; | |||
DT_VARIANT = 26; | |||
} | |||
enum OutputFormat { | |||
FORMAT_NCHW = 0; | |||
FORMAT_NHWC = 1; | |||
FORMAT_ND = 2; | |||
FORMAT_NC1HWC0 = 3; | |||
FORMAT_FRACTAL_Z = 4; | |||
FORMAT_NC1C0HWPAD = 5; | |||
FORMAT_NHWC1C0 = 6; | |||
FORMAT_FSR_NCHW = 7; | |||
FORMAT_FRACTAL_DECONV = 8; | |||
FORMAT_C1HWNC0 = 9; | |||
FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; | |||
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; | |||
FORMAT_NC1HWC0_C04 = 12; | |||
FORMAT_FRACTAL_Z_C04 = 13; | |||
FORMAT_CHWN = 14; | |||
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; | |||
FORMAT_HWCN = 16; | |||
FORMAT_NC1KHKWHWC0 = 17; | |||
FORMAT_BN_WEIGHT = 18; | |||
FORMAT_FILTER_HWCK = 19; | |||
FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; | |||
FORMAT_HASHTABLE_LOOKUP_KEYS = 21; | |||
FORMAT_HASHTABLE_LOOKUP_VALUE = 22; | |||
FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; | |||
FORMAT_HASHTABLE_LOOKUP_HITS=24; | |||
FORMAT_C1HWNCoC0 = 25; | |||
FORMAT_MD = 26; | |||
FORMAT_NDHWC = 27; | |||
FORMAT_FRACTAL_ZZ = 28; | |||
FORMAT_FRACTAL_NZ = 29; | |||
FORMAT_RESERVED = 30; | |||
} | |||
message OriginalOp { | |||
string name = 1; | |||
uint32 output_index = 2; | |||
OutputDataType data_type = 3; | |||
OutputFormat format = 4; | |||
} | |||
message Shape { | |||
repeated uint64 dim = 1; | |||
} | |||
message OpOutput { | |||
OutputDataType data_type = 1; | |||
OutputFormat format = 2; | |||
Shape shape = 3; | |||
OriginalOp original_op = 4; // the original op corresponding to the output | |||
bytes data = 5; | |||
uint64 size = 6; | |||
} | |||
message OpInput { | |||
OutputDataType data_type = 1; | |||
OutputFormat format = 2; | |||
Shape shape = 3; | |||
bytes data = 4; | |||
uint64 size = 5; | |||
} | |||
enum BufferType { | |||
L1 = 0; | |||
} | |||
message OpBuffer { | |||
BufferType buffer_type = 1; | |||
bytes data = 2; | |||
uint64 size = 3; | |||
} | |||
message DumpData{ | |||
string version = 1; | |||
uint64 dump_time = 2; | |||
repeated OpOutput output = 3; | |||
repeated OpInput input = 4; | |||
repeated OpBuffer buffer = 5; | |||
string op_name = 6; | |||
} |
@@ -1,193 +0,0 @@ | |||
syntax = "proto3"; | |||
package ge.proto; | |||
enum DataType | |||
{ | |||
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
DT_FLOAT = 1; // float type | |||
DT_FLOAT16 = 2; // fp16 type | |||
DT_INT8 = 3; // int8 type | |||
DT_UINT8 = 4; // uint8 type | |||
DT_INT16 = 5; // int16 type | |||
DT_UINT16 = 6; // uint16 type | |||
DT_INT32 = 7; // | |||
DT_INT64 = 8; // int64 type | |||
DT_UINT32 = 9; // unsigned int32 | |||
DT_UINT64 = 10; // unsigned int64 | |||
DT_BOOL = 11; // bool type | |||
DT_DOUBLE = 12; // double type | |||
DT_STRING = 13; // string type | |||
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
DT_COMPLEX64 = 16; // complex64 type | |||
DT_COMPLEX128 = 17; // complex128 type | |||
DT_QINT8 = 18; // qint8 type | |||
DT_QINT16 = 19; // qint16 type | |||
DT_QINT32 = 20; // qint32 type | |||
DT_QUINT8 = 21; // quint8 type | |||
DT_QUINT16 = 22; // quint16 type | |||
DT_RESOURCE = 23; // resource type | |||
DT_STRING_REF = 24; // string_ref type | |||
DT_DUAL = 25; /**< dual output type */ | |||
DT_VARIANT = 26; // variant type | |||
DT_BF16 = 27; // bf16 type | |||
DT_INT4 = 28; // int4 type | |||
} | |||
message AttrDef | |||
{ | |||
message ListValue | |||
{ | |||
enum ListValueType{ | |||
VT_LIST_NONE = 0; | |||
VT_LIST_STRING = 1; | |||
VT_LIST_INT = 2; | |||
VT_LIST_FLOAT = 3; | |||
VT_LIST_BOOL = 4; | |||
VT_LIST_BYTES = 5; | |||
VT_LIST_TENSOR_DESC = 6; | |||
VT_LIST_TENSOR = 7; | |||
VT_LIST_GRAPH = 8; | |||
VT_LIST_NAMED_ATTRS = 9; | |||
VT_LIST_DATA_TYPE = 10; | |||
} | |||
repeated bytes s = 2; // "list(string)" | |||
repeated int64 i = 3; // "list(int)" | |||
repeated float f = 4; // "list(float)" | |||
repeated bool b = 5; // "list(bool)" | |||
repeated bytes bt = 7; | |||
repeated TensorDescriptor td = 8; | |||
repeated TensorDef t = 9; | |||
repeated GraphDef g = 10; | |||
repeated NamedAttrs na = 11; | |||
repeated int64 dt = 12; // list ge::DataType | |||
ListValueType val_type = 20; | |||
} | |||
message ListListInt{ | |||
message ListInt{ | |||
repeated int64 list_i = 1; // list int | |||
} | |||
repeated ListInt list_list_i = 1; // list list int | |||
} | |||
oneof value | |||
{ | |||
bytes s = 2; // "string" | |||
int64 i = 3; // "int" | |||
float f = 4; // "float" | |||
bool b = 5; // "bool" | |||
bytes bt = 7; | |||
ListValue list = 1; // any "list(...)" | |||
NamedAttrs func = 10; // Used to support attr nesting | |||
TensorDescriptor td = 11; // GeTensorDesc type | |||
TensorDef t = 12; // GeTensor type | |||
GraphDef g = 13; // Graph type | |||
ListListInt list_list_int = 14; // List List Int type | |||
int64 dt = 15; // ge::DataType | |||
} | |||
} | |||
// A list of attr names and their values. The whole list is attached | |||
// with a string name. E.g., MatMul[T=float]. | |||
message NamedAttrs | |||
{ | |||
string name = 1; | |||
map<string, AttrDef> attr = 2; | |||
} | |||
// Shape / dimension description, using row-major order | |||
message ShapeDef | |||
{ | |||
repeated int64 dim = 1; // Size of each dimension | |||
} | |||
// Multidimensional data description | |||
message TensorDescriptor | |||
{ | |||
string name = 1; // Optional parameter, tensor name | |||
DataType dtype = 2; // tensor datatype | |||
ShapeDef shape = 3; // Shape / dimension | |||
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
bool has_out_attr = 9; | |||
int64 size = 10; | |||
int64 weight_size = 11; | |||
bool reuse_input = 12; | |||
bool output_tensor = 13; | |||
string device_type = 14; | |||
bool input_tensor =15; | |||
int64 real_dim_cnt = 16; | |||
int64 reuse_input_index = 17; | |||
int64 data_offset = 18; | |||
int64 cmps_size = 19; | |||
string cmps_tab = 20; | |||
int64 cmps_tab_offset = 21; | |||
map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
} | |||
// GeTensor definition | |||
message TensorDef | |||
{ | |||
TensorDescriptor desc = 1; // Tensor description | |||
bytes data = 2; // Tensor data | |||
} | |||
// Operator description | |||
message OpDef | |||
{ | |||
string name = 1; // name | |||
string type = 2; // type | |||
repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
bool has_out_attr = 20; | |||
int64 id = 21; | |||
int64 stream_id =22; | |||
repeated string input_name = 23; | |||
repeated string src_name = 24; | |||
repeated int64 src_index = 25; | |||
repeated string dst_name = 26; | |||
repeated int64 dst_index = 27; | |||
repeated int64 input_i = 28; | |||
repeated int64 output_i = 29; | |||
repeated int64 workspace = 30; | |||
repeated int64 workspace_bytes = 31; | |||
repeated bool is_input_const = 32; | |||
repeated TensorDescriptor input_desc = 33; | |||
repeated TensorDescriptor output_desc = 34; | |||
repeated string subgraph_name = 35; | |||
} | |||
// Graph definition | |||
message GraphDef | |||
{ | |||
string name = 1; // name | |||
repeated string input = 4; // Graph input | |||
repeated string output = 5; // Graph output | |||
repeated OpDef op = 6; // List of operators | |||
map<string, AttrDef> attr = 11; // Extended field | |||
} | |||
// model definition | |||
message ModelDef | |||
{ | |||
string name = 1; // name | |||
uint32 version = 2; // IR Proto verion | |||
string custom_version = 3; // User model version number, passed in by user | |||
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
map<string, AttrDef> attr = 11; // Extended field | |||
} | |||
@@ -1,140 +0,0 @@ | |||
syntax = "proto3"; | |||
package domi; | |||
message InsertNewOps { | |||
repeated AippOpParams aipp_op = 1; | |||
repeated MultiShapeOpParams multi_shape_op = 2; | |||
} | |||
message AippOpParams { | |||
enum InputFormat { | |||
UNDEFINED = 0; | |||
YUV420SP_U8 = 1; | |||
XRGB8888_U8 = 2; | |||
RGB888_U8 = 3; | |||
YUV400_U8 = 4; | |||
NC1HWC0DI_FP16 = 5; | |||
NC1HWC0DI_S8 = 6; | |||
ARGB8888_U8 = 7; | |||
YUYV_U8 = 8; | |||
YUV422SP_U8 = 9; | |||
AYUV444_U8 = 10; | |||
RAW10 = 11; | |||
RAW12 = 12; | |||
RAW16 = 13; | |||
RAW24 = 14; | |||
RGB16 = 15; | |||
RGB20 = 16; | |||
RGB24 = 17; | |||
RGB8_IR = 18; | |||
RGB16_IR = 19; | |||
RGB24_IR = 20; | |||
} | |||
enum AippMode { | |||
undefined = 0; | |||
static = 1; | |||
dynamic = 2; | |||
} | |||
// AIPP模式,区分静态AIPP和动态AIPP | |||
AippMode aipp_mode = 1; | |||
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
uint32 related_input_rank = 2; | |||
// related_input_name is optional and the top name of data node which inserts aipp | |||
string related_input_name = 6; | |||
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
// 配置值 <= Data算子输出边的个数。 | |||
repeated uint32 input_edge_idx = 3; | |||
// [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
uint32 max_src_image_size = 4; | |||
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
bool support_rotation = 5; | |||
// [End] 动态AIPP参数 | |||
// [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
InputFormat input_format = 51; | |||
bool csc_switch = 52; | |||
float cpadding_value = 53; | |||
bool rbuv_swap_switch = 54; | |||
bool ax_swap_switch = 55; | |||
bool single_line_mode = 56; | |||
int32 src_image_size_w = 57; | |||
int32 src_image_size_h = 58; | |||
bool crop = 59; | |||
int32 load_start_pos_w = 60; | |||
int32 load_start_pos_h = 61; | |||
int32 crop_size_w = 62; | |||
int32 crop_size_h = 63; | |||
bool resize = 64; | |||
int32 resize_output_w = 65; | |||
int32 resize_output_h = 66; | |||
bool padding = 67; | |||
int32 left_padding_size = 68; | |||
int32 right_padding_size = 69; | |||
int32 top_padding_size = 70; | |||
int32 bottom_padding_size = 71; | |||
float padding_value = 72; | |||
int32 mean_chn_0 = 10; | |||
int32 mean_chn_1 = 11; | |||
int32 mean_chn_2 = 12; | |||
int32 mean_chn_3 = 19; | |||
float min_chn_0 = 13; | |||
float min_chn_1 = 14; | |||
float min_chn_2 = 15; | |||
float min_chn_3 = 20; | |||
repeated float var_reci_chn_0 = 16; | |||
repeated float var_reci_chn_1 = 17; | |||
repeated float var_reci_chn_2 = 18; | |||
repeated float var_reci_chn_3 = 21; | |||
repeated int32 matrix_r0c0 = 30; | |||
repeated int32 matrix_r0c1 = 31; | |||
repeated int32 matrix_r0c2 = 32; | |||
repeated int32 matrix_r1c0 = 33; | |||
repeated int32 matrix_r1c1 = 34; | |||
repeated int32 matrix_r1c2 = 35; | |||
repeated int32 matrix_r2c0 = 36; | |||
repeated int32 matrix_r2c1 = 37; | |||
repeated int32 matrix_r2c2 = 38; | |||
repeated int32 output_bias_0 = 39; | |||
repeated int32 output_bias_1 = 40; | |||
repeated int32 output_bias_2 = 41; | |||
repeated int32 input_bias_0 = 42; | |||
repeated int32 input_bias_1 = 43; | |||
repeated int32 input_bias_2 = 44; | |||
// [End] 静态AIPP参数 | |||
// The n number that is used for raw/rgbir data into f16 transformation. | |||
// The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
uint32 raw_rgbir_to_f16_n = 45; | |||
} | |||
message MultiShapeOpParams { | |||
enum MultiShapeMode { | |||
batch = 0; //动态batch | |||
resolution = 1; //动态分辨率,扩展用 | |||
} | |||
MultiShapeMode mode = 1; //算子模式 | |||
uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
} |
@@ -1,396 +0,0 @@ | |||
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
enum TargetType | |||
{ | |||
MINI = 0; | |||
TINY = 1; | |||
LITE = 2; | |||
} | |||
// offline model | |||
message ModelDef { | |||
string name = 1; | |||
uint32 version = 2; | |||
uint64 memory_size = 10; | |||
uint32 stream_num = 11; | |||
uint32 event_num = 12; | |||
uint64 weight_size = 13; | |||
uint32 label_num = 15; | |||
repeated OpDef op = 20; | |||
TargetType target_type = 23; | |||
map<string, AttrDef> attr = 30; | |||
}; | |||
// operator define | |||
message OpDef { | |||
string name = 1; | |||
string type = 2; | |||
uint32 id = 3; | |||
uint32 stream_id = 4; | |||
repeated string input_name = 5; | |||
repeated string src_name = 8; | |||
repeated int32 src_index = 9; | |||
repeated int64 input = 10; | |||
repeated int64 output = 11; | |||
repeated TensorDescriptor input_desc = 12; | |||
repeated TensorDescriptor output_desc = 13; | |||
repeated WeightDef weights = 14; | |||
repeated string dst_name = 15; | |||
repeated int32 dst_index = 16; | |||
repeated int64 workspace = 20; | |||
repeated uint32 workspace_bytes = 21; | |||
repeated string weight_name = 22; | |||
repeated bool is_input_const = 23; | |||
map<string, AttrDef> attr = 30; | |||
QuantizeFactorParams quantize_factor = 31; | |||
oneof op_params { | |||
// start at 100 here | |||
SendOpParams sender_param = 100; | |||
RecvOpParams receiver_param = 200; | |||
ConvolutionOpParams convolution_param = 300; | |||
PoolingOpParams pooling_param = 400; | |||
EltwiseOpParams eltwise_param = 500; | |||
BatchNormOpParams batchnorm_param = 600; | |||
ScaleOpParams scale_param = 700; | |||
FullConnectionOpParams full_connection_param = 800; | |||
SoftmaxOpParams softmax_param = 900; | |||
ActivationOpParams activation_param = 1000; | |||
ReshapeOpParams reshape_param = 1100; | |||
} | |||
}; | |||
message SendOpParams { | |||
uint32 event_id = 1; | |||
}; | |||
message RecvOpParams { | |||
uint32 event_id = 1; | |||
}; | |||
enum QuantizeScaleType | |||
{ | |||
VECTOR_SCALE = 0; | |||
SCALAR_SCALE = 1; | |||
} | |||
enum QuantizeScaleMode | |||
{ | |||
NORMAL_MODE = 0; | |||
SQRT_MODE = 1; | |||
} | |||
enum QuantizeAlgorithm | |||
{ | |||
NON_OFFSET_ALGO = 0; | |||
HALF_OFFSET_ALGO = 1; | |||
ALL_OFFSET_ALGO = 2; | |||
} | |||
message QuantizeFactor | |||
{ | |||
QuantizeScaleMode scale_mode = 1; | |||
bytes scale_value = 2; | |||
int64 scale_offset = 3; | |||
bytes offset_data_value = 4; | |||
int64 offset_data_offset = 5; | |||
bytes offset_weight_value = 6; | |||
int64 offset_weight_offset = 7; | |||
bytes offset_pad_value = 8; | |||
int64 offset_pad_offset = 9; | |||
}; | |||
message QuantizeCalcFactor | |||
{ | |||
bytes offsetw = 1; | |||
int64 offsetw_offset = 2; | |||
bytes offsetd = 3; | |||
int64 offsetd_offset = 4; | |||
bytes scalereq = 5; | |||
int64 scaledreq_offset = 6; | |||
bytes offsetdnext = 7; | |||
int64 offsetdnext_offset = 8; | |||
} | |||
message QuantizeFactorParams | |||
{ | |||
QuantizeAlgorithm quantize_algo = 1; | |||
QuantizeScaleType scale_type = 2; | |||
QuantizeFactor quantize_param = 3; | |||
QuantizeFactor dequantize_param = 4; | |||
QuantizeFactor requantize_param = 5; | |||
QuantizeCalcFactor quantizecalc_param = 6; | |||
}; | |||
message ConvolutionOpParams { | |||
int32 mode = 1; | |||
int32 algo = 2; | |||
int32 pad_mode = 3; | |||
uint32 group = 4; | |||
uint32 num_output = 5; | |||
repeated uint32 pad = 10; | |||
repeated uint32 stride = 11; | |||
repeated uint32 dilation = 12; | |||
repeated uint32 kernel = 13; | |||
float alpha = 20; | |||
float beta = 21; | |||
WeightDef filter = 40; | |||
WeightDef bias = 41; | |||
bool relu_flag = 62; | |||
repeated uint32 adj = 70; | |||
repeated uint32 target_shape = 71; | |||
repeated uint32 before_pad = 72; | |||
}; | |||
message PoolingOpParams { | |||
int32 mode = 1; | |||
int32 nan_opt = 2; | |||
int32 pad_mode = 3; | |||
bool global_pooling = 4; | |||
repeated uint32 window = 10; | |||
repeated uint32 pad = 11; | |||
repeated uint32 stride = 12; | |||
bool ceil_mode = 13; | |||
int32 data_mode = 14; | |||
float alpha = 20; | |||
float beta = 21; | |||
repeated uint32 before_pad = 22; | |||
}; | |||
message EltwiseOpParams { | |||
int32 mode = 1; | |||
repeated float coeff = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
repeated WeightDef weight = 5; | |||
bool relu_flag = 6; | |||
}; | |||
message ActivationOpParams { | |||
int32 mode = 1; | |||
float coef = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
}; | |||
message BatchNormOpParams { | |||
int32 mode = 1; | |||
float alpha = 2; | |||
float beta = 3; | |||
double epsilon = 4;//optinal,[default = 1e-5] | |||
bool use_global_stats = 5; //optinal,by default true,testing mode | |||
float moving_average_fraction = 6; //optinal,[default = .999]; | |||
WeightDef estimated_mean = 7; | |||
WeightDef estimated_variance = 8; | |||
WeightDef scale = 9; | |||
WeightDef bias = 10; | |||
}; | |||
message ScaleOpParams { | |||
WeightDef scale = 1; | |||
WeightDef bias = 2; | |||
}; | |||
message ReshapeOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
ShapeDef shape = 3; | |||
int32 axis = 4; | |||
int32 num_axes = 5; | |||
int32 format = 6; | |||
}; | |||
message SoftmaxOpParams { | |||
int32 algo = 1; | |||
int32 mode = 2; | |||
float alpha = 3; | |||
float beta = 4; | |||
}; | |||
message FullConnectionOpParams { | |||
WeightDef filter = 1; | |||
WeightDef bias = 2; | |||
uint32 num_output = 3; | |||
bool relu_flag = 12; | |||
}; | |||
message FlattenOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 start_axis = 3; | |||
int32 end_axis = 4; | |||
} | |||
message AddLimitedOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 axis = 3; | |||
bool broadcast = 4; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message MulLimitedOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
int32 axis = 3; | |||
bool broadcast = 4; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message AddOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message MulOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message SubOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
repeated WeightDef weight = 10; | |||
}; | |||
message BiasAddOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
WeightDef bias = 10; | |||
}; | |||
message MatMulOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
bool transposeX = 3; | |||
bool transposeW = 4; | |||
WeightDef filter = 10; | |||
WeightDef bias = 12; | |||
}; | |||
message RsqrtOpParams { | |||
float alpha = 1; | |||
float beta = 2; | |||
}; | |||
message WeightDef { | |||
int32 format = 1; | |||
int32 data_type = 2; | |||
ShapeDef shape = 3; | |||
bytes data = 4; | |||
int64 data_offset = 5; | |||
uint32 cmps_size = 6; | |||
bytes cmps_tab = 7; | |||
int64 cmps_tab_offset = 10; | |||
CompressInfo cmps_info = 8; | |||
AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||
} | |||
message ShapeDef { | |||
repeated int64 dim = 1; | |||
} | |||
enum DeviceType { | |||
NPU = 0; // In default, we will use NPU. | |||
CPU = 1; // CPU | |||
} | |||
message AllOffsetQuantizeInfo { | |||
float scale = 1; | |||
int32 offset = 2; | |||
} | |||
message TensorDescriptor { | |||
int32 format = 1; | |||
int32 data_type = 2; | |||
repeated int64 dim = 3; | |||
uint32 size = 4; | |||
bool reuse_input = 5; | |||
bool output_tensor = 7; | |||
DeviceType device_type = 8; | |||
bool input_tensor = 9; | |||
uint32 real_dim_cnt = 10; | |||
uint32 reuse_input_index = 11; | |||
AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||
} | |||
message CompressInfo { | |||
int32 blockRow = 1; // block row | |||
int32 blockCol = 2; // block col | |||
int32 fractalK = 3; // fractal K | |||
int32 fractalN = 4; // fractal N | |||
int32 lastFractalK = 5; // K of last fractal | |||
int32 lastFractalN = 6; // N of last fractal | |||
int32 cubeSize = 7; // cube's length | |||
int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||
} | |||
message AttrDef { | |||
message ListValue { | |||
repeated string s = 2; // "list(string)" | |||
repeated int64 i = 3 [packed = true]; // "list(int)" | |||
repeated float f = 4 [packed = true]; // "list(float)" | |||
repeated bool b = 5 [packed = true]; // "list(bool)" | |||
repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||
repeated bytes bt = 7; | |||
} | |||
oneof value { | |||
string s = 2; // "string" | |||
int64 i = 3; // "int" | |||
float f = 4; // "float" | |||
bool b = 5; // "bool" | |||
uint32 u = 6; // "uint32" | |||
bytes bt = 7; | |||
ListValue list = 1; // any "list(...)" | |||
NamedAttrs func = 10; | |||
} | |||
} | |||
// A list of attr names and their values. The whole list is attached | |||
// with a string name. E.g., MatMul[T=float]. | |||
message NamedAttrs { | |||
string name = 1; | |||
map<string, AttrDef> attr = 2; | |||
} | |||
@@ -1,75 +0,0 @@ | |||
syntax = "proto3"; | |||
package toolkit.aicpu.dump; | |||
message Shape { | |||
repeated uint64 dim = 1; | |||
} | |||
message Output { | |||
int32 data_type = 1; | |||
int32 format = 2; | |||
Shape shape = 3; | |||
uint64 address = 4; | |||
string original_name = 5; | |||
int32 original_output_index = 6; | |||
int32 original_output_data_type = 7; | |||
int32 original_output_format = 8; | |||
uint64 size = 9; | |||
Shape origin_shape = 10; | |||
} | |||
message Input { | |||
int32 data_type =1; | |||
int32 format = 2; | |||
Shape shape = 3; | |||
uint64 address = 4; | |||
uint64 size = 5; | |||
Shape origin_shape = 6; | |||
} | |||
enum BufferType { | |||
L1 = 0; | |||
} | |||
message OpBuffer { | |||
BufferType buffer_type = 1; | |||
uint64 address = 2; | |||
uint64 size = 3; | |||
} | |||
message Op { | |||
string op_name = 1; | |||
string op_type = 2; | |||
} | |||
message Task { | |||
uint32 task_id = 1; | |||
uint32 stream_id = 2; | |||
Op op = 3; | |||
repeated Output output = 4; | |||
bool end_graph = 5; | |||
repeated Input input = 6; | |||
repeated OpBuffer buffer = 7; | |||
} | |||
message OpMappingInfo { | |||
string dump_path = 1; | |||
oneof model_name_param { | |||
string model_name = 2; | |||
} | |||
oneof model_id_param { | |||
uint32 model_id = 3; | |||
} | |||
oneof step_id { | |||
uint64 step_id_addr = 4; | |||
} | |||
oneof iterations_per_loop { | |||
uint64 iterations_per_loop_addr = 5; | |||
} | |||
oneof loop_cond { | |||
uint64 loop_cond_addr = 6; | |||
} | |||
uint32 flag = 7; // 0x01 load, 0x00 unload | |||
repeated Task task = 8; | |||
string dump_step = 9; | |||
} |
@@ -1,179 +0,0 @@ | |||
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
message ModelTaskDef { | |||
string version = 1; | |||
map<string, string> attr = 9; // Extended field | |||
repeated TaskDef task = 10; | |||
uint64 memory_size = 11; | |||
uint32 stream_num = 12; | |||
uint32 event_num = 13; | |||
uint64 weight_size = 14; | |||
repeated bytes op = 15; // input/output opdef in bytes | |||
uint64 base_addr = 16; // base addr | |||
uint64 weight_addr = 17; // weight addr | |||
uint32 batch_num = 18; | |||
} | |||
message TaskDef { | |||
uint32 id = 1; | |||
uint32 type = 2; | |||
uint32 stream_id = 10; | |||
uint32 event_id = 11; | |||
KernelDef kernel = 20; | |||
KernelExDef kernel_ex = 21; | |||
KernelHcclDef kernel_hccl = 25; | |||
EventExDef event_ex = 26; | |||
LogTimeStampDef log_timestamp = 28; | |||
uint32 label_id = 30; | |||
MemcpyAsyncDef memcpy_async = 31; | |||
StreamSwitchDef stream_switch = 32; | |||
StreamActiveDef stream_active = 33; | |||
bytes private_def = 34; | |||
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||
StreamSwitchNDef stream_switch_n = 36; | |||
LabelSetDef label_set = 37; | |||
LabelGotoExDef label_goto_ex = 38; | |||
LabelSwitchByIndexDef label_switch_by_index = 39; | |||
KernelDefWithHandle kernel_with_handle = 40; | |||
} | |||
message KernelDef { | |||
KernelContext context = 1; | |||
string stub_func = 10; | |||
uint32 block_dim = 11; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes sm_desc = 14; | |||
bytes flowtable = 15; | |||
string so_name = 16; | |||
string kernel_name = 17; | |||
bytes kernel_ext_info = 18; | |||
uint32 kernel_ext_info_size = 19; | |||
} | |||
message KernelDefWithHandle { | |||
KernelContext context = 1; | |||
uint64 handle = 10; | |||
string dev_func = 11; | |||
uint32 block_dim = 12; | |||
uint32 args_size = 13; | |||
bytes args = 14; | |||
bytes sm_desc = 15; | |||
string original_kernel_key = 16; | |||
string node_info = 17; | |||
} | |||
message KernelContext { | |||
uint32 kernel_type = 1; | |||
uint32 op_id = 2; // OP type in CCE | |||
uint32 kernel_func_id = 3; | |||
uint32 op_index = 4; // TE/Custom operator | |||
bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||
bytes args_offset = 6; // args offset information | |||
uint32 args_count = 7; // args count | |||
repeated uint32 origin_op_index = 8; | |||
} | |||
message KernelExDef { | |||
uint32 flags = 1; | |||
uint32 op_index = 4; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||
uint32 task_info_size = 15; | |||
bytes kernel_ext_info = 16; | |||
uint32 kernel_ext_info_size = 17; | |||
} | |||
message KernelHcclDef { | |||
uint32 op_index = 8; | |||
string hccl_type = 9; | |||
} | |||
message EventExDef { | |||
uint32 op_index = 1; | |||
uint32 event_type = 2; | |||
} | |||
message LogTimeStampDef { | |||
uint64 logid = 1; | |||
bool notify = 2; | |||
uint32 flat = 3; | |||
} | |||
message MemcpyAsyncDef { | |||
uint64 dst = 1; | |||
uint64 dst_max = 2; | |||
uint64 src = 3; | |||
uint64 count = 4; | |||
uint32 kind = 5; | |||
uint32 op_index = 6; | |||
} | |||
message StreamSwitchDef { | |||
uint32 op_index = 1; | |||
uint32 true_stream_id = 2; | |||
int64 value = 3; | |||
uint64 value_ptr = 4; | |||
uint32 data_type = 5; | |||
} | |||
message StreamActiveDef { | |||
uint32 op_index = 1; | |||
uint32 active_stream_id = 2; | |||
} | |||
message StreamSwitchNDef { | |||
uint32 op_index = 1; | |||
uint32 size = 2; | |||
repeated int64 target_value = 3; | |||
repeated uint32 true_stream_id = 4; | |||
uint32 element_size = 5; | |||
uint32 data_type = 6; | |||
} | |||
message LabelSetDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelGotoExDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelSwitchByIndexDef { | |||
uint32 op_index = 1; | |||
uint32 label_max = 2; | |||
} |
@@ -13,15 +13,15 @@ | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "host_cpu_engine.h" | |||
#include "graph/common/omg_util.h" | |||
#include "ge_local_engine/engine/host_cpu_engine.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "graph/utils/tensor_adapter.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "register/op_kernel_registry.h" | |||
#include "register/host_cpu_context.h" | |||
#include "common/ge/ge_util.h" | |||
#include "common/ge/plugin_manager.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "common/fp16_t.h" | |||
#include "common/math/math_util.h" | |||
@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { | |||
} | |||
Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | |||
std::string op_type; | |||
auto status = GetOriginalType(node, op_type); | |||
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); | |||
const std::string op_type = NodeUtils::GetNodeType(node); | |||
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | |||
if (kernel == nullptr) { | |||
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | |||
@@ -1,179 +0,0 @@ | |||
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
message ModelTaskDef { | |||
string version = 1; | |||
map<string, string> attr = 9; // Extended field | |||
repeated TaskDef task = 10; | |||
uint64 memory_size = 11; | |||
uint32 stream_num = 12; | |||
uint32 event_num = 13; | |||
uint64 weight_size = 14; | |||
repeated bytes op = 15; // input/output opdef in bytes | |||
uint64 base_addr = 16; // base addr | |||
uint64 weight_addr = 17; // weight addr | |||
uint32 batch_num = 18; | |||
} | |||
message TaskDef { | |||
uint32 id = 1; | |||
uint32 type = 2; | |||
uint32 stream_id = 10; | |||
uint32 event_id = 11; | |||
KernelDef kernel = 20; | |||
KernelExDef kernel_ex = 21; | |||
KernelHcclDef kernel_hccl = 25; | |||
EventExDef event_ex = 26; | |||
LogTimeStampDef log_timestamp = 28; | |||
uint32 label_id = 30; | |||
MemcpyAsyncDef memcpy_async = 31; | |||
StreamSwitchDef stream_switch = 32; | |||
StreamActiveDef stream_active = 33; | |||
bytes private_def = 34; | |||
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||
StreamSwitchNDef stream_switch_n = 36; | |||
LabelSetDef label_set = 37; | |||
LabelGotoExDef label_goto_ex = 38; | |||
LabelSwitchByIndexDef label_switch_by_index = 39; | |||
KernelDefWithHandle kernel_with_handle = 40; | |||
} | |||
message KernelDef { | |||
KernelContext context = 1; | |||
string stub_func = 10; | |||
uint32 block_dim = 11; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes sm_desc = 14; | |||
bytes flowtable = 15; | |||
string so_name = 16; | |||
string kernel_name = 17; | |||
bytes kernel_ext_info = 18; | |||
uint32 kernel_ext_info_size = 19; | |||
} | |||
message KernelDefWithHandle { | |||
KernelContext context = 1; | |||
uint64 handle = 10; | |||
string dev_func = 11; | |||
uint32 block_dim = 12; | |||
uint32 args_size = 13; | |||
bytes args = 14; | |||
bytes sm_desc = 15; | |||
string original_kernel_key = 16; | |||
string node_info = 17; | |||
} | |||
message KernelContext { | |||
uint32 kernel_type = 1; | |||
uint32 op_id = 2; // OP type in CCE | |||
uint32 kernel_func_id = 3; | |||
uint32 op_index = 4; // TE/Custom operator | |||
bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||
bytes args_offset = 6; // args offset information | |||
uint32 args_count = 7; // args count | |||
repeated uint32 origin_op_index = 8; | |||
} | |||
message KernelExDef { | |||
uint32 flags = 1; | |||
uint32 op_index = 4; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||
uint32 task_info_size = 15; | |||
bytes kernel_ext_info = 16; | |||
uint32 kernel_ext_info_size = 17; | |||
} | |||
message KernelHcclDef { | |||
uint32 op_index = 8; | |||
string hccl_type = 9; | |||
} | |||
message EventExDef { | |||
uint32 op_index = 1; | |||
uint32 event_type = 2; | |||
} | |||
message LogTimeStampDef { | |||
uint64 logid = 1; | |||
bool notify = 2; | |||
uint32 flat = 3; | |||
} | |||
message MemcpyAsyncDef { | |||
uint64 dst = 1; | |||
uint64 dst_max = 2; | |||
uint64 src = 3; | |||
uint64 count = 4; | |||
uint32 kind = 5; | |||
uint32 op_index = 6; | |||
} | |||
message StreamSwitchDef { | |||
uint32 op_index = 1; | |||
uint32 true_stream_id = 2; | |||
int64 value = 3; | |||
uint64 value_ptr = 4; | |||
uint32 data_type = 5; | |||
} | |||
message StreamActiveDef { | |||
uint32 op_index = 1; | |||
uint32 active_stream_id = 2; | |||
} | |||
message StreamSwitchNDef { | |||
uint32 op_index = 1; | |||
uint32 size = 2; | |||
repeated int64 target_value = 3; | |||
repeated uint32 true_stream_id = 4; | |||
uint32 element_size = 5; | |||
uint32 data_type = 6; | |||
} | |||
message LabelSetDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelGotoExDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelSwitchByIndexDef { | |||
uint32 op_index = 1; | |||
uint32 label_max = 2; | |||
} |
@@ -0,0 +1,58 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_opt_info/ge_opt_info.h" | |||
#include <string> | |||
#include <map> | |||
#include "graph/ge_local_context.h" | |||
#include "ge/ge_api_types.h" | |||
#include "common/debug/ge_log.h" | |||
#include "opt_info.h" | |||
namespace ge { | |||
Status GeOptInfo::SetOptInfo() { | |||
std::string soc_ver; | |||
graphStatus ret = GetThreadLocalContext().GetOption(SOC_VERSION, soc_ver); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Get soc version failed."); | |||
GELOGE(FAILED, "[Get][SocVersion]Get soc version failed."); | |||
return FAILED; | |||
} | |||
GELOGD("Soc version:%s.", soc_ver.c_str()); | |||
std::map<std::string, std::string> opt_info; | |||
// the first arg does not work at present. | |||
if (gelc::GetOptInfo(gelc::kOffline, soc_ver, opt_info) != gelc::SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Get optional information failed, is_offline:%d, soc version:%s", | |||
gelc::kOffline, soc_ver.c_str()); | |||
GELOGE(FAILED, "[Get][OptInfo]Get optional information failed, is_offline:%d, soc version:%s", | |||
gelc::kOffline, soc_ver.c_str()); | |||
return FAILED; | |||
} | |||
// do nothing if get empty information | |||
if (opt_info.empty()) { | |||
GELOGI("Optional information is empty."); | |||
return SUCCESS; | |||
} | |||
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions(); | |||
for (const auto &itr : opt_info) { | |||
graph_options.emplace(itr.first, itr.second); | |||
GELOGI("Get optional information success, key:%s, value:%s.", itr.first.c_str(), itr.second.c_str()); | |||
} | |||
GetThreadLocalContext().SetGraphOption(graph_options); | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -0,0 +1,31 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_OPT_INFO_GE_OPT_INFO_H_ | |||
#define GE_OPT_INFO_GE_OPT_INFO_H_ | |||
#include "ge/ge_api_error_codes.h" | |||
#include "register/register_types.h" | |||
namespace ge { | |||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeOptInfo { | |||
public: | |||
GeOptInfo() = default; | |||
static Status SetOptInfo(); | |||
}; | |||
} // namespace ge | |||
#endif // GE_OPT_INFO_GE_OPT_INFO_H_ |
@@ -16,6 +16,7 @@ set(GE_SRC_LIST | |||
"task/label_goto_task.cc" | |||
"task/label_set_task.cc" | |||
"task/label_switch_task.cc" | |||
"task/label_manager.cc" | |||
) | |||
add_library(ge_runtime SHARED ${GE_SRC_LIST}) | |||
@@ -21,6 +21,7 @@ | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/types.h" | |||
#include "common/util.h" | |||
#include "common/math/math_util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "framework/common/op/op_parser_util.h" | |||
#include "graph/types.h" | |||
@@ -52,15 +52,7 @@ HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<Hccl | |||
} | |||
} | |||
HcclTask::~HcclTask() { | |||
if (workspace_mem_ != nullptr) { | |||
rtError_t rt_ret = rtFree(workspace_mem_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret); | |||
} | |||
workspace_mem_ = nullptr; | |||
} | |||
} | |||
HcclTask::~HcclTask() {} | |||
bool HcclTask::Distribute() { | |||
// Ops kernel info store | |||
@@ -79,11 +71,7 @@ bool HcclTask::Distribute() { | |||
SetSecondaryStream(); | |||
if (task_info_->workspace_size() > 0) { | |||
rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
workspace_mem_ = task_info_->workspace_addr(); | |||
} | |||
GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl."); | |||
@@ -16,33 +16,46 @@ | |||
#include "ge_runtime/task/label_goto_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
#include "framework/common/util.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) { | |||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
index_value_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
auto stream_list = model_context.stream_list(); | |||
auto label_list = model_context.label_list(); | |||
rt_model_handle_ = model_context.rt_model_handle(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
uint32_t label_id = task_info->label_id(); | |||
label_id_ = task_info->label_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id_); | |||
if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) { | |||
GELOGW("Stream/Label id invalid."); | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
label_ = label_list[label_id]; | |||
label_manager_ = LabelManager::GetInstance(); | |||
if (label_manager_ == nullptr) { | |||
GELOGW("Get label manager instance failed."); | |||
return; | |||
} | |||
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list); | |||
} | |||
LabelGotoTask::~LabelGotoTask() { | |||
GE_FREE_RT_LOG(label_info_); | |||
GE_FREE_RT_LOG(index_value_); | |||
if (index_value_ != nullptr) { | |||
rtError_t rt_ret = rtFree(index_value_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "rtFree index_value_ failed! ret: 0x%X.", rt_ret); | |||
} | |||
index_value_ = nullptr; | |||
} | |||
} | |||
bool LabelGotoTask::Distribute() { | |||
@@ -72,7 +85,7 @@ bool LabelGotoTask::Distribute() { | |||
return false; | |||
} | |||
rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||
rt_ret = rtLabelListCpy(reinterpret_cast<void**>(label_list.data()), label_list.size(), label_info_, label_info_size); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
return false; | |||
@@ -94,21 +107,34 @@ bool LabelGotoTask::CheckParamValid() { | |||
return false; | |||
} | |||
if (label_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "label is null!"); | |||
if (label_info_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "label info is null!"); | |||
return false; | |||
} | |||
if (label_info_ != nullptr) { | |||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
return false; | |||
if (index_value_ == nullptr) { | |||
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
uint64_t index = 0; | |||
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
} | |||
if (index_value_ != nullptr) { | |||
GELOGE(PARAM_INVALID, "index_value_ has dirty data."); | |||
void *label_info = label_info_->GetLabelInfo(); | |||
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
@@ -18,7 +18,11 @@ | |||
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
#include <memory> | |||
#include <vector> | |||
#include <map> | |||
#include <mutex> | |||
#include "ge_runtime/task/task.h" | |||
#include "ge_runtime/task/label_manager.h" | |||
namespace ge { | |||
namespace model_runner { | |||
@@ -31,13 +35,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
bool Distribute() override; | |||
private: | |||
bool CheckParamValid(); | |||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
void *stream_{nullptr}; | |||
void *label_{nullptr}; | |||
void *label_info_{nullptr}; | |||
void *index_value_{nullptr}; | |||
void *stream_; | |||
std::shared_ptr<LabelGuard> label_info_; | |||
void *index_value_; | |||
uint32_t label_id_; | |||
rtModel_t rt_model_handle_; | |||
std::shared_ptr<LabelManager> label_manager_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
@@ -0,0 +1,119 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_runtime/task/label_manager.h" | |||
#include <algorithm> | |||
#include <string> | |||
#include "runtime/mem.h" | |||
#include "runtime/rt_model.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "framework/common/debug/ge_log.h" | |||
namespace ge { | |||
namespace model_runner { | |||
std::weak_ptr<LabelManager> LabelManager::instance_; | |||
std::mutex LabelManager::instance_mutex_; | |||
template <class T> | |||
static std::string GetVectorString(const std::vector<T> &vec) { | |||
std::string ret; | |||
for (size_t i = 0; i < vec.size(); ++i) { | |||
if (i != 0) { | |||
ret.push_back(','); | |||
} | |||
ret += std::to_string(vec[i]); | |||
} | |||
return ret; | |||
} | |||
LabelGuard::~LabelGuard() { | |||
void *label_info = GetLabelInfo(); | |||
if (label_info != nullptr) { | |||
rtError_t rt_ret = rtFree(label_info); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret); | |||
} | |||
} | |||
} | |||
std::shared_ptr<LabelManager> LabelManager::GetInstance() { | |||
std::lock_guard<std::mutex> lock(instance_mutex_); | |||
auto instance = instance_.lock(); | |||
if (instance != nullptr) { | |||
return instance; | |||
} | |||
instance = std::make_shared<LabelManager>(); | |||
instance_ = instance; | |||
return instance; | |||
} | |||
std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||
const std::vector<void *> &all_label) { | |||
std::lock_guard<std::mutex> lock(model_info_mapping_mutex_); | |||
rtError_t rt_ret; | |||
auto model_iter = model_info_mapping_.find(model); | |||
if (model_iter == model_info_mapping_.end()) { | |||
model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>()); | |||
model_iter = model_info_mapping_.find(model); | |||
} | |||
std::string label_id_str = GetVectorString(label_ids); | |||
auto &label_map = model_iter->second; | |||
auto label_iter = label_map.find(label_id_str); | |||
if (label_iter != label_map.end()) { | |||
auto label_guard = label_iter->second.lock(); | |||
if (label_guard != nullptr) { | |||
GELOGI("model %p find same label id %s.", model, label_id_str.c_str()); | |||
return label_guard; | |||
} | |||
} | |||
GELOGI("Alloc label id %s for model %p.", label_id_str.c_str(), model); | |||
void *label_info; | |||
std::vector<void *> label_list; | |||
bool status = true; | |||
std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list), | |||
[&all_label, &status](uint32_t idx) -> void * { | |||
if (idx >= all_label.size()) { | |||
GELOGE(PARAM_INVALID, "Invalid label id %u, all label list size %zu.", idx, all_label.size()); | |||
status = false; | |||
return nullptr; | |||
} | |||
return all_label[idx]; | |||
}); | |||
if (!status) { | |||
GELOGE(PARAM_INVALID, "Get label info failed."); | |||
return nullptr; | |||
} | |||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return nullptr; | |||
} | |||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return nullptr; | |||
} | |||
auto label_guard = std::make_shared<LabelGuard>(label_info); | |||
label_map.emplace(label_id_str, label_guard); | |||
return label_guard; | |||
} | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -0,0 +1,54 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
#define GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
#include <vector> | |||
#include <memory> | |||
#include <mutex> | |||
#include <map> | |||
#include <runtime/base.h> | |||
namespace ge { | |||
namespace model_runner { | |||
class LabelGuard { | |||
public: | |||
explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {} | |||
~LabelGuard(); | |||
void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); } | |||
private: | |||
uintptr_t label_info_; | |||
}; | |||
class LabelManager { | |||
public: | |||
static std::shared_ptr<LabelManager> GetInstance(); | |||
std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||
const std::vector<void *> &all_label); | |||
private: | |||
std::mutex model_info_mapping_mutex_; | |||
std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_; | |||
static std::weak_ptr<LabelManager> instance_; | |||
static std::mutex instance_mutex_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_ |
@@ -24,14 +24,14 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
all_label_resource_(), | |||
label_info_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
all_label_resource_ = model_context.label_list(); | |||
rt_model_handle_ = model_context.rt_model_handle(); | |||
auto all_label_resource = model_context.label_list(); | |||
auto stream_list = model_context.stream_list(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
@@ -40,18 +40,16 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
} | |||
LabelSwitchTask::~LabelSwitchTask() { | |||
if (label_info_ != nullptr) { | |||
rtError_t rt_ret = rtFree(label_info_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
} | |||
label_info_ = nullptr; | |||
label_manager_ = LabelManager::GetInstance(); | |||
if (label_manager_ == nullptr) { | |||
GELOGW("Get label manager instance failed."); | |||
return; | |||
} | |||
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource); | |||
} | |||
LabelSwitchTask::~LabelSwitchTask() {} | |||
bool LabelSwitchTask::Distribute() { | |||
GELOGI("LabelSwitchTask Distribute start."); | |||
if (!CheckParamValid()) { | |||
@@ -117,8 +115,8 @@ bool LabelSwitchTask::CheckParamValid() { | |||
return false; | |||
} | |||
if (label_info_ != nullptr) { | |||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
if (label_info_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null."); | |||
return false; | |||
} | |||
@@ -126,6 +124,5 @@ bool LabelSwitchTask::CheckParamValid() { | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -19,6 +19,7 @@ | |||
#include <memory> | |||
#include "ge_runtime/task/task.h" | |||
#include "ge_runtime/task/label_manager.h" | |||
namespace ge { | |||
namespace model_runner { | |||
@@ -35,8 +36,9 @@ class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
void *stream_; | |||
std::vector<void *> all_label_resource_; | |||
void *label_info_; | |||
rtModel_t rt_model_handle_; | |||
std::shared_ptr<LabelGuard> label_info_; | |||
std::shared_ptr<LabelManager> label_manager_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
@@ -674,6 +674,12 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||
GELOGD("Current ctx is null."); | |||
ctx = nullptr; | |||
} | |||
std::function<void()> callback = [&]() { | |||
if (ctx != nullptr) { | |||
(void)rtCtxSetCurrent(ctx); | |||
} | |||
}; | |||
GE_MAKE_GUARD(restore, callback); | |||
GeRootModelPtr ge_root_model = nullptr; | |||
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | |||
@@ -712,11 +718,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||
} | |||
return ret; | |||
} | |||
if (ctx != nullptr) { | |||
(void)rtCtxSetCurrent(ctx); | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -86,6 +86,11 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node | |||
return false; | |||
} | |||
if (func_node->GetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||
GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str()); | |||
return true; | |||
} | |||
ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); | |||
if (owner_graph == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", | |||
@@ -474,6 +474,11 @@ Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<Subgr | |||
for (ge::NodePtr &node : graph->GetDirectNode()) { | |||
auto op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
if (op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID)) { | |||
op_desc->SetStreamId(kInvalidStream); | |||
GELOGI("Ffts node %s of type %s reassign to invalid stream.", node->GetName().c_str(), node->GetType().c_str()); | |||
continue; | |||
} | |||
int64_t stream_id = op_desc->GetStreamId(); | |||
if (ops_without_label.find(op_desc) != ops_without_label.end()) { | |||
if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { | |||
@@ -707,7 +707,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { | |||
if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { | |||
GE_CHECK_NOTNULL(kernel_buffer.GetData()); | |||
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | |||
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | |||
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); | |||
GE_CHECK_NOTNULL(tbe_kernel); | |||
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), | |||
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | |||
@@ -432,7 +432,11 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() { | |||
// Insert the send/recv event id to the graph | |||
Status StreamAllocator::InsertSyncEvents() { | |||
for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | |||
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { | |||
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); | |||
}; | |||
for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) { | |||
// Take the adjacent points, then judge whether need to insert the event | |||
for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { | |||
for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { | |||
@@ -531,6 +535,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const | |||
Status StreamAllocator::InsertEventsForSubgraph() { | |||
for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { | |||
GE_CHECK_NOTNULL(subgraph); | |||
const auto parent_node = subgraph->GetParentNode(); | |||
if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||
GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str()); | |||
continue; | |||
} | |||
for (const auto &node : subgraph->GetDirectNode()) { | |||
auto op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
@@ -354,7 +354,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||
}; | |||
GE_MAKE_GUARD(release, callback); | |||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | |||
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { | |||
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); | |||
}; | |||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) { | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
node_index++; | |||
@@ -380,10 +383,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||
GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); | |||
continue; | |||
} | |||
if (op_kernel_lib_name.empty()) { | |||
GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||
continue; | |||
} | |||
GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue, | |||
"Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||
auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | |||
if (kernel_info_store == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", | |||
@@ -394,6 +395,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||
} | |||
GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), | |||
type.c_str()); | |||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||
GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed", | |||
name.c_str(), type.c_str()); | |||
} | |||
// Profiling task | |||
size_t task_list_size_before = task_def_list.size(); | |||
GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); | |||
@@ -571,7 +576,24 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info | |||
return ret; | |||
} | |||
Status TaskGenerator::UpdateAnchorStatusForFfts(const NodePtr &node) { | |||
GELOGD("Start UpdateAnchorStatusForFfts for %s.", node->GetName().c_str()); | |||
if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||
for (size_t i = 0; i < node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||
auto sub_graph = NodeUtils::GetSubgraph(*node, i); | |||
GE_CHECK_NOTNULL(sub_graph); | |||
GELOGD("Start update anchor status for %s.", sub_graph->GetName().c_str()); | |||
for (auto &ffts_node : sub_graph->GetDirectNode()) { | |||
GE_CHK_STATUS_RET(UpdateAnchorStatus(ffts_node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", | |||
ffts_node->GetName().c_str(), ffts_node->GetType().c_str()); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { | |||
GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str()); | |||
if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", | |||
node->GetName().c_str(), node->GetType().c_str()); | |||
@@ -771,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||
GELOGI("Start AutoFindBpOpIndex"); | |||
NodePtr bp_node = nullptr; | |||
uint32_t current_idx = 0; | |||
uint32_t netoutput_idx = 0; | |||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
@@ -789,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||
if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { | |||
if (bp_node == nullptr) { | |||
bp_node = node; | |||
netoutput_idx = current_idx - 1; | |||
} | |||
} | |||
if (graph->GetNeedIteration()) { | |||
@@ -814,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||
if (bp_node == nullptr) { | |||
GELOGW("not find bp_node."); | |||
return SUCCESS; | |||
} else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) { | |||
profiling_point.bp_index = netoutput_idx; | |||
GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx); | |||
} else { | |||
profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node); | |||
} | |||
return SUCCESS; | |||
return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index); | |||
} | |||
uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const { | |||
uint32_t last_bp = 0; | |||
Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node, | |||
uint32_t &bp_index) const { | |||
bp_index = 0; | |||
auto target_desc = target_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(target_desc); | |||
OpDescPtr bp_op_desc = nullptr; | |||
for (auto &in_anchor : bp_node->GetAllInDataAnchors()) { | |||
auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||
if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { | |||
continue; | |||
for (auto &in_node : target_node->GetInAllNodes()) { | |||
GE_CHECK_NOTNULL(in_node); | |||
auto in_node_desc = in_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(in_node_desc); | |||
if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) && | |||
(in_node_desc->GetStreamId() == target_desc->GetStreamId())){ | |||
bp_op_desc = in_node_desc; | |||
} | |||
auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc(); | |||
GE_CHECK_NOTNULL(out_node_desc); | |||
if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) { | |||
bp_op_desc = out_node_desc; | |||
} | |||
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); | |||
} | |||
if (bp_op_desc == nullptr) { | |||
return last_bp; | |||
GELOGI("Did not find bp node."); | |||
return SUCCESS; | |||
} | |||
uint32_t current_idx = 0; | |||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | |||
@@ -849,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const | |||
GE_CHECK_NOTNULL(op_desc); | |||
current_idx++; | |||
if (op_desc->GetName() == bp_op_desc->GetName()) { | |||
last_bp = current_idx; | |||
GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); | |||
bp_index = current_idx; | |||
GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index); | |||
break; | |||
} | |||
} | |||
return last_bp; | |||
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(), | |||
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId()); | |||
return SUCCESS; | |||
} | |||
Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | |||
@@ -80,6 +80,7 @@ class TaskGenerator { | |||
Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | |||
std::vector<uint32_t> &all_reduce_nodes); | |||
private: | |||
Status UpdateAnchorStatusForFfts(const NodePtr &node); | |||
Status UpdateAnchorStatus(const NodePtr &node); | |||
Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); | |||
@@ -115,7 +116,7 @@ class TaskGenerator { | |||
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; | |||
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | |||
vector<uint32_t> &all_reduce_nodes) const; | |||
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const; | |||
Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const; | |||
Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | |||
ProfilingPoint &profiling_point) const; | |||
@@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||
} | |||
/// | |||
/// @brief Set Op _force_unknown_shape flag | |||
/// @param [in] node | |||
/// @param [in] force_unknown, set attribute if true | |||
/// @param [in] group_index, condition group index of node. | |||
/// @return | |||
/// | |||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||
if (!force_unknown) { | |||
return; | |||
} | |||
SetControlFlowGroup(node, group_index); | |||
} | |||
/// | |||
/// @brief Set Op _control_flow_group flag | |||
/// @param [in] node | |||
/// @param [in] group, condition group index of node. | |||
@@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); | |||
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||
/// | |||
/// @brief Set Op _force_unknown_shape flag | |||
/// @param [in] node | |||
/// @param [in] force_unknown, set attribute if true | |||
/// @param [in] group_index, condition group index of node. | |||
/// @return | |||
/// | |||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||
/// | |||
/// @brief Set Op _control_flow_group flag | |||
/// @param [in] node | |||
/// @param [in] group, condition group index of node. | |||
@@ -99,6 +99,9 @@ const uint32_t kEndOfSequenceNew = 507005; | |||
const int32_t kModelAbortNormal = 0x0704000e; | |||
const int32_t kModelAbortNormalNew = 507024; | |||
const uint32_t kInteval = 2; | |||
const uint32_t kFftsTbeHandleElementSize = 2; | |||
const uint32_t kNonTailBlock = 0; | |||
const uint32_t kTailBlock = 1; | |||
const char *const kModelName = "model_name"; | |||
const char *const kModeleId = "model_id"; | |||
const char *const kLoadStartTime = "load_start_time"; | |||
@@ -116,14 +119,15 @@ const char *const kWorkSpaceSize = "workspace_size"; | |||
const char *const kTotalSize = "total_size"; | |||
const char *const kTaskCount = "task_count"; | |||
const char *const kTaskId = "task_id"; | |||
const char* const kRequestId = "request_id"; | |||
const char* const kThreadId = "thread_id"; | |||
const char* const kInputBeginTime = "input_begin_time"; | |||
const char* const kInputEndTime = "input_end_time"; | |||
const char* const kInferBeginTime = "infer_begin_time"; | |||
const char* const kInferEndTime = "infer_end_time"; | |||
const char* const kOutputBeginTime = "output_start_time"; | |||
const char* const kOutputEndTime = "output_end_time"; | |||
const char *const kRequestId = "request_id"; | |||
const char *const kThreadId = "thread_id"; | |||
const char *const kInputBeginTime = "input_begin_time"; | |||
const char *const kInputEndTime = "input_end_time"; | |||
const char *const kInferBeginTime = "infer_begin_time"; | |||
const char *const kInferEndTime = "infer_end_time"; | |||
const char *const kOutputBeginTime = "output_start_time"; | |||
const char *const kOutputEndTime = "output_end_time"; | |||
const char *const kStubFuncName = "_register_stub_func"; | |||
const uint32_t kStringHeadElems = 2; | |||
const uint32_t kPlacementHostData = 0; | |||
const size_t kAlignment = 64; | |||
@@ -902,10 +906,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||
SetLabelForDynamic(node); | |||
auto it = op_desc_handle.find(op_desc->GetType()); | |||
if (it != op_desc_handle.end()) { | |||
if ((this->*it->second)(op_desc) != SUCCESS) { | |||
GELOGE(PARAM_INVALID, "[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((this->*it->second)(op_desc) != SUCCESS, return PARAM_INVALID, | |||
"[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); | |||
continue; | |||
} | |||
@@ -935,7 +937,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||
GE_TIMESTAMP_RESTART(InitTbeHandle); | |||
if (IsTbeTask(op_desc)) { | |||
Status status = InitTbeHandle(op_desc); | |||
Status status = | |||
op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); | |||
return status; | |||
@@ -3700,6 +3703,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) { | |||
/// @return Status | |||
/// | |||
Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | |||
string bin_file = op_desc->GetName(); | |||
auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); | |||
auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | |||
if (tbe_kernel == nullptr) { | |||
@@ -3708,12 +3712,61 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed", | |||
bin_file.c_str()); | |||
return SUCCESS; | |||
} | |||
std::string session_graph_model_id; | |||
GetUniqueId(op_desc, session_graph_model_id); | |||
const char *bin_file_key = GetRegisterStub(op_desc->GetName(), session_graph_model_id); // from set, always valid. | |||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||
Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) { | |||
std::vector<OpKernelBinPtr> tbe_kernel; | |||
tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); | |||
GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size()); | |||
if (tbe_kernel.size() != kFftsTbeHandleElementSize) { | |||
REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts", | |||
op_desc->GetName().c_str(), tbe_kernel.size()); | |||
return INTERNAL_ERROR; | |||
} | |||
if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
vector<string> bin_file_keys; | |||
(void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys); | |||
if (bin_file_keys.size() != kFftsTbeHandleElementSize) { | |||
REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(), | |||
op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts", | |||
op_desc->GetName().c_str(), bin_file_keys.size()); | |||
return INTERNAL_ERROR; | |||
} | |||
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true, | |||
kNonTailBlock), | |||
"Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str()); | |||
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock), | |||
"Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str()); | |||
return SUCCESS; | |||
} | |||
Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, | |||
bool is_ffts, size_t thread_index) { | |||
if (thread_index > 1) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index); | |||
return INTERNAL_ERROR; | |||
} | |||
const char *bin_file_key; | |||
if (is_ffts) { | |||
bin_file_key = GetRegisterStub(bin_file, ""); | |||
GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key); | |||
} else { | |||
std::string session_graph_model_id; | |||
GetUniqueId(op_desc, session_graph_model_id); | |||
bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid. | |||
} | |||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||
std::lock_guard<std::mutex> lock(tvm_bin_mutex_); | |||
if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { | |||
void *bin_handle = nullptr; | |||
@@ -3721,59 +3774,115 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | |||
GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); | |||
rtDevBinary_t binary; | |||
std::string json_string; | |||
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string), | |||
GELOGD("Get original type of session_graph_id.")); | |||
if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICPU") { | |||
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICPU; | |||
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF") { | |||
binary.magic = RT_DEV_BINARY_MAGIC_ELF; | |||
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") { | |||
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; | |||
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") { | |||
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE; | |||
} else { | |||
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||
return PARAM_INVALID; | |||
} | |||
GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.", | |||
op_desc->GetName().c_str()); | |||
binary.version = 0; | |||
binary.data = tbe_kernel->GetBinData(); | |||
binary.length = tbe_kernel->GetBinDataSize(); | |||
GELOGD("TBE: binary.length: %lu", binary.length); | |||
GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); | |||
std::string meta_data; | |||
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data), | |||
GELOGI("Get original type of json_string")); | |||
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); | |||
GE_IF_BOOL_EXEC(!meta_data.empty(), GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str()))); | |||
GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.", | |||
op_desc->GetName().c_str()); | |||
kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); | |||
} else { | |||
GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); | |||
kernel_store.ReferTBEHandle(bin_file_key); | |||
} | |||
std::string kernel_name; | |||
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name), | |||
GELOGD("Get original type of kernel_name")); | |||
GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.", | |||
op_desc->GetName().c_str()); | |||
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); | |||
used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. | |||
return SUCCESS; | |||
} | |||
// Kernel registed, Increase used num in store. | |||
StoreTbeHandle(bin_file_key); | |||
return SUCCESS; | |||
} | |||
Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, | |||
rtDevBinary_t &binary) { | |||
string json_string; | |||
const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC; | |||
const static std::map<std::string, uint32_t> binary_magics = { | |||
{"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}, | |||
{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, | |||
{"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC}, | |||
{"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE} | |||
}; | |||
if (is_ffts) { | |||
vector<string> json_list; | |||
(void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list); | |||
if (json_list.size() != kFftsTbeHandleElementSize) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.", | |||
tvm_magic.c_str(), thread_index, json_list.size()); | |||
return INTERNAL_ERROR; | |||
} | |||
json_string = json_list[thread_index]; | |||
} else { | |||
(void)AttrUtils::GetStr(op_desc, tvm_magic, json_string); | |||
} | |||
auto iter = binary_magics.find(json_string); | |||
if (iter == binary_magics.end()) { | |||
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||
tvm_magic.c_str(), json_string.c_str(), op_desc->GetName().c_str(), | |||
op_desc->GetType().c_str(), model_id_); | |||
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||
return PARAM_INVALID; | |||
} | |||
binary.magic = iter->second; | |||
return SUCCESS; | |||
} | |||
Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) { | |||
string meta_data; | |||
const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA; | |||
if (is_ffts) { | |||
vector<string> meta_data_list; | |||
(void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list); | |||
if (meta_data_list.size() != kFftsTbeHandleElementSize) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.", | |||
tvm_metadata.c_str(), thread_index, meta_data_list.size()); | |||
return INTERNAL_ERROR; | |||
} | |||
meta_data = meta_data_list[thread_index]; | |||
} else { | |||
(void)AttrUtils::GetStr(op_desc, tvm_metadata, meta_data); | |||
} | |||
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); | |||
if (!meta_data.empty()) { | |||
GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str())); | |||
} | |||
return SUCCESS; | |||
} | |||
Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) { | |||
if (is_ffts) { | |||
// delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile | |||
vector<string> kernel_name_list; | |||
auto pos = op_desc->GetName().find("/"); | |||
if (pos == std::string::npos) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, subgraph node name: %s.", op_desc->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname"; | |||
(void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list); | |||
if (kernel_name_list.size() != kFftsTbeHandleElementSize) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.", | |||
attr_kernel_name.c_str(), thread_index, kernel_name_list.size()); | |||
return INTERNAL_ERROR; | |||
} | |||
kernel_name = kernel_name_list[thread_index]; | |||
} else { | |||
string attr_kernel_name = op_desc->GetName() + "_kernelname"; | |||
(void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name); | |||
} | |||
return SUCCESS; | |||
} | |||
void DavinciModel::StoreTbeHandle(const std::string &handle_key) { | |||
// Online mode FE may call rtFunctionRegister. | |||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||
@@ -771,6 +771,12 @@ class DavinciModel { | |||
/// @return Status | |||
/// | |||
Status InitTbeHandle(const OpDescPtr &op_desc); | |||
Status InitTbeHandleWithFfts(const OpDescPtr &op_desc); | |||
Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts, | |||
size_t thread_index = 0); | |||
Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary); | |||
Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle); | |||
Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name); | |||
void StoreTbeHandle(const string &handle_key); | |||
void CleanTbeHandle(); | |||
@@ -1378,7 +1378,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ | |||
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | |||
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | |||
if (cust_aicpu_so_.size() == 0) return SUCCESS; | |||
if (cust_aicpu_so_.empty()) { | |||
return SUCCESS; | |||
} | |||
// get current context | |||
rtContext_t rt_cur_ctx = nullptr; | |||
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | |||
@@ -1394,9 +1396,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||
return SUCCESS; | |||
} | |||
rtStream_t stream = nullptr; | |||
vector<void *> allocated_mem; | |||
std::function<void()> callback = [&]() { | |||
for (auto mem : allocated_mem) { | |||
GE_CHK_RT(rtFree(mem)); | |||
} | |||
if (stream != nullptr) { | |||
GE_CHK_RT(rtStreamDestroy(stream)); | |||
} | |||
}; | |||
GE_MAKE_GUARD(release, callback); | |||
rtError_t status; | |||
rtStream_t stream = nullptr; | |||
vector<CustAicpuSoBuf> v_cust_so; | |||
void *args = nullptr; | |||
@@ -1471,13 +1483,6 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||
GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); | |||
return RT_ERROR_TO_GE_STATUS(status); | |||
} | |||
std::function<void()> callback = [&]() { | |||
for (auto mem : allocated_mem) { | |||
GE_CHK_RT(rtFree(mem)); | |||
} | |||
GE_CHK_RT(rtStreamDestroy(stream)); | |||
}; | |||
GE_MAKE_GUARD(release, callback); | |||
GELOGI("Cpu kernel launch task success."); | |||
return SUCCESS; | |||
} | |||
@@ -0,0 +1,393 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/load/model_manager/task_info/ffts_task_info.h" | |||
#include <vector> | |||
#include "graph/load/model_manager/davinci_model.h" | |||
namespace { | |||
constexpr uint32_t kAddrLen = sizeof(void *); | |||
} | |||
namespace ge { | |||
FftsTaskInfo::~FftsTaskInfo() { | |||
GE_FREE_RT_LOG(args_); | |||
} | |||
Status FftsTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||
GELOGI("FftsTaskInfo Init Start."); | |||
GE_CHECK_NOTNULL(davinci_model); | |||
davinci_model_ = davinci_model; | |||
GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList())); | |||
const domi::FftsTaskDef &ffts_task_def = task_def.ffts_task(); | |||
OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_task_def.op_index()); | |||
GE_CHECK_NOTNULL(op_desc); | |||
if ((ffts_task_def.sub_task_size() > static_cast<int>(RT_FFTS_MAX_SUB_TASK_NUM)) || | |||
(ffts_task_def.ticket_cache_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_NUM))) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Node: %s, sub task desc size: %d, ticket cache size: %d", | |||
op_desc->GetName().c_str(), ffts_task_def.sub_task_size(), ffts_task_def.ticket_cache_size()); | |||
return INTERNAL_ERROR; | |||
} | |||
args_size_ = kAddrLen * ffts_task_def.addr_size(); | |||
GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM)); | |||
InitFftsDescInfo(ffts_task_def.ffts_desc(), sub_task_info_.fftsDesc); | |||
sub_task_info_.fftsType = static_cast<rtFftsType_t>(ffts_task_def.ffts_type()); | |||
sub_task_info_.subTaskNum = ffts_task_def.sub_task_size(); | |||
for (int idx = 0; idx < ffts_task_def.sub_task_size(); ++idx) { | |||
GE_CHK_STATUS_RET_NOLOG(InitSubTaskInfo(ffts_task_def.sub_task(idx), sub_task_info_.subTask[idx])); | |||
} | |||
sub_task_info_.tickCacheNum = ffts_task_def.ticket_cache_size(); | |||
for (int idx = 0; idx < ffts_task_def.ticket_cache_size(); ++idx) { | |||
GE_CHK_STATUS_RET_NOLOG(InitTicketCache(ffts_task_def.ticket_cache(idx), sub_task_info_.ticketCache[idx])); | |||
} | |||
size_t data_size = kAddrLen * io_addrs_.size(); | |||
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||
GELOGI("FftsTaskInfo::Init Success. Node: %s, input/output size: %zu", op_desc->GetName().c_str(), io_addrs_.size()); | |||
return SUCCESS; | |||
} | |||
void FftsTaskInfo::InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc) { | |||
ffts_desc.tm = static_cast<uint8_t>(ffts_desc_def.tm()); | |||
ffts_desc.di = static_cast<uint8_t>(ffts_desc_def.di()); | |||
ffts_desc.dw = static_cast<uint8_t>(ffts_desc_def.dw()); | |||
ffts_desc.df = static_cast<uint8_t>(ffts_desc_def.df()); | |||
ffts_desc.dataSplitUnit = static_cast<uint8_t>(ffts_desc_def.data_split_unit()); | |||
ffts_desc.prefetchOstNum = static_cast<uint8_t>(ffts_desc_def.prefetch_ost_num()); | |||
ffts_desc.cacheMaintainOstNum = static_cast<uint8_t>(ffts_desc_def.cache_maintain_ost_num()); | |||
ffts_desc.aicPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_upper()); | |||
ffts_desc.aicPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_lower()); | |||
ffts_desc.aivPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_upper()); | |||
ffts_desc.aivPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_lower()); | |||
} | |||
Status FftsTaskInfo::InitSubTaskInfo(const domi::FftsSubTaskDef &sub_task_def, rtFftsSubTaskInfo_t &sub_task_desc) { | |||
if ((sub_task_def.dst_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) || | |||
(sub_task_def.src_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { | |||
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, dst tick cache id size: %d, src tick cache id size: %d", | |||
sub_task_def.dst_tick_cache_id_size(), sub_task_def.src_tick_cache_id_size()); | |||
return FAILED; | |||
} | |||
if (sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv()) { | |||
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, auto thread aic/aiv: %d, manual thread aic/aiv: %d", | |||
sub_task_def.has_auto_thread_aic_aiv(), sub_task_def.has_manual_thread_aic_aiv()); | |||
return FAILED; | |||
} | |||
thread_dim_ = sub_task_def.thread_dim(); | |||
GE_CHK_BOOL_RET_STATUS(thread_dim_ != 0, FAILED, "[Get][thread_dim] failed, Invalid thread dim: %u!", thread_dim_); | |||
sub_task_desc.subTaskType = static_cast<rtFftsSubTaskType_t>(sub_task_def.sub_task_type()); | |||
sub_task_desc.threadDim = sub_task_def.thread_dim(); | |||
sub_task_desc.dstTickCacheVldBitmap = sub_task_def.dst_tick_cache_vld_bitmap(); | |||
sub_task_desc.srcTickCacheVldBitmap = sub_task_def.src_tick_cache_vld_bitmap(); | |||
sub_task_desc.srcDataOutOfSubGraphBitmap = sub_task_def.src_data_out_of_subgraph_bitmap(); | |||
for (int idx = 0; idx < sub_task_def.dst_tick_cache_id_size(); ++idx) { | |||
sub_task_desc.dstTickCacheID[idx] = sub_task_def.dst_tick_cache_id(idx); | |||
} | |||
for (int idx = 0; idx < sub_task_def.src_tick_cache_id_size(); ++idx) { | |||
sub_task_desc.srcTickCacheID[idx] = sub_task_def.src_tick_cache_id(idx); | |||
} | |||
if (sub_task_def.has_auto_thread_aic_aiv()) { | |||
GE_CHK_STATUS_RET_NOLOG(InitAutoAicAiv(sub_task_def.auto_thread_aic_aiv(), sub_task_desc.custom.autoThreadAicAiv)); | |||
} | |||
if (sub_task_def.has_manual_thread_aic_aiv()) { | |||
GE_CHK_STATUS_RET_NOLOG( | |||
InitManualAicAiv(sub_task_def.manual_thread_aic_aiv(), sub_task_desc.custom.manualThreadAicAiv)); | |||
} | |||
if (sub_task_def.has_manual_thread_nop()) { | |||
GE_CHK_STATUS_RET_NOLOG(InitManualNop(sub_task_def.manual_thread_nop(), sub_task_desc.custom.manualThreadNop)); | |||
} | |||
return SUCCESS; | |||
} | |||
Status FftsTaskInfo::InitTicketCache(const domi::TicketCacheDef &ticket_cache_def, rtTicketCache_t &ticket_cache) { | |||
if (ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache()) { | |||
GELOGE(FAILED, "[Check][Param] Invalid TicketCacheDef, has auto thread cache: %d, has manual thread cache: %d", | |||
ticket_cache_def.has_auto_thread_cache(), ticket_cache_def.has_manual_thread_cache()); | |||
return FAILED; | |||
} | |||
ticket_cache.cacheOption = static_cast<rtCacheOp_t>(ticket_cache_def.cache_option()); | |||
ticket_cache.ticketCacheWindow = ticket_cache_def.ticket_cache_window(); | |||
if (ticket_cache_def.has_auto_thread_cache()) { | |||
InitAutoCacheInfo(ticket_cache_def.auto_thread_cache(), ticket_cache.custom.autoThreadCache); | |||
} | |||
if (ticket_cache_def.has_manual_thread_cache()) { | |||
GE_CHK_STATUS_RET_NOLOG( | |||
InitManualCacheInfo(ticket_cache_def.manual_thread_cache(), ticket_cache.custom.manualThreadCache)); | |||
} | |||
return SUCCESS; | |||
} | |||
// task_addr = {0,200,700,1000,2000, 3500} | |||
// task_addr_offset = {20,40,2,100,200} | |||
template <typename T> | |||
Status FftsTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, | |||
uint32_t addr_count) { | |||
for (uint32_t i = 0; i < addr_count; ++i) { | |||
uintptr_t logic_addr = aic_aiv_def.task_addr(i) + thread_dim * aic_aiv_def.task_addr_offset(i); | |||
uint8_t *io_addr = nullptr; | |||
if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress]GetRtAddress failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGD("aic_aiv_def task base addr is %ld, offset is %ld, thread is %d, logic addrs is 0x%lx, io addr is %p", | |||
aic_aiv_def.task_addr(i), aic_aiv_def.task_addr_offset(i), thread_dim, logic_addr, io_addr); | |||
io_addrs_.emplace_back(io_addr); | |||
} | |||
return SUCCESS; | |||
} | |||
Status FftsTaskInfo::InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv) { | |||
if (aic_aiv_def.src_prefetch_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { | |||
GELOGE(FAILED, "[Check][Param] Invalid AutoThreadAicAivInfo, prefetch size: %d", aic_aiv_def.src_prefetch_size()); | |||
return FAILED; | |||
} | |||
aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size(); | |||
GELOGD("AutoThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); | |||
const auto &rts_param = davinci_model_->GetRuntimeParam(); | |||
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { | |||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, | |||
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size()))); | |||
} | |||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); | |||
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); | |||
for (int k = 0; k < last_thread_workspace_size; ++k) { | |||
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); | |||
uint8_t *io_addr = nullptr; | |||
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); | |||
GELOGD("logic addr is 0x%lx, io addr is %p.", logic_addr, io_addr); | |||
io_addrs_.emplace_back(io_addr); | |||
} | |||
aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); | |||
GELOGD("args_: %p, io_addrs size: %zu, task param offset: %u.", args_, io_addrs_.size(), aic_aiv.taskParamOffset); | |||
aic_aiv.satMode = aic_aiv_def.sat_mode(); | |||
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); | |||
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); | |||
aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); | |||
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); | |||
aic_aiv.tailBlkDim = aic_aiv_def.tail_blk_dim(); | |||
aic_aiv.nonTailBlkDim = aic_aiv_def.non_tail_blk_dim(); | |||
aic_aiv.nonTailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.non_tail_task_func_stub(), ""); | |||
aic_aiv.tailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.tail_task_func_stub(), ""); | |||
GELOGI("Set func name[%s][%s] succ.", aic_aiv.nonTailTaskFuncStub, aic_aiv.tailTaskFuncStub); | |||
for (int idx = 0; idx < aic_aiv_def.src_prefetch_size(); ++idx) { | |||
InitAutoPrefetch(aic_aiv_def.src_prefetch(idx), aic_aiv.srcPrefetch[idx]); | |||
} | |||
return SUCCESS; | |||
} | |||
void FftsTaskInfo::InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache) { | |||
cache.dataAddr = cache_def.data_addr(); | |||
cache.dataAddrOffset = cache_def.data_addr_offset(); | |||
cache.nonTailDataLen = cache_def.non_tail_data_len(); | |||
cache.tailDataLen = cache_def.tail_data_len(); | |||
cache.ticketCacheRefCnt = cache_def.ticket_cache_ref_cnt(); | |||
} | |||
void FftsTaskInfo::InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch) { | |||
prefetch.dataAddr = prefetch_def.data_addr(); | |||
prefetch.dataAddrOffset = prefetch_def.data_addr_offset(); | |||
prefetch.nonTailDataLen = prefetch_def.non_tail_data_len(); | |||
prefetch.tailDataLen = prefetch_def.tail_data_len(); | |||
} | |||
Status FftsTaskInfo::InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, | |||
rtManualThreadAicAivInfo_t &aic_aiv) { | |||
if ((aic_aiv_def.thread_prefetch_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||
(aic_aiv_def.thread_blk_dim_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||
(aic_aiv_def.thread_task_func_stub_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||
(aic_aiv_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { | |||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadAicAivInfo, thread prefetch dmu desc size: %d, " | |||
"thread blk dim size: %d, thread task func stub size: %d, src dep tbl size: %d", | |||
aic_aiv_def.thread_prefetch_dmu_idx_size(), aic_aiv_def.thread_blk_dim_size(), | |||
aic_aiv_def.thread_task_func_stub_size(), aic_aiv_def.src_dep_tbl_size()); | |||
return FAILED; | |||
} | |||
aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size(); | |||
GELOGD("ManualThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); | |||
const auto &rts_param = davinci_model_->GetRuntimeParam(); | |||
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { | |||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, | |||
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size()))); | |||
} | |||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); | |||
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); | |||
for (int k = 0; k < last_thread_workspace_size; ++k) { | |||
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); | |||
uint8_t *io_addr = nullptr; | |||
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); | |||
io_addrs_.emplace_back(io_addr); | |||
} | |||
aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); | |||
aic_aiv.satMode = aic_aiv_def.sat_mode(); | |||
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); | |||
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); | |||
aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); // 8 bit bitmap 1 0 1 0 | |||
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); // 8 bit bitmap 1 0 1 0 | |||
aic_aiv.prefetchOnceDmuNum = aic_aiv_def.prefetch_once_dmu_num(); | |||
for (int idx = 0; idx < aic_aiv_def.thread_prefetch_dmu_idx_size(); ++idx) { | |||
aic_aiv.threadPrefetchDmuIdx[idx] = aic_aiv_def.thread_prefetch_dmu_idx(idx); | |||
} | |||
for (int idx = 0; idx < aic_aiv_def.thread_blk_dim_size(); ++idx) { | |||
aic_aiv.threadBlkDim[idx] = aic_aiv_def.thread_blk_dim(idx); | |||
} | |||
for (int idx = 0; idx < aic_aiv_def.thread_task_func_stub_size(); ++idx) { | |||
aic_aiv.threadTaskFuncStub[idx] = aic_aiv_def.thread_task_func_stub(idx).c_str(); | |||
} | |||
InitManualDmuInfo(aic_aiv_def, aic_aiv.prefetchList); | |||
for (int idx = 0; idx < aic_aiv_def.src_dep_tbl_size(); ++idx) { | |||
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(aic_aiv_def.src_dep_tbl(idx), aic_aiv.srcDepTbl[idx])); | |||
} | |||
return SUCCESS; | |||
} | |||
Status FftsTaskInfo::InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, | |||
rtManualThreadCacheInfo_t &cache_info) { | |||
if ((cache_def.slice_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||
(cache_def.ticket_cache_ref_cnt_tbl_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM))) { | |||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadCacheInfo slice dum desc index %d, ticket cache ref cnt %d", | |||
cache_def.slice_dmu_idx_size(), cache_def.ticket_cache_ref_cnt_tbl_size()); | |||
return FAILED; | |||
} | |||
InitManualDmuInfo(cache_def, cache_info.dmuList); | |||
for (int idx = 0; idx < cache_def.slice_dmu_idx_size(); ++idx) { | |||
cache_info.sliceDmuIdx[idx] = cache_def.slice_dmu_idx(idx); | |||
} | |||
for (int idx = 0; idx < cache_def.ticket_cache_ref_cnt_tbl_size(); ++idx) { | |||
cache_info.ticketCacheRefCntTbl[idx] = cache_def.ticket_cache_ref_cnt_tbl(idx); | |||
} | |||
return SUCCESS; | |||
} | |||
Status FftsTaskInfo::InitManualDependency(const domi::ManualThreadDependencyDef &dependency_def, | |||
rtManualThreadDependency_t &dependency) { | |||
if (dependency_def.dependency_size() > static_cast<int>(RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN)) { | |||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadDependency size: %d", dependency_def.dependency_size()); | |||
return FAILED; | |||
} | |||
for (int idx = 0; idx < dependency_def.dependency_size(); ++idx) { | |||
dependency.dependency[idx] = dependency_def.dependency(idx); | |||
} | |||
return SUCCESS; | |||
} | |||
Status FftsTaskInfo::InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop_info) { | |||
if (nop_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { | |||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadNopInfo, src dep tbl size: %d", nop_def.src_dep_tbl_size()); | |||
return FAILED; | |||
} | |||
for (int idx = 0; idx < nop_def.src_dep_tbl_size(); ++idx) { | |||
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(nop_def.src_dep_tbl(idx), nop_info.srcDepTbl[idx])); | |||
} | |||
return SUCCESS; | |||
} | |||
void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu) { | |||
if (aic_aiv_def.prefetch_list().empty()) { | |||
return; | |||
} | |||
std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * aic_aiv_def.prefetch_list_size()); | |||
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data()); | |||
for (int idx = 0; idx < aic_aiv_def.prefetch_list_size(); ++idx) { | |||
InitManualDmuInfo(aic_aiv_def.prefetch_list(idx), dmu[idx]); | |||
} | |||
} | |||
void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu) { | |||
if (cache_def.dmu_list().empty()) { | |||
return; | |||
} | |||
std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * cache_def.dmu_list_size()); | |||
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data()); | |||
for (int idx = 0; idx < cache_def.dmu_list_size(); ++idx) { | |||
InitManualDmuInfo(cache_def.dmu_list(idx), dmu[idx]); | |||
} | |||
} | |||
void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu) { | |||
dmu.dataAddr = dmu_def.data_addr(); | |||
dmu.numOuter = dmu_def.num_outer(); | |||
dmu.numInner = dmu_def.num_inner(); | |||
dmu.strideOuter = dmu_def.stride_outer(); | |||
dmu.lenInner = dmu_def.len_inner(); | |||
dmu.strideInner = dmu_def.stride_inner(); | |||
} | |||
Status FftsTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||
return SUCCESS; | |||
} | |||
Status FftsTaskInfo::UpdateArgs() { | |||
GE_CHECK_NOTNULL(davinci_model_); | |||
std::vector<void *> io_addrs = io_addrs_; | |||
davinci_model_->UpdateKnownZeroCopyAddr(io_addrs); | |||
auto addr_size = kAddrLen * io_addrs.size(); | |||
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||
return SUCCESS; | |||
} | |||
Status FftsTaskInfo::Distribute() { | |||
GELOGI("FftsTaskInfo Distribute Start."); | |||
rtError_t rt_ret = rtFftsTaskLaunch(&sub_task_info_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsTaskLaunch failed, ret: 0x%X", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
GELOGI("FftsTaskInfo Distribute Success."); | |||
return SUCCESS; | |||
} | |||
REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_TASK, FftsTaskInfo); | |||
} // namespace ge |
@@ -0,0 +1,66 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ | |||
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ | |||
#include "graph/load/model_manager/task_info/task_info.h" | |||
#include "graph/op_desc.h" | |||
namespace ge { | |||
class FftsTaskInfo : public TaskInfo { | |||
public: | |||
FftsTaskInfo() = default; | |||
~FftsTaskInfo() override; | |||
Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
Status Distribute() override; | |||
Status UpdateArgs() override; | |||
Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
private: | |||
void InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc); | |||
Status InitSubTaskInfo(const domi::FftsSubTaskDef &task_def, rtFftsSubTaskInfo_t &task); | |||
Status InitTicketCache(const domi::TicketCacheDef &cache_def, rtTicketCache_t &cache); | |||
Status InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv); | |||
void InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache); | |||
void InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch); | |||
Status InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadAicAivInfo_t &aic_aiv); | |||
Status InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadCacheInfo_t &cache); | |||
Status InitManualDependency(const domi::ManualThreadDependencyDef &depend_def, rtManualThreadDependency_t &depend); | |||
Status InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop); | |||
void InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu); | |||
void InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu); | |||
void InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu); | |||
template<typename T> | |||
Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, uint32_t addr_count); | |||
DavinciModel *davinci_model_{nullptr}; | |||
rtFftsTaskInfo_t sub_task_info_; | |||
std::vector<void *> io_addrs_; | |||
uint32_t thread_dim_{0}; | |||
void *args_{nullptr}; // runtime args memory | |||
uint32_t args_size_{0}; // runtime args memory length | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ |
@@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||
GE_CHECK_NOTNULL(op_desc); | |||
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | |||
GE_CHECK_NOTNULL(args_addr); | |||
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | |||
if (sec_ret != EOK) { | |||
REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); | |||
@@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||
// copy args to new host memory | |||
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | |||
GE_CHECK_NOTNULL(args_addr); | |||
GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | |||
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | |||
if (sec_ret != EOK) { | |||
@@ -27,6 +27,7 @@ | |||
#include "common/math/math_util.h" | |||
#include "common/thread_pool.h" | |||
#include "common/dump/dump_manager.h" | |||
#include "ge_opt_info/ge_opt_info.h" | |||
#include "analyzer/analyzer.h" | |||
#include "graph/common/ge_call_wrapper.h" | |||
#include "graph/common/local_context.h" | |||
@@ -120,7 +121,6 @@ const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar"; | |||
const char *const kCheckPointGraph = "checkpoint_graph"; | |||
const char *const kVectorEngine = "VectorEngine"; | |||
const char *const kAIcoreEngine = "AIcoreEngine"; | |||
const char *const kRunFlagOffline = "0"; | |||
const int32_t kDynamicDimsTypeIsGetNext = 0; | |||
const int32_t kDynamicDimsTypeIsData = 1; | |||
const char *const kGetNextName = "IteratorV2"; | |||
@@ -950,7 +950,7 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint | |||
rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate faileded, session_id:%lu, graph_id:%u, mode:%d", | |||
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, session_id:%lu, graph_id:%u, mode:%d", | |||
session_id, graph_id, mode); | |||
GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode); | |||
return FAILED; | |||
@@ -1002,6 +1002,12 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||
return ret; | |||
} | |||
ret = GeOptInfo::SetOptInfo(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Set][OptInfo] Set optional information failed."); | |||
return ret; | |||
} | |||
/// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; | |||
/// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. | |||
/// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. | |||
@@ -1789,8 +1795,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||
return GE_GRAPH_OPTIONS_INVALID); | |||
// ge.graphType | |||
ret = | |||
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | |||
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | |||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); | |||
return GE_GRAPH_OPTIONS_INVALID); | |||
@@ -2436,6 +2441,8 @@ Status GraphManager::RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute | |||
continue; | |||
} | |||
if (n->GetOpDesc()->GetType() == CONSTANT || n->GetOpDesc()->GetType() == CONSTANTOP) { | |||
// reset const type depend on train_flag | |||
options_.train_graph_flag ? n->GetOpDesc()->SetType(CONSTANTOP) : n->GetOpDesc()->SetType(CONSTANT); | |||
if (n->GetOutAllNodes().empty() && n->GetInAllNodes().empty()) { | |||
// it is an isolated constant, just remove it | |||
if (GraphUtils::RemoveJustNode(compute_graph, n) != GRAPH_SUCCESS) { | |||
@@ -2762,35 +2769,22 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||
"Please pay attention to it."); | |||
} | |||
GE_CHK_STATUS_RET(ChangeConstType(compute_graph)); | |||
ChangeConstTypeWhenTraining(compute_graph); | |||
GELOGI("End optimize after merge sub graph."); | |||
return SUCCESS; | |||
} | |||
Status GraphManager::ChangeConstType(const ComputeGraphPtr &compute_graph) { | |||
// run_flag off means offline, on means online | |||
string run_flag; | |||
(void)ge::GetContext().GetOption(ge::RUN_FLAG, run_flag); | |||
// The constant for online is CONSTANTOP, and is CONSTANT for offline. They will be unified in future. | |||
if (run_flag == kRunFlagOffline) { | |||
GELOGI("Offline mode, change all Constant to Const."); | |||
} else { | |||
GELOGI("Online mode, change all Const to Constant."); | |||
} | |||
for (NodePtr &n : compute_graph->GetAllNodes()) { | |||
GE_CHECK_NOTNULL(n); | |||
if (n->GetType() == CONSTANT || n->GetType() == CONSTANTOP) { | |||
auto op_desc = n->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
if (run_flag == kRunFlagOffline) { | |||
op_desc->SetType(CONSTANT); | |||
} else { | |||
op_desc->SetType(CONSTANTOP); | |||
void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph) { | |||
// The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future. | |||
if (options_.train_graph_flag) { | |||
for (NodePtr &n : compute_graph->GetAllNodes()) { | |||
// This can ensure that n is not a null pointer | |||
if (n->GetOpDesc()->GetType() == CONSTANT) { | |||
n->GetOpDesc()->SetType(CONSTANTOP); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | |||
@@ -3145,10 +3139,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
} | |||
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | |||
if (count > 1 && graph_node->GetBuildFlag()) { | |||
graph_node->Lock(); | |||
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | |||
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | |||
graph_node->SetSemSize(count); | |||
graph_node->Lock(); | |||
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | |||
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | |||
GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | |||
@@ -375,7 +375,7 @@ class GraphManager { | |||
static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, | |||
Status ret, const string &log); | |||
Status ChangeConstType(const ComputeGraphPtr &compute_graph); | |||
void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); | |||
Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | |||
ge::ComputeGraphPtr &compute_graph, uint64_t session_id); | |||
@@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) { | |||
GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); | |||
continue; | |||
} | |||
#ifndef ONLY_COMPILE_OPEN_SRC | |||
GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); | |||
ret = (iter->second)->OptimizeAfterStage1(*compute_graph); | |||
#endif | |||
if (ret != SUCCESS) { | |||
REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " | |||
"graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); | |||
@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() { | |||
auto cluster = MakeShared<Cluster>(rank++, type, node, this); | |||
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); | |||
node_2_cluster_[node] = cluster; | |||
if (cluster->IsUnknownShape()) { | |||
ordered_cluster_.push_back(cluster); | |||
} | |||
int64_t group_index = -1; | |||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() { | |||
return SUCCESS; | |||
} | |||
Status DynamicShapePartitioner::TopologicalSortClusters() { | |||
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) { | |||
ordered_cluster_.clear(); | |||
// BFS topological sort clusters for known shape cluster | |||
std::queue<ClusterPtr> ready_clusters; | |||
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { | |||
auto cluster = ready_clusters.front(); | |||
ready_clusters.pop(); | |||
cluster->UpdateRank(rank++); | |||
if (cluster->IsKnownShape() || cluster->IsInputNode()) { | |||
if (ordered_filter == nullptr || ordered_filter(cluster)) { | |||
ordered_cluster_.push_back(cluster); | |||
} | |||
for (const auto &out_cluster : cluster->Outputs()) { | |||
@@ -364,6 +361,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||
} | |||
void DynamicShapePartitioner::MergeClustersControlFlow() { | |||
std::unordered_set<ClusterPtr> all_merged_clusters; | |||
for (const auto &item : control_clusters_) { | |||
const auto &control_cluster = item.second; | |||
auto rit = control_cluster.rbegin(); | |||
@@ -373,12 +371,21 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||
} | |||
const auto &cluster = *rit; | |||
if (all_merged_clusters.count(cluster) > 0) { | |||
continue; | |||
} | |||
for (++rit; rit != control_cluster.rend(); ++rit) { | |||
const auto &cluster_from = *rit; | |||
if (all_merged_clusters.count(cluster_from) > 0) { | |||
continue; | |||
} | |||
auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | |||
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | |||
ToString(merged_clusters).c_str()); | |||
for (const auto &merged_cluster : merged_clusters) { | |||
all_merged_clusters.emplace(merged_cluster); | |||
for (const auto &node : merged_cluster->Nodes()) { | |||
node_2_cluster_[node] = cluster; | |||
} | |||
@@ -459,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() { | |||
} | |||
Status DynamicShapePartitioner::MergeClusters() { | |||
const auto filter_known = [](const ClusterPtr &cluster) { | |||
return cluster->IsKnownShape() || cluster->IsInputNode(); | |||
}; | |||
const auto filter_unknown = [](const ClusterPtr &cluster) { | |||
return cluster->IsUnknownShape(); | |||
}; | |||
MergeClustersControlFlow(); | |||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown), | |||
"[TopologicalSort][Clusters] after merge control flow clusters failed."); | |||
MergeClustersUnknownShape(); | |||
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known), | |||
"[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||
MergeClustersKnownShape(); | |||
MergeClustersInputData(); | |||
return SUCCESS; | |||
@@ -703,7 +720,12 @@ void Cluster::Merge(ClusterPtr other) { | |||
if (other->min_ < min_) { | |||
min_ = other->min_; | |||
} | |||
}; | |||
if (!IsUnknownShape() && other->IsUnknownShape()) { | |||
type_ = UNKNOWN_SHAPE; | |||
} | |||
} | |||
bool Cluster::TryMerge(ClusterPtr other) { | |||
std::queue<ClusterPtr> forward_reached; | |||
forward_reached.push(other); | |||
@@ -111,6 +111,8 @@ class DynamicShapePartitioner { | |||
Status Partition(); | |||
using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>; | |||
private: | |||
Status PartitionImpl(); | |||
// Collect nodes that satisfy the unknowshape rules: | |||
@@ -138,7 +140,7 @@ class DynamicShapePartitioner { | |||
// Merge clusters step3 | |||
void MergeClustersInputData(); | |||
// Topological sort clusters after merge unknown shape clusters. | |||
Status TopologicalSortClusters(); | |||
Status TopologicalSortClusters(const OrderedFilter &ordered_filter); | |||
// Deduplicate merged clusters | |||
void PruneUniqueClusters(); | |||
// Establish the input-output anchors for each partition of the cluster and record links to other clusters | |||
@@ -161,7 +163,7 @@ class DynamicShapePartitioner { | |||
ge::ComputeGraphPtr root_graph_; // The original graph to partition | |||
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | |||
// V1 control flow cluster, need merge to one Graph. | |||
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||
std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||
// topological sorted clusters, this field will change with the splitting. | |||
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | |||
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | |||
@@ -179,6 +179,7 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr | |||
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | |||
} | |||
GE_CHECK_NOTNULL(original_compute_graph); | |||
output_merged_compute_graph->SetName(original_compute_graph->GetName()); | |||
// partition sub graph | |||
for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { | |||
ComputeGraphPtr merged_sub_graph = nullptr; | |||
@@ -188,8 +189,16 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr | |||
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | |||
continue; | |||
} | |||
// this means subgraph added in optimize subgraph and without partitions, so just add to root graph | |||
if (merged_sub_graph == sub_graph) { | |||
GELOGI("Just add subgraph %s (parent node is %s) to root graph %s.", sub_graph->GetName().c_str(), | |||
sub_graph->GetParentNode()->GetName().c_str(), output_merged_compute_graph->GetName().c_str()); | |||
sub_graph->SetParentGraph(sub_graph->GetParentNode()->GetOwnerComputeGraph()); | |||
GE_IF_BOOL_EXEC(output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph) != SUCCESS, | |||
return FAILED;) | |||
continue; | |||
} | |||
// add sub graph | |||
output_merged_compute_graph->SetName(original_compute_graph->GetName()); | |||
merged_sub_graph->SetName(sub_graph->GetName()); | |||
merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); | |||
merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); | |||
@@ -245,12 +254,9 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co | |||
} | |||
if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || | |||
(graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { | |||
REPORT_INNER_ERROR("E19999", "original_compute_graph:%s is not find in graph_2_graph_partition_info_.", | |||
original_compute_graph->GetName().c_str()); | |||
GELOGE(GE_GRAPH_NULL_INPUT, | |||
"[Check][Param] original_compute_graph:%s is not find in graph_2_graph_partition_info_.", | |||
original_compute_graph->GetName().c_str()); | |||
return FAILED; | |||
GELOGW("[GraphPartition]: compute_graph has not found, just return original."); | |||
output_merged_compute_graph = original_compute_graph; | |||
return SUCCESS; | |||
} | |||
GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; | |||
const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; | |||
@@ -708,6 +714,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr | |||
} | |||
auto &engine_name = graph_info_.partitions_.at(sub_graph); | |||
(void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | |||
(void)sub_graph->SetExtAttr("part_src_graph", compute_graph); | |||
GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | |||
compute_graph->GetName().c_str()); | |||
GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | |||
@@ -16,8 +16,6 @@ | |||
#include "mark_force_unknown_for_cond_pass.h" | |||
#include <queue> | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/common/omg_util.h" | |||
@@ -26,17 +24,7 @@ namespace { | |||
inline bool IsMergeInLoop(const NodePtr &node) { | |||
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||
std::string node_type; | |||
(void)GetOriginalType(node, node_type); | |||
return kLoopMergeInputs.count(node_type) > 0; | |||
} | |||
inline bool IsSwitchInLoop(const NodePtr &node) { | |||
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||
std::string node_type; | |||
(void)GetOriginalType(node, node_type); | |||
return kLoopSwitchInputs.count(node_type) > 0; | |||
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; | |||
} | |||
} | |||
@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
GELOGD("MarkForceUnknownForCondPass Enter"); | |||
std::map<NodePtr, std::vector<NodePtr>> switch_groups; | |||
for (const auto &node : graph->GetDirectNode()) { | |||
std::string node_type; | |||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); | |||
if (kMergeOpTypes.count(node_type) == 0) { | |||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { | |||
continue; | |||
} | |||
@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
} | |||
/// | |||
/// @brief Deal with Switch node for LoopCond | |||
/// @param [in] Switch node | |||
/// @param [in] dest span | |||
/// @param [out] Search queue | |||
/// @return true: Switch In while loop / false: Not in while Loop. | |||
/// | |||
bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, | |||
std::queue<std::pair<NodePtr, uint32_t>> &search_queue) { | |||
/// LoopCond --->\. | |||
/// \. | |||
/// Enter-----------+ \. | |||
/// +--> Merge --> Switch --> Exit | |||
/// NextIteration---+ | |||
const auto is_loop_op = [](const NodePtr &n) { | |||
return NodeUtils::GetNodeType(n) == LOOPCOND; | |||
}; | |||
const auto is_exit_op = [](const NodePtr &n) { | |||
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; | |||
}; | |||
const auto src_nodes = node->GetInAllNodes(); | |||
const auto dst_nodes = node->GetOutAllNodes(); | |||
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && | |||
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { | |||
return false; | |||
} | |||
for (const auto &m : src_nodes) { | |||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { | |||
for (const auto &n : m->GetInAllNodes()) { | |||
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { | |||
continue; | |||
} | |||
search_queue.push({n, dst_span}); | |||
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), | |||
n->GetName().c_str(), dst_span); | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
/// | |||
/// @brief Mark force unknown shape for Switch node | |||
/// @param [in] merge node | |||
/// @param [out] switch group | |||
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
/// | |||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | |||
// Switch --> {Switch --> Merge} --> Merge | |||
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); | |||
std::unordered_set<NodePtr> nodes_seen; | |||
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | |||
while (!search_queue.empty()) { | |||
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
const auto dst_span = search_queue.front().second; | |||
search_queue.pop(); | |||
// Switch --> Identity --> Constant | |||
for (const auto &in_node : dst_node->GetInControlNodes()) { | |||
if (nodes_seen.count(in_node) > 0) { | |||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||
continue; | |||
} | |||
nodes_seen.insert(in_node); | |||
if (in_node->GetType() == IDENTITY) { | |||
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), | |||
in_node->GetName().c_str(), dst_span); | |||
search_queue.push({in_node, dst_span}); | |||
} | |||
} | |||
for (const auto &in_node : dst_node->GetInDataNodes()) { | |||
for (const auto &in_node : dst_node->GetInAllNodes()) { | |||
if (nodes_seen.count(in_node) > 0) { | |||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||
continue; | |||
} | |||
nodes_seen.insert(in_node); | |||
std::string node_type; | |||
(void)GetOriginalType(in_node, node_type); | |||
const std::string node_type = NodeUtils::GetNodeType(in_node); | |||
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | |||
in_node->GetName().c_str(), dst_span); | |||
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | |||
if (DealAsLoopSwitch(in_node, dst_span, search_queue)) { | |||
continue; | |||
} | |||
if (dst_span > 0) { | |||
search_queue.push({in_node, dst_span - 1}); | |||
} else { | |||
const auto &all_in_nodes = in_node->GetInDataNodes(); | |||
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||
in_node->GetName().c_str()); | |||
} else { | |||
switch_group.emplace_back(in_node); | |||
} | |||
switch_group.emplace_back(in_node); | |||
} | |||
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | |||
search_queue.push({in_node, dst_span + 1}); | |||
@@ -132,39 +145,17 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
/// @return | |||
/// | |||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||
}; | |||
for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||
const auto &op_node1 = it1->first; | |||
const auto &op_desc1 = op_node1->GetOpDesc(); | |||
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { | |||
const auto &op_node = it->first; | |||
const auto &op_desc = op_node->GetOpDesc(); | |||
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
continue; | |||
} | |||
if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { | |||
int64_t group_index = op_desc1->GetId(); | |||
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||
MarkForceUnknownShape(op_node1, true, group_index); | |||
for (const auto &n : it1->second) { | |||
MarkForceUnknownShape(n, true, group_index); | |||
} | |||
for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||
const auto &op_node2 = it2->first; | |||
const auto &op_desc2 = op_node2->GetOpDesc(); | |||
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
continue; | |||
} | |||
if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||
MarkForceUnknownShape(op_node2, true, group_index); | |||
for (const auto &n : it2->second) { | |||
MarkForceUnknownShape(n, true, group_index); | |||
} | |||
} | |||
} | |||
int64_t group_index = op_desc->GetId(); | |||
SetControlFlowGroup(op_node, group_index); | |||
for (const auto &n : it->second) { | |||
SetControlFlowGroup(n, group_index); | |||
} | |||
} | |||
} | |||
@@ -19,6 +19,8 @@ | |||
#include "inc/graph_pass.h" | |||
#include <queue> | |||
namespace ge { | |||
class MarkForceUnknownForCondPass : public GraphPass { | |||
public: | |||
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { | |||
private: | |||
/// | |||
/// @brief Deal with Switch node for LoopCond | |||
/// @param [in] Switch node | |||
/// @param [in] dest span | |||
/// @param [out] Search queue | |||
/// @return true: Switch In while loop / false: Not in while Loop. | |||
/// | |||
bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue); | |||
/// | |||
/// @brief Mark force unknown shape for Switch node | |||
/// @param [in] merge node | |||
/// @param [out] switch group | |||
@@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { | |||
} | |||
} | |||
const auto &node = graph->GetParentNode(); | |||
if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) { | |||
GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | |||
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); | |||
} | |||
for (const auto &node : graph->GetDirectNode()) { | |||
GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); | |||
(void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); | |||
@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | |||
return FAILED, "[Check][Param] Param of pre node is nullptr."); | |||
int64_t group_index = -1; | |||
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
MarkForceUnknownShape(node, force_unknown, group_index); | |||
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | |||
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
MarkForceUnknownShape(active_node, force_unknown, group_index); | |||
SetControlFlowGroup(active_node, group_index); | |||
} | |||
return SUCCESS; | |||
@@ -24,7 +24,9 @@ using std::string; | |||
namespace ge { | |||
namespace { | |||
const int64_t kLoopType = 1; | |||
constexpr int64_t kLoopType = 1; | |||
constexpr uint8_t kMaxTransOp = 3; | |||
constexpr uint8_t kTransOpIoSize = 1; | |||
} | |||
Status NextIterationPass::Run(ComputeGraphPtr graph) { | |||
@@ -284,13 +286,28 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
/// @return void | |||
/// | |||
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | |||
std::string node_type; | |||
for (const auto &switch_node : loop_group.switch_nodes) { | |||
SetControlFlowGroup(switch_node, group_index); | |||
for (const auto &node : switch_node->GetOutDataNodes()) { | |||
std::string node_type; | |||
(void)GetOriginalType(node, node_type); | |||
if (kExitOpTypes.count(node_type) > 0) { | |||
SetControlFlowGroup(node, group_index); | |||
for (auto node : switch_node->GetOutDataNodes()) { | |||
// Switch --> Exit | |||
// Switch --> Cast --> Exit | |||
// Switch --> TransData --> Cast --> Exit | |||
for (uint8_t i = 0; i < kMaxTransOp; ++i) { | |||
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { | |||
break; | |||
} | |||
if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||
SetControlFlowGroup(node, group_index); | |||
break; | |||
} | |||
const auto &all_nodes = node->GetOutAllNodes(); | |||
if (all_nodes.size() != kTransOpIoSize) { | |||
break; | |||
} | |||
node = all_nodes.at(0); | |||
} | |||
} | |||
} | |||
@@ -15,7 +15,7 @@ | |||
*/ | |||
#include "graph/passes/parallel_group_pass.h" | |||
#include <queue> | |||
#include "framework/common/debug/ge_log.h" | |||
#include "common/ge/ge_util.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
@@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu | |||
for (const auto &switch_node : cur_itr->second.first) { | |||
int64_t pre_id = pre_node->GetOpDesc()->GetId(); | |||
int64_t switch_id = switch_node->GetOpDesc()->GetId(); | |||
// avoid ring | |||
if (pre_id > switch_id) { | |||
auto merge_node = cur_itr->second.second; | |||
if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { | |||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} else { | |||
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { | |||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
NodePtr first_node = pre_node; | |||
NodePtr second_node = switch_node; | |||
if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) { | |||
// avoid ring, merge->pre_node | |||
first_node = cur_itr->second.second; | |||
second_node = pre_node; | |||
} | |||
if (AddCtrlEdge(first_node, second_node) != SUCCESS) { | |||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||
first_node->GetName().c_str(), second_node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||
first_node->GetName().c_str(), second_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} else { | |||
@@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { | |||
return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && | |||
stream_switch_type == kLoopType); | |||
} | |||
bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) { | |||
if (node_a == nullptr || node_b == nullptr) { | |||
GELOGW("node_a or node_b is nullptr."); | |||
return false; | |||
} | |||
int64_t end_id = node_b->GetOpDesc()->GetId(); | |||
std::queue<NodePtr> nodes; | |||
nodes.push(node_a); | |||
while (!nodes.empty()) { | |||
NodePtr tmp_node = nodes.front(); | |||
nodes.pop(); | |||
if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr || | |||
tmp_node->GetOpDesc()->GetId() > end_id) { | |||
continue; | |||
} | |||
if (tmp_node == node_b) { | |||
return true; | |||
} | |||
for (const auto &out_node : tmp_node->GetOutAllNodes()) { | |||
nodes.push(out_node); | |||
} | |||
} | |||
return false; | |||
} | |||
} // namespace ge |
@@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass { | |||
bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); | |||
bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); | |||
bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b); | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H |
@@ -21,7 +21,23 @@ | |||
#include "framework/common/debug/ge_log.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
namespace { | |||
const std::unordered_set<std::string> kControlFlowOps = { | |||
ge::SWITCH, | |||
ge::REFSWITCH, | |||
ge::MERGE, | |||
ge::REFMERGE, | |||
ge::ENTER, | |||
ge::REFENTER, | |||
ge::NEXTITERATION, | |||
ge::REFNEXTITERATION, | |||
ge::EXIT, | |||
ge::REFEXIT, | |||
ge::LOOPCOND | |||
}; | |||
} | |||
namespace ge { | |||
Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { | |||
GELOGD("ReplaceWithEmptyConstPass in."); | |||
@@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { | |||
GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) { | |||
GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
// Node like no op, it has no output | |||
if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { | |||
GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); | |||
@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | |||
int64_t group_index = -1; | |||
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
SetControlFlowGroup(stream_switch, group_index); | |||
} | |||
return stream_switch; | |||
} | |||
@@ -491,8 +492,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | |||
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | |||
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | |||
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
std::set<NodePtr> same_cond_switch; | |||
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | |||
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | |||
@@ -524,13 +525,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | |||
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
}; | |||
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
SetControlFlowGroup(active_node, group_index); | |||
const std::string &cond_group = cond_node->GetName(); | |||
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | |||
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | |||
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
GE_IF_BOOL_EXEC(switch_list.empty(), continue); | |||
// select first stream_switch | |||
@@ -559,7 +560,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
"[Add][Edge] between %s and %s failed.", | |||
cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | |||
MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||
SetControlFlowGroup(stream_switch, group_index); | |||
for (const NodePtr &node : switch_list) { | |||
GE_IF_BOOL_EXEC(node != stream_switch, { | |||
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | |||
@@ -1420,9 +1420,10 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) { | |||
Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc) { | |||
auto format = desc.GetFormat(); | |||
auto origin_format = desc.GetOriginFormat(); | |||
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | |||
bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); | |||
if (need_check_internal_format) { | |||
bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); | |||
@@ -1439,6 +1440,63 @@ Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTens | |||
return SUCCESS; | |||
} | |||
Status GraphPrepare::UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc) { | |||
auto data_type = desc.GetDataType(); | |||
uint32_t length = 1; | |||
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | |||
if (!type_ret) { | |||
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + | |||
std::to_string(index) + " input tensor is not support"; | |||
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", | |||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
return FAILED; | |||
} | |||
int64_t desc_shape = desc.GetShape().GetShapeSize(); | |||
FMK_INT64_UINT32_MULCHECK(desc_shape, length); | |||
int64_t shape_size = desc_shape * length; | |||
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length)); | |||
int64_t size = 0; | |||
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, | |||
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); | |||
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); return FAILED); | |||
bool size_check = (size != 0 && shape_size != size); | |||
if (size_check) { | |||
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + | |||
"] != shape_size[" + std::to_string(size) + "], check invalid"; | |||
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); | |||
return FAILED; | |||
} | |||
ge::TensorUtils::SetSize(desc, shape_size); | |||
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | |||
if (!tune_flag) { | |||
graphStatus graph_ret = op->UpdateInputDesc(0, desc); | |||
if (graph_ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
return graph_ret; | |||
} | |||
// Size will be recalculated in the build stage | |||
ge::TensorUtils::SetSize(desc, 0); | |||
graph_ret = op->UpdateOutputDesc(0, desc); | |||
if (graph_ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
return graph_ret; | |||
} | |||
} else { | |||
GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | |||
const std::map<string, string> &graph_option) { | |||
// Get shape range of input in dynamic_execute mode | |||
@@ -1471,63 +1529,18 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | |||
} | |||
GeTensorDesc desc(user_input[index].GetTensorDesc()); | |||
// data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. | |||
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | |||
ret = CheckInternalFormat(input_node, desc, tune_flag); | |||
ret = CheckInternalFormat(input_node, desc); | |||
if (ret != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); | |||
return ret; | |||
} | |||
auto data_type = desc.GetDataType(); | |||
uint32_t length = 1; | |||
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | |||
if (!type_ret) { | |||
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + | |||
std::to_string(index) + " input tensor is not support"; | |||
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", | |||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
return FAILED; | |||
} | |||
int64_t desc_shape = desc.GetShape().GetShapeSize(); | |||
FMK_INT64_UINT32_MULCHECK(desc_shape, length); | |||
int64_t shape_size = desc_shape * length; | |||
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length)); | |||
int64_t size = 0; | |||
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, | |||
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); | |||
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); | |||
return FAILED); | |||
bool size_check = (size != 0 && shape_size != size); | |||
if (size_check) { | |||
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + | |||
"] != shape_size[" + std::to_string(size) + "], check invalid"; | |||
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); | |||
return FAILED; | |||
} | |||
ge::TensorUtils::SetSize(desc, shape_size); | |||
if (!tune_flag) { | |||
graphStatus graph_ret = op->UpdateInputDesc(0, desc); | |||
if (graph_ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
return graph_ret; | |||
} | |||
// Size will be recalculated in the build stage | |||
ge::TensorUtils::SetSize(desc, 0); | |||
graph_ret = op->UpdateOutputDesc(0, desc); | |||
if (graph_ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", | |||
op->GetName().c_str(), op->GetType().c_str()); | |||
return graph_ret; | |||
} | |||
} else { | |||
GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); | |||
ret = UpdateDataInputOutputDesc(index, op, desc); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "[Update][DataInputOutputDesc] on %s failed", op->GetName().c_str()); | |||
return ret; | |||
} | |||
if (!dynamic_shape_range_vec.empty()) { | |||
ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc); | |||
GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str()); | |||
@@ -63,7 +63,8 @@ class GraphPrepare { | |||
Status CheckRefOp(); | |||
Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); | |||
Status AdjustDataOpOutput(const NodePtr &node); | |||
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag); | |||
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc); | |||
Status UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc); | |||
Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | |||
Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | |||
Status CheckConstOp(); | |||
@@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std: | |||
} | |||
std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams()); | |||
GE_CHECK_NOTNULL(aipp_params); | |||
ge::GeAttrValue::NAMED_ATTRS aipp_attr; | |||
GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, | |||
"[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); | |||
@@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_ | |||
auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); | |||
if (!IsAllDimsPositive(dims)) { | |||
REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", | |||
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | |||
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", | |||
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | |||
return INTERNAL_ERROR; | |||
@@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge | |||
GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); | |||
return PARAM_INVALID; | |||
} | |||
GELOGD("FillKernel in, name: %s.", op_desc_ptr->GetName().c_str()); | |||
GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); | |||
GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); | |||
@@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge | |||
return NOT_CHANGED; | |||
} | |||
auto output_desc = op_desc_ptr->GetOutputDescPtr(0); | |||
GE_CHECK_NOTNULL(output_desc); | |||
if (output_desc->GetShape().IsUnknownShape()) { | |||
GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str()); | |||
return NOT_CHANGED; | |||
} | |||
GeTensorPtr output_ptr; | |||
output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0)); | |||
if (output_ptr == nullptr) { | |||
@@ -297,13 +297,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy | |||
} | |||
} | |||
tensor_desc->SetShape(shape); | |||
args.input_desc[input_index] = tensor_desc; | |||
GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); | |||
GELOGD("Update shape[%s] of input[%zu] to [%s]", | |||
shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str()); | |||
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), | |||
"[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," | |||
"index = %zu, shape = [%s], model_id = %u.", | |||
input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); | |||
GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size); | |||
GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size); | |||
TensorUtils::SetSize(*tensor_desc, tensor_size); | |||
args.input_desc[input_index] = tensor_desc; | |||
} | |||
GE_CHECK_GE(tensor_size, 0); | |||
@@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() { | |||
Status HybridModelExecutor::Init() { | |||
GELOGD("Start to init HybridGraphEngine."); | |||
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | |||
root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||
GE_CHECK_NOTNULL(root_graph_executor_); | |||
GELOGD("HybridGraphEngine initialized successfully."); | |||
return SUCCESS; | |||
} | |||
@@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | |||
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | |||
} | |||
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | |||
auto ret = ExecuteGraphInternal(executor, args); | |||
auto ret = ExecuteGraphInternal(args); | |||
Cleanup(); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | |||
GELOGD("Model executed successfully."); | |||
@@ -69,6 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||
context_.profiler->Dump(std::cout); | |||
context_.profiler->Reset(); | |||
} | |||
root_graph_executor_->ReleaseContext(); | |||
context_.iteration += 1; | |||
if (ret == END_OF_SEQUENCE) { | |||
@@ -79,8 +81,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||
return SUCCESS; | |||
} | |||
Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
HybridModelExecutor::ExecuteArgs &args) { | |||
Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | |||
GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | |||
@@ -94,7 +95,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | |||
} | |||
HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||
HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||
"Failed to execute partitioned call."); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | |||
@@ -103,7 +104,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
} | |||
if (!model_->IsSingleOp()) { | |||
Status ret = executor.Synchronize(); | |||
Status ret = root_graph_executor_->Synchronize(); | |||
if (ret != ge::SUCCESS) { | |||
auto model_manager = ModelManager::GetInstance(); | |||
GE_CHECK_NOTNULL(model_manager); | |||
@@ -123,7 +124,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
} | |||
args.outputs.clear(); | |||
HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||
HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | |||
return SUCCESS; | |||
} | |||
@@ -48,7 +48,7 @@ class HybridModelExecutor { | |||
Status Execute(ExecuteArgs &args); | |||
private: | |||
Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); | |||
Status ExecuteGraphInternal(ExecuteArgs &args); | |||
Status Cleanup(); | |||
Status InitExecutionContext(); | |||
static Status ResetExecutionContext(GraphExecutionContext &context); | |||
@@ -58,6 +58,7 @@ class HybridModelExecutor { | |||
uint32_t device_id_; | |||
rtStream_t stream_; | |||
GraphExecutionContext context_; | |||
std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -19,8 +19,9 @@ | |||
#include "framework/common/debug/log.h" | |||
#include "graph/compute_graph.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "hybrid_execution_context.h" | |||
#include "subgraph_context.h" | |||
#include "hybrid/executor/hybrid_execution_context.h" | |||
#include "hybrid/executor/subgraph_context.h" | |||
#include "hybrid/node_executor/task_context.h" | |||
#define INC_ITERATION_COUNT(iteration) \ | |||
do { \ | |||
@@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex | |||
this->op_desc_ = node_item.node->GetOpDesc(); | |||
} | |||
Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) { | |||
GE_CHECK_NOTNULL(frame_state); | |||
group_ = group; | |||
frame_state_ = frame_state; | |||
auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||
GE_CHECK_NOTNULL(unique_task_context); | |||
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
return SUCCESS; | |||
} | |||
Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | |||
if (node_item_->IsMergeOp()) { | |||
GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); | |||
@@ -314,15 +325,75 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
return task_context_; | |||
} | |||
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) { | |||
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) { | |||
return items.second.count(idx) > 0; | |||
}; | |||
return std::any_of(items.begin(), items.end(), is_exist); | |||
}; | |||
if (root_tensor_values_.count(input_idx) > 0) { | |||
return; | |||
} | |||
if (is_persist_tensor(node_item_->root_data_, input_idx)) { | |||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||
root_tensor_values_[input_idx] = tensor; | |||
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { | |||
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | |||
root_tensor_values_[input_idx] = tensor; | |||
} | |||
} | |||
void NodeState::UpdatePersistTensor() { | |||
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) { | |||
for (const auto &item : items) { | |||
for (const auto idx : item.second) { | |||
UpdatePersistTensor(idx); | |||
} | |||
} | |||
}; | |||
if (root_tensor_values_.empty()) { | |||
return; | |||
} | |||
update_tensor(node_item_->root_data_); | |||
if (iteration_count_ > 0) { | |||
update_tensor(node_item_->enter_data_); | |||
} | |||
} | |||
void NodeState::UpdatePersistTensor(int input_idx) { | |||
const auto it = root_tensor_values_.find(input_idx); | |||
if (it == root_tensor_values_.end()) { | |||
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | |||
return; | |||
} | |||
auto tensor = task_context_->MutableInput(input_idx); | |||
if (tensor == nullptr) { | |||
GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx); | |||
return; | |||
} | |||
*tensor = it->second; | |||
GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx); | |||
} | |||
void NodeState::ResetContext(uint64_t iteration) { | |||
switch_index_ = -1; | |||
subgraph_context_->ResetContext(node_item_->node); | |||
if (iteration == 0) { | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
} else { | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); | |||
auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||
GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); | |||
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
if (iteration > 0) { | |||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||
} | |||
iteration_count_ = iteration; | |||
@@ -100,6 +100,8 @@ struct NodeState { | |||
NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | |||
~NodeState() = default; | |||
Status Init(int group, const shared_ptr<FrameState> &frame_state); | |||
OpDesc *GetOpDesc() const { | |||
return op_desc_.get(); | |||
} | |||
@@ -129,6 +131,9 @@ struct NodeState { | |||
void RunStreamActive(); | |||
void RunNextIteration(); | |||
void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||
void UpdatePersistTensor(); | |||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
void SetScheduleFuture(std::future<Status> &&future); | |||
@@ -150,18 +155,10 @@ struct NodeState { | |||
return merge_index_; | |||
} | |||
void SetGroup(int group) { | |||
group_ = group; | |||
} | |||
int GetGroup() const { | |||
return group_; | |||
} | |||
void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||
frame_state_ = frame_state; | |||
} | |||
const shared_ptr<NodeTask> &GetKernelTask() const { | |||
return kernel_task_; | |||
} | |||
@@ -181,12 +178,17 @@ struct NodeState { | |||
void SetTaskContext(std::shared_ptr<TaskContext> &task_context); | |||
std::shared_ptr<TaskContext> GetTaskContext(); | |||
void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; } | |||
bool MaySkipShapeInference() const { return skip_infershape_; } | |||
private: | |||
bool IsScheduleReady() const; | |||
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
void ResetContext(uint64_t iteration); | |||
void ScheduleContext(const NodeState &node_state); | |||
void UpdatePersistTensor(int input_idx); | |||
const NodeItem *node_item_ = nullptr; | |||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | |||
@@ -199,6 +201,7 @@ struct NodeState { | |||
std::future<Status> schedule_future_; | |||
std::shared_ptr<FrameState> frame_state_; | |||
std::map<int, TensorValue> root_tensor_values_; | |||
uint64_t active_count_ = 0; | |||
uint64_t iteration_count_ = 0; | |||
uint32_t ctrl_scheduled_ = 0; | |||
@@ -206,6 +209,7 @@ struct NodeState { | |||
int merge_index_ = -1; // Use for Execute (Reset after Executed). | |||
int switch_index_ = -1; // Use for Schedule (Reset after Prepared). | |||
int group_ = -1; | |||
bool skip_infershape_ = false; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -19,7 +19,7 @@ | |||
namespace ge { | |||
namespace hybrid { | |||
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context) | |||
SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context) | |||
: graph_item_(graph_item), execution_context_(execution_context) { | |||
} | |||
@@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||
return nullptr; | |||
} | |||
return CreateNodeState(node_item); | |||
} | |||
NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { | |||
GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | |||
if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | |||
REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | |||
return nullptr; | |||
} | |||
auto &node_state = node_states_[node_item]; | |||
if (node_state == nullptr) { | |||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||
node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||
node_state->SetGroup(group_); | |||
(void)guard; | |||
} | |||
do { | |||
if (node_state == nullptr) { | |||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||
if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); | |||
break; | |||
} | |||
(void)guard; | |||
} | |||
} while (0); | |||
GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | |||
if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | |||
REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | |||
@@ -30,7 +30,7 @@ namespace ge { | |||
namespace hybrid { | |||
class SubgraphContext { | |||
public: | |||
explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); | |||
explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context); | |||
~SubgraphContext(); | |||
Status Init(); | |||
@@ -51,10 +51,11 @@ class SubgraphContext { | |||
void NodeDone(const NodePtr &node); | |||
private: | |||
NodeStatePtr CreateNodeState(const NodeItem *node_item); | |||
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | |||
friend class TaskContext; | |||
const GraphItem *graph_item_; | |||
const GraphExecutionContext *execution_context_; | |||
GraphExecutionContext *execution_context_; | |||
mmRWLock_t rw_lock_; | |||
std::vector<TensorValue> all_inputs_; | |||
std::vector<TensorValue> all_outputs_; | |||
@@ -103,6 +103,14 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | |||
GE_CHECK_NOTNULL(node_state); | |||
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | |||
auto op_desc = input_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); | |||
GE_CHECK_NOTNULL(output_desc); | |||
output_desc->SetShape(tensor_desc->GetShape()); | |||
output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | |||
output_desc->SetDataType(tensor_desc->GetDataType()); | |||
node_state->SetSkipInferShape(true); | |||
} | |||
} | |||
@@ -175,16 +183,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||
GE_CHECK_NOTNULL(node_state); | |||
node_state->SetKernelTask(node_item->kernel_task); | |||
known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||
GE_CHECK_NOTNULL(known_shape_task_context_); | |||
node_state->SetTaskContext(known_shape_task_context_); | |||
std::function<void()> callback; | |||
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | |||
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback), | |||
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback), | |||
"[%s] Failed to execute node [%s] for known subgraph.", | |||
graph_item_->GetName().c_str(), | |||
known_shape_task_context_->GetNodeName()); | |||
node_state->GetName().c_str()); | |||
GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); | |||
return SUCCESS; | |||
@@ -271,16 +275,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||
} else { | |||
node_state->SetKernelTask(node_item.kernel_task); | |||
} | |||
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||
GE_CHECK_NOTNULL(unique_task_context); | |||
const auto &task = node_state->GetKernelTask(); | |||
if (task == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | |||
return AfterPrepared(p_node_state); | |||
} | |||
@@ -480,19 +480,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||
} else { | |||
node_state.SetKernelTask(node_item.kernel_task); | |||
} | |||
auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); | |||
GE_CHECK_NOTNULL(unique_task_context); | |||
const auto &task = node_state.GetKernelTask(); | |||
if (task == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state.SetTaskContext(shared_task_context); | |||
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); | |||
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); | |||
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws | |||
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws | |||
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); | |||
return SUCCESS; | |||
} | |||
@@ -41,6 +41,8 @@ class SubgraphExecutor { | |||
Status PartialExecuteAsync(int task_group); | |||
void ReleaseContext() { subgraph_context_.reset(nullptr); } | |||
/** | |||
* Execute subgraph async, output tensor address(not data) and output tensor descriptions are | |||
* valid after this method returned | |||
@@ -125,7 +127,6 @@ class SubgraphExecutor { | |||
ThreadPool pre_run_pool_; | |||
BlockingQueue<NodeState *> ready_queue_; | |||
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | |||
std::shared_ptr<TaskContext> known_shape_task_context_; | |||
std::mutex mu_; // Guard for prepare_queues_. | |||
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; | |||
@@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | |||
auto executor = node_item.node_executor; | |||
GE_CHECK_NOTNULL(executor); | |||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | |||
node_state.UpdatePersistTensor(); | |||
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | |||
node_state.GetName().c_str()); | |||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | |||