循环神经网络系列(三)Tensorflow中MultiRNNCell

循环神经网络系列(一) Tensorflow中BasicRNNCell
循环神经网络系列(二)Tensorflow中dynamic_rnn

经过前面两篇博文,我们介绍了如何定义一个RNN单元,以及用dynamic_rnn来对其在时间维度(横轴)上展开。我们今天要介绍的就是如何叠加多层RNN单元(如双向LSTM),同时对其按时间维度展开。具体多层RNN展开长什么样呢?还是用最直观的图来展示,如下所示:

其中A,B分别表示两个RNN单元,然后再分别对其按时间维度time_step=3进行展开,最终形成了两层,包含两个状态和3个输出。要完成这样一个例子,在Tensorflow中该如何来实现呢?

1. 先定义两个RNN单元

def get_a_cell(output_size):
    return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
    
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])

经过上面的8行代码,我们就定义好了两个堆叠在一起的RNN单元A和B,如下图所示:

2. 利用dynamic_rnn进行展开

import tensorflow as tf


def get_a_cell(output_size):
    return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)


output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])

inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
print(outputs)
print(final_state)

>>

Tensor("rnn/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4, 5), dtype=float32)
(<tf.Tensor 'rnn/while/Exit_2:0' shape=(4, 5) dtype=float32>, <tf.Tensor 'rnn/while/Exit_3:0' shape=(4, 5) dtype=float32>)

从第23行结果可知,输出的最后状态有两个,形状分别都是shape=(4,5),这也符合我们的预期;而第22行的输出结果shape=(3,4,5)有表示什么意思呢?这里的3就不表示维度了,而表示输出结果有3部分,每个部分的大小都是shape=(4,5),这也是我们所预期的。并且B层的final_state应该使等于第三个输出的。

3. 喂个实例跑跑

import tensorflow as tf
import numpy as np


def get_a_cell(output_size):
    return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)


output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])

inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
print(outputs)
print(final_state)

X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x1
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x2
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]])  # x3
sess = tf.Session()
sess.run(tf.global_variables_initializer())
a, b = sess.run([outputs, final_state], feed_dict={inputs: X})
print('outputs:')
print(a)
print('final_state:')
print(b)




>>
Tensor("rnn/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4, 5), dtype=float32)
(<tf.Tensor 'rnn/while/Exit_2:0' shape=(4, 5) dtype=float32>, <tf.Tensor 'rnn/while/Exit_3:0' shape=(4, 5) dtype=float32>)


outputs:
[[[-0.6958626  -0.6776572   0.15731043 -0.6311886   0.20267256]
  [ 0.07732188  0.09182965 -0.49770945  0.0051106   0.23445603]
  [-0.304461   -0.2706095  -0.4083268  -0.3364025   0.26729658]
  [-0.38100582 -0.35050285 -0.2153194  -0.3686508   0.21973696]]

 [[-0.38028494 -0.39984316  0.5924934  -0.7433707   0.45858386]
  [ 0.15477817  0.06120307 -0.23038468 -0.2532196   0.19319542]
  [-0.09605556 -0.23243633  0.18608333 -0.6444844   0.34893066]
  [-0.15772797 -0.2529126   0.32016686 -0.6125384   0.33331177]]

 [[-0.45718285 -0.20688602  0.66812176 -0.81284994 -0.03955056]
  [ 0.16529301  0.2245452  -0.45850635 -0.36383444  0.18540041]
  [-0.0918629   0.11388774  0.01027385 -0.7402484   0.06189062]
  [-0.21528585  0.00840321  0.20390712 -0.71303254  0.04809263]]]
final_state:
(array([[ 0.01885682,  0.79334605, -0.99330646, -0.19715786,  0.8772415 ],
       [-0.43402836, -0.2537776 , -0.52755517,  0.5360404 , -0.38291538],
       [-0.49418357,  0.28655267, -0.91146743,  0.4856847 ,  0.22705963],
       [-0.3087254 ,  0.42241457, -0.8743213 ,  0.26078507,  0.3464944 ]],
      dtype=float32), 
array([[-0.45718285, -0.20688602,  0.66812176, -0.81284994, -0.03955056],
       [ 0.16529301,  0.2245452 , -0.45850635, -0.36383444,  0.18540041],
       [-0.0918629 ,  0.11388774,  0.01027385, -0.7402484 ,  0.06189062],
       [-0.21528585,  0.00840321,  0.20390712, -0.71303254,  0.04809263]],
      dtype=float32))

可以看到output有3个部分,final_state有2个部分,且output的第三个结果和final_state的第二个结果相同,符合我们上面的猜想。

注意:

如果每层的输出大小要不同的话,直接在定义多层单元的时候填上不同的参数即可!

output_size = [5, 6]
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(size) for size in output_size])

更多内容欢迎扫码关注公众号月来客栈!
在这里插入图片描述

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页