博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow lstm 实现 RNN / LSTM 的关键几个步骤 多层通俗易懂
阅读量:4145 次
发布时间:2019-05-25

本文共 2914 字,大约阅读时间需要 9 分钟。

 

其中该博客的展开的RNN图片有错误,左侧的原始状态输入应该没有纵向的数据传递连线

关于初始状态的输入数据结构重新分析

init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)

outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False),

 

注意上面这个init_state,若是单层返回的为 [batch_size,num_units],若是多层则返回的是列表[],见tensorflow的API,为什么zero_state该方法只需传递一个参数batch_size的讨论,注意init_state的数据结构不简单!

zero_state(    batch_size, dtype)

Return zero-filled state tensor(s).

Args:

  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.

Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size, state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size, s] for each s in state_size.

 

mlstm_cell是个MultiRNNCell类,它内部的hidden_size和layer_num已经定义好了,所以只需要告诉它batch是多少就可以了,所以这里只传入了batch_size

返回的是[batck_size,state_size],其实也就是[batck_size,num_units],初始数据形状与layer_num有关系吗,没有,只与batch_size*hidden_size也即batch*num_units有关系吧?是的

 

 

tensorflow API:

 

# 把784个点的字符信息还原成 28 * 28 的图片

# 下面几个步骤是实现 RNN / LSTM 的关键
####################################################################
# **步骤1:RNN 的输入shape = (batch_size, timestep_size, input_size) 
X = tf.reshape(_X, [-1, 28, 28])

# **步骤2:定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度

lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)

# **步骤3:添加 dropout layer, 一般只设置 output_keep_prob

lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)

# **步骤4:调用 MultiRNNCell 来实现多层 LSTM

mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)

# **步骤5:用全零来初始化state

init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)

# **步骤6:方法一,调用 dynamic_rnn() 来让我们构建好的网络运行起来

# ** 当 time_major==False 时, outputs.shape = [batch_size, timestep_size, hidden_size] 
# ** 所以,可以取 h_state = outputs[:, -1, :] 作为最后输出
# ** state.shape = [layer_num, 2, batch_size, hidden_size], 
# ** 或者,可以取 h_state = state[-1][1] 作为最后输出
# ** 最后输出维度是 [batch_size, hidden_size]
# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
# h_state = outputs[:, -1, :]  # 或者 h_state = state[-1][1]

# *************** 为了更好的理解 LSTM 工作原理,我们把上面 步骤6 中的函数自己来实现 ***************

# 通过查看文档你会发现, RNNCell 都提供了一个 __call__()函数(见最后附),我们可以用它来展开实现LSTM按时间步迭代。
# **步骤6:方法二,按时间步展开计算
outputs = list()
state = init_state
with tf.variable_scope('RNN'):
    for timestep in range(timestep_size):
        if timestep > 0:
            tf.get_variable_scope().reuse_variables()
        # 这里的state保存了每一层 LSTM 的状态
        (cell_output, state) = mlstm_cell(X[:, timestep, :], state)
        outputs.append(cell_output)
h_state = outputs[-1]
————————————————
版权声明:本文为CSDN博主「永永夜」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/Jerr__y/article/details/61195257

你可能感兴趣的文章
组队总结
查看>>
TitledBorder 设置JPanel边框
查看>>
DBCP——开源组件 的使用
查看>>
抓包工具
查看>>
海量数据相似度计算之simhash和海明距离
查看>>
DeepLearning tutorial(5)CNN卷积神经网络应用于人脸识别(详细流程+代码实现)
查看>>
DeepLearning tutorial(6)易用的深度学习框架Keras简介
查看>>
DeepLearning tutorial(7)深度学习框架Keras的使用-进阶
查看>>
流形学习-高维数据的降维与可视化
查看>>
Python-OpenCV人脸检测(代码)
查看>>
python+opencv之视频人脸识别
查看>>
人脸识别(OpenCV+Python)
查看>>
6个强大的AngularJS扩展应用
查看>>
网站用户登录系统设计——jsGen实现版
查看>>
第三方SDK:讯飞语音听写
查看>>
第三方SDK:JPush SDK Eclipse
查看>>
第三方开源库:imageLoader的使用
查看>>
自定义控件:飞入飞出的效果
查看>>
自定义控件:动态获取控件的高
查看>>
第三方开源库:nineoldandroid:ValueAnimator 动态设置textview的高
查看>>