|
- #!/usr/bin/env python3
- # -*- coding utf-8 -*-
- # Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.
-
- import tensorflow as tf
- import os
-
- pb_file_path = os.getcwd()
-
- def generate_case_0():
- with tf.compat.v1.Session(graph=tf.Graph()) as sess:
- input_dtype = tf.float32
- input_shape0 = [1, ]
- input_shape1 = [202, 1, 768]
- input_shape2 = [1, 1]
- input_shape3 = [1, 1]
- input_shape4 = [769, 4]
- input_shape5 = [1, ]
- input_shape6 = [1, ]
- input_shape7 = [1, ]
- input_shape8 = [4, ]
-
- d0 = tf.compat.v1.placeholder(dtype=tf.int64, shape=input_shape0)
- d1 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape1)
- d2 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape2)
- d3 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape3)
- d4 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape4)
- d5 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape5)
- d6 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape6)
- d7 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape7)
- d8 = tf.compat.v1.placeholder(dtype=input_dtype, shape=input_shape8)
-
- i1, cs1, f1, o1, ci1, co1, h1 = tf.raw_ops.BlockLSTM(seq_len_max=d0, x=d1, cs_prev=d2, h_prev=d3, w=d4, wci=d5, wcf=d6, wco=d7, b=d8,
- forget_bias=1, cell_clip=3, use_peephole=False, name="blockLSTM")
- tf.io.write_graph(sess.graph, logdir="./", name="blocklstm_case.pb", as_text=False)
-
- if __name__=='__main__':
- generate_case_0()
|