You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_blocklstm_pb.gen.py 1.7 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. #!/usr/bin/env python3
  2. # -*- coding utf-8 -*-
  3. # Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.
  4. import tensorflow as tf
  5. import os
  6. pb_file_path = os.getcwd()
  7. def generate_case_0():
  8. with tf.compat.v1.Session(graph=tf.Graph()) as sess:
  9. input_dtype = tf.float32
  10. input_shape0 = [1, ]
  11. input_shape1 = [202, 1, 768]
  12. input_shape2 = [1, 1]
  13. input_shape3 = [1, 1]
  14. input_shape4 = [769, 4]
  15. input_shape5 = [1, ]
  16. input_shape6 = [1, ]
  17. input_shape7 = [1, ]
  18. input_shape8 = [4, ]
  19. d0 = tf.compat.v1.placeholder(dtype=tf.int64, shape=input_shape0)
  20. d1 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape1)
  21. d2 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape2)
  22. d3 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape3)
  23. d4 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape4)
  24. d5 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape5)
  25. d6 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape6)
  26. d7 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape7)
  27. d8 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape8)
  28. i1, cs1, f1, o1, ci1, co1, h1 = tf.raw_ops.BlockLSTM(seq_len_max=d0, x=d1, cs_prev=d2, h_prev=d3, w=d4, wci=d5, wcf=d6, wco=d7, b=d8,
  29. forget_bias=1, cell_clip=3, use_peephole=False, name="blockLSTM")
  30. tf.io.write_graph(sess.graph, logdir="./", name="blocklstm_case.pb", as_text=False)
  31. if __name__=='__main__':
  32. generate_case_0()