NO IMAGE

1、概述

在tensorflow中的輸入資料會有很多形式一般有一下幾種形式

  • 資料以tf.constant的實行直接嵌入到graph中。在這種情況下一般資料量不會很大,應用場景也比較單一
  • 以tf.placeholder與feed_dic的形式存在

       在這種場景下,往往也需要將資料全部讀入到記憶體,轉換成tf的張量集合然後再進行處理。在進行大量資料處理時顯得的力不從心。

  • 以pipeline的方式從檔案中直接讀取資料並且採用多執行緒非同步形式來解決IO的瓶頸。

本章內容主要討論第三種方式。

2、tf.data API

tensorflow的tf.data類可以讓我們使用簡單的可重用的程式碼來構建複雜的輸入管道。並將資料構建、打亂資料、生成批量資料的功能,整合其中。同時tf.data提供了文字檔案輸入模型與影象輸入模型用於處理不同形式的輸入資料。

tf.data引入了2個新的概念

     1、tf.data.Dataset

tf.data.Dataset表示一個元素序列,其中每個元素包含一個或多個張量,例如,在影象流水線中,元素可以是單個訓練示例,其中一對張量表示影象資料和標籤。建立資料集有兩種不同的方法:

  • 從資料集中直接建立一個dataset物件
  • 從一個已有的dataset物件轉換一個新的dataset物件

     2、tf.data.Iterator

使用該api可以構造一個迭代器從Dataset中提取資料。我們可以使用Iterator.get_next()產生Dataset執行時的下一個元素,並且通常充當輸入管道程式碼和模型之間的介面。最簡單的迭代器是一個“一次性迭代器”,它與一個特定的Dataset迭代器相關聯並迭代一次。如果需要構造更為複雜的迭代器可以使用Iterator.initializer傳遞不同的引數來進行構建。

3、基本原理

  1. 要啟動輸入管道,您必須定義源。
  2. 一旦有了Dataset物件,就可以通過對物件進行連結方法呼叫將其轉換為新物件。
  3. 建立一個迭代器從Dataset物件中獲取資料

4、Dataset的基本結構

Dataset必須是具有相同結構的元素集合。每個元素至少包含一個Tensor物件,每一個元素被稱之為“元件”。

每一個元件都包含下面兩個非常重要的屬性:

  • tf.Dtype:用來表示元件中每一個元素的資料型別
  • tf.TensorShape:用來表示每個元素的靜態形狀

而就資料集本身來講也有兩個非常重要的屬性,我們更多的情況下是關注整個資料集的情況,而不是單個元件的情況:

  • Dataset.output_types:整體資料資料集中所包含的資料型別
  • Dataset.output_shapes:資料集的整體形狀綜述

請參照如下程式碼:


input_data = tf.random_uniform([4,10])
dataset1 = tf.data.Dataset.from_tensor_slices(input_data)
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"
with tf.Session() as sess:
print(sess.run(input_data))

dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

也可以對dataset進行任意組合:


dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

5、建立一個迭代器

要從資料集裡讀取資料,就要使用迭代器,tensorflow提供了下面幾種迭代器:

  • one-shot
  • initializable
  • reinitializable
  • feedable

one-shot迭代器是所有的迭代器中最簡單的一個。他只支援一次性迭代,而且無需初始化操作。one-shot迭代器幾乎處理現有基於佇列的輸入管道支援的所有情況,但它們不支援引數化。請參照下面程式碼:


dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(100):
value = sess.run(next_element)
print(value)

initializable迭代器可以使用placeholder對迭代器進行初始化,請參照如下程式碼:


max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value

reinitializable迭代器可以使用不同的已經被dataset來初始化迭代器。如果我們需要通過在訓練的時候,同時進行交叉驗證,那麼此時我們就會用到此類的迭代器,請參考一下程式碼:


training_dataset = tf.data.Dataset.range(100).map(
lambda x: x   tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
sess = tf.Session()
for _ in range(20):
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
sess.run(validation_init_op)
for _ in range(50):
print(sess.run(next_element))

feedable迭代器採用feed_dic機制在session.run時使用類似place_holder的機制來初始化不同的迭代器。他提供的功能與reinitializable迭代器類似,但並不需要在使用資料集之前就初始化迭代器。我們可以使用feedable迭代器實現上面類似的功能。


training_dataset = tf.data.Dataset.range(100).map(
lambda x: x   tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
sess = tf.Session()
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})