學習筆記TF061:分散式TensorFlow,分散式理、最佳實踐

NO IMAGE
1 Star2 Stars3 Stars4 Stars5 Stars 給文章打分!
Loading...

分散式TensorFlow由高效能gRPC庫底層技術支援。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。

分散式原理。分散式叢集 由多個伺服器程序、客戶端程序組成。部署方式,單機多卡、分散式(多機多卡)。多機多卡TensorFlow分散式。

單機多卡,單臺伺服器多塊GPU。訓練過程:在單機單GPU訓練,資料一個批次(batch)一個批次訓練。單機多GPU,一次處理多個批次資料,每個GPU處理一個批次資料計算。變數引數儲存在CPU,資料由CPU分發給多個GPU,GPU計算每個批次更新梯度。CPU收集完多個GPU更新梯度,計算平均梯度,更新引數。繼續計算更新梯度。處理速度取決最慢GPU速度。

分散式,訓練在多個工作節點(worker)。工作節點,實現計算單元。計算伺服器單卡,指伺服器。計算伺服器多卡,多個GPU劃分多個工作節點。資料量大,超過一臺機器處理能力,須用分散式。

分散式TensorFlow底層通訊,gRPC(google remote procedure call)。gRPC,谷歌開源高效能、跨語言RPC框架。RPC協議,遠端過程呼叫協議,網路從遠端計算機程度請求服務。

分散式部署方式。分散式執行,多個計算單元(工作節點),後端伺服器部署單工作節點、多工作節點。

單工作節點部署。每臺伺服器執行一個工作節點,伺服器多個GPU,一個工作節點可以訪問多塊GPU卡。程式碼tf.device()指定執行操作裝置。優勢,單機多GPU間通訊,效率高。劣勢,手動程式碼指定裝置。

多工作節點部署。一臺伺服器執行多個工作節點。

設定CUDA_VISIBLE_DEVICES環境變數,限制各個工作節點只可見一個GPU,啟動程序新增環境變數。用tf.device()指定特定GPU。多工作節點部署優勢,程式碼簡單,提高GPU使用率。劣勢,工作節點通訊,需部署多個工作節點。https://github.com/tobegit3hu…

CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0
CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1
CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0
CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1

分散式架構。https://www.tensorflow.org/ex… 。客戶端(client)、服務端(server),服務端包括主節點(master)、工作節點(worker)組成。

客戶端、主節點、工作節點關係。TensorFlow,客戶端會話聯絡主節點,實際工作由工作節點實現,每個工作節點佔一臺裝置(TensorFlow具體計算硬體抽象,CPU或GPU)。單機模式,客戶端、主節點、工作節點在同一臺伺服器。分佈模式,可不同伺服器。客戶端->主節點->工作節點/job:worker/task:0->/job:ps/task:0。

客戶端。建立TensorFlow計算圖,建立與叢集互動會話層。程式碼包含Session()。一個客戶端可同時與多個服務端相連,一具服務端也可與多個客戶端相連。

服務端。執行tf.train.Server例項程序,TensroFlow執行任務叢集(cluster)一部分。有主節點服務(Master service)和工作節點服務(Worker service)。執行中,一個主節點程序和數個工作節點程序,主節點程序和工作接點程序通過介面通訊。單機多卡和分散式結構相同,只需要更改通訊介面實現切換。

主節點服務。實現tensorflow::Session介面。通過RPC服務程式連線工作節點,與工作節點服務程序工作任務通訊。TensorFlow服務端,task_index為0作業(job)。

工作節點服務。實現worker_service.proto介面,本地裝置計算部分圖。TensorFlow服務端,所有工作節點包含工作節點服務邏輯。每個工作節點負責管理一個或多個裝置。工作節點可以是本地不同埠不同程序,或多臺服務多個程序。執行TensorFlow分散式執行任務集,一個或多個作業(job)。每個作業,一個或多個相同目的任務(task)。每個任務,一個工作程序執行。作業是任務集合,叢集是作業集合。

分散式機器學習框架,作業分引數作業(parameter job)和工作節點作業(worker job)。引數作業執行伺服器為引數伺服器(parameter server,PS),管理引數儲存、更新。工作節點作業,管理無狀態主要從事計算任務。模型越大,引數越多,模型引數更新超過一臺機器效能,需要把引數分開到不同機器儲存更新。引數服務,多臺機器組成叢集,類似分散式儲存架構,涉及資料同步、一致性,引數儲存為鍵值對(key-value)。分散式鍵值記憶體資料庫,加引數更新操作。李沐《Parameter Server for Distributed Machine Learning》http://www.cs.cmu.edu/~muli/f…

引數儲存更新在引數作業進行,模型計算在工作節點作業進行。TensorFlow分散式實現作業間資料傳輸,引數作業到工作節點作業前向傳播,工作節點作業到引數作業反向傳播。

任務。特定TensorFlow伺服器獨立程序,在作業中擁有對應序號。一個任務對應一個工作節點。叢集->作業->任務->工作節點。

客戶端、主節點、工作節點互動過程。單機多卡互動,客戶端->會話執行->主節點->執行子圖->工作節點->GPU0、GPU1。分散式互動,客戶端->會話執行->主節點程序->執行子圖1->工作節點程序1->GPU0、GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》https://arxiv.org/abs/1603.04…

分散式模式。

資料並行。https://www.tensorflow.org/tu… 。CPU負責梯度平均、引數更新,不同GPU訓練模型副本(model replica)。基於訓練樣例子集訓練,模型有獨立性。

步驟:不同GPU分別定義模型網路結構。單個GPU從資料管道讀取不同資料塊,前向傳播,計算損失,計算當前變數梯度。所有GPU輸出梯度資料轉移到CPU,梯度求平均操作,模型變數更新。重複,直到模型變數收斂。

資料並行,提高SGD效率。SGD mini-batch樣本,切成多份,模型複製多份,在多個模型上同時計算。多個模型計算速度不一致,CPU更新變數有同步、非同步兩個方案。

同步更新、非同步更新。分散式隨機梯度下降法,模型引數分散式儲存在不同引數服務上,工作節點並行訓練資料,和引數伺服器通訊獲取模型引數。

同步隨機梯度下降法(Sync-SGD,同步更新、同步訓練),訓練時,每個節點上工作任務讀入共享引數,執行並行梯度計算,同步需要等待所有工作節點把區域性梯度處好,將所有共享引數合併、累加,再一次性更新到模型引數,下一批次,所有工作節點用模型更新後引數訓練。優勢,每個訓練批次考慮所有工作節點訓練情部,損失下降穩定。劣勢,效能瓶頸在最慢工作節點。異楹裝置,工作節點效能不同,劣勢明顯。

非同步隨機梯度下降法(Async-SGD,非同步更新、非同步訓練),每個工作節點任務獨立計算區域性梯度,非同步更新到模型引數,不需執行協調、等待操作。優勢,效能不存在瓶頸。劣勢,每個工作節點計算梯度值發磅回引數伺服器有引數更新衝突,影響演算法收劍速度,損失下降過程抖動較大。

同步更新、非同步更新實現區別於更新引數伺服器引數策略。資料量小,各節點計算能力較均衡,用同步模型。資料量大,各機器計算效能參差不齊,用非同步模式。

帶備份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz論文《Revisiting Distributed Synchronous SGD》https://arxiv.org/abs/1604.00981 。增加工作節點,解決部分工作節點計算慢問題。工作節點總數n n*5%,n為叢集工作節點數。非同步更新設定接受到n個工作節點引數直接更新引數伺服器模型引數,進入下一批次模型訓練。計算較慢節點訓練引數直接丟棄。

同步更新、非同步更新有圖內模式(in-graph pattern)和圖間模式(between-graph pattern),獨立於圖內(in-graph)、圖間(between-graph)概念。

圖內複製(in-grasph replication),所有操作(operation)在同一個圖中,用一個客戶端來生成圖,把所有操作分配到叢集所有引數伺服器和工作節點上。國內複製和單機多卡類似,擴充套件到多機多卡,資料分發還是在客戶端一個節點上。優勢,計算節點只需要呼叫join()函式等待任務,客戶端隨時提交資料就可以訓練。劣勢,訓練資料分發在一個節點上,要分發給不同工作節點,嚴重影響併發訓練速度。

圖間複製(between-graph replication),每一個工作節點建立一個圖,訓練引數儲存在引數伺服器,資料不分發,各個工作節點獨立計算,計算完成把要更新引數告訴引數伺服器,引數伺服器更新引數。優勢,不需要資料分發,各個工作節點都建立圖和讀取資料訓練。劣勢,工作節點既是圖建立者又是計算任務執行者,某個工作節點宕機影響叢集工作。大資料相關深度學習推薦使用圖間模式。

模型並行。切分模型,模型不同部分執行在不同裝置上,一個批次樣本可以在不同裝置同時執行。TensorFlow儘量讓相鄰計算在同一臺裝置上完成節省網路開銷。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》https://arxiv.org/abs/1603.04…

模型並行、資料並行,TensorFlow中,計算可以分離,引數可以分離。可以在每個裝置上分配計算節點,讓對應引數也在該裝置上,計算引數放一起。

分散式API。https://www.tensorflow.org/de…

建立叢集,每個任務(task)啟動一個服務(工作節點服務或主節點服務)。任務可以分佈不同機器,可以同一臺機器啟動多個任務,用不同GPU執行。每個任務完成工作:建立一個tf.train.ClusterSpec,對叢集所有任務進行描述,描述內容對所有任務相同。建立一個tf.train.Server,建立一個服務,執行相應作業計算任務。

TensorFlow分散式開發API。tf.train.ClusterSpec({“ps”:ps_hosts,”worker”:worke_hosts})。建立TensorFlow叢集描述資訊,ps、worker為作業名稱,ps_phsts、worker_hosts為作業任務所在節點地址資訊。tf.train.ClusterSpec傳入引數,作業和任務間關係對映,對映關係任務通過IP地址、埠號表示。

結構 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
可用任務 /job:local/task:0、/job:local/task:1。
結構 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})
可用任務 /job:worker/task:0、 /job:worker/task:1、 /job:worker/task:2、 /job:ps/task:0、 /job:ps/task:1

tf.train.Server(cluster,job_name,task_index)。建立服務(主節點服務或工作節點服務),執行作業計算任務,執行任務在task_index指定機器啟動。

#任務0 
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server  = tr.train.Server(cluster,job_name="local",task_index=0) 
#任務1 
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server  = tr.train.Server(cluster,job_name="local",task_index=1)。

自動化管理節點、監控節點工具。叢集管理工具Kubernetes。
tf.device(device_name_or_function)。設定指定裝置執行張量運算,批定程式碼執行CPU、GPU。

#指定在task0所在機器執行Tensor操作運算 
with tf.device("/job:ps/task:0"):
weights_1 = tf.Variable(…)
biases_1 = tf.Variable(…)

分散式訓練程式碼框架。建立TensorFlow伺服器叢集,在該叢集分散式計算資料流圖。https://github.com/tensorflow…

import argparse
import sys
import tensorflow as tf
FLAGS = None
def main(_):
# 第1步:命令列引數解析,獲取叢集資訊ps_hosts、worker_hosts
# 當前節點角色資訊job_name、task_index
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# 第2步:建立當前任務節點伺服器
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
# 第3步:如果當前節點是引數伺服器,呼叫server.join()無休止等待;如果是工作節點,執行第4步
if FLAGS.job_name == "ps":
server.join()
# 第4步:構建要訓練模型,構建計算圖
elif FLAGS.job_name == "worker":
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
# Build model...
loss = ...
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps.
# 第5步管理模型訓練過程
hooks=[tf.train.StopAtStepHook(last_step=1000000)]
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir="/tmp/train_logs",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
# mon_sess.run handles AbortedError in case of preempted PS.
# 訓練模型
mon_sess.run(train_op)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]]   unparsed)

分散式最佳實踐。
https://github.com/tensorflow…

MNIST資料集分散式訓練。開設3個埠作分散式工作節點部署,2222埠引數伺服器,2223埠工作節點0,2224埠工作節點1。引數伺服器執行引數更新任務,工作節點0、工作節點1執行圖模型訓練計算任務。引數伺服器/job:ps/task:0 cocalhost:2222,工作節點/job:worker/task:0 cocalhost:2223,工作節點/job:worker/task:1 cocalhost:2224。

執行程式碼。

python mnist_replica.py --job_name="ps" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=1
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 定義常量,用於建立資料流圖
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data")
# 只下載資料,不做其他操作
flags.DEFINE_boolean("download_only", False,
"Only perform downloading of data; Do not proceed to "
"session preparation, model definition or training")
# task_index從0開始。0代表用來初始化變數的第一個任務
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
# 每臺機器GPU個數,機器沒有GPU為0
flags.DEFINE_integer("num_gpus", 1,
"Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'")
# 同步訓練模型下,設定收集工作節點數量。預設工作節點總數
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
flags.DEFINE_integer("hidden_units", 100,
"Number of units in the hidden layer of the NN")
# 訓練次數
flags.DEFINE_integer("train_steps", 200,
"Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
# 使用同步訓練、非同步訓練
flags.DEFINE_boolean("sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
# 如果伺服器已經存在,採用gRPC協議通訊;如果不存在,採用程序間通訊
flags.DEFINE_boolean(
"existing_servers", False, "Whether servers already exists. If True, "
"will use the worker hosts via their GRPC URLs (one client process "
"per worker host). Otherwise, will create an in-process TensorFlow "
"server.")
# 引數伺服器主機
flags.DEFINE_string("ps_hosts","localhost:2222",
"Comma-separated list of hostname:port pairs")
# 工作節點主機
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
# 本作業是工作節點還是引數伺服器
flags.DEFINE_string("job_name", None,"job name: worker or ps")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def main(unused_argv):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
if FLAGS.download_only:
sys.exit(0)
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
#Construct the cluster and start the server
# 讀取叢集描述資訊
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
# Get the number of workers.
num_workers = len(worker_spec)
# 建立TensorFlow叢集描述物件
cluster = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
# 為本地執行任務建立TensorFlow Server物件。
if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server.
# 建立本地Sever物件,從tf.train.Server這個定義開始,每個節點開始不同
# 根據執行的命令的引數(作業名字)不同,決定這個任務是哪個任務
# 如果作業名字是ps,程序就加入這裡,作為引數更新的服務,等待其他工作節點給它提交引數更新的資料
# 如果作業名字是worker,就執行後面的計算任務
server = tf.train.Server(
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 如果是引數伺服器,直接啟動即可。這裡,程序就會阻塞在這裡
# 下面的tf.train.replica_device_setter程式碼會將引數批定給ps_server保管
if FLAGS.job_name == "ps":
server.join()
# 處理工作節點
# 找出worker的主節點,即task_index為0的點
is_chief = (FLAGS.task_index == 0)
# 如果使用gpu
if FLAGS.num_gpus > 0:
# Avoid gpu allocation conflict: now allocate task_num -> #gpu
# for each worker in the corresponding machine
gpu = (FLAGS.task_index % FLAGS.num_gpus)
# 分配worker到指定gpu上執行
worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
# 如果使用cpu
elif FLAGS.num_gpus == 0:
# Just allocate the CPU to worker server
# 把cpu分配給worker
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
# The ps use CPU and workers use corresponding GPU
# 用tf.train.replica_device_setter將涉及變數操作分配到引數伺服器上,使用CPU。將涉及非變數操作分配到工作節點上,使用上一步worker_device值。
# 在這個with語句之下定義的引數,會自動分配到引數伺服器上去定義。如果有多個引數伺服器,就輪流迴圈分配
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
# 定義全域性步長,預設值為0
global_step = tf.Variable(0, name="global_step", trainable=False)
# Variables of the hidden layer
# 定義隱藏層引數變數,這裡是全連線神經網路隱藏層
hid_w = tf.Variable(
tf.truncated_normal(
[IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
stddev=1.0 / IMAGE_PIXELS),
name="hid_w")
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
# Variables of the softmax layer
# 定義Softmax 迴歸層引數變數
sm_w = tf.Variable(
tf.truncated_normal(
[FLAGS.hidden_units, 10],
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with FLAGS.task_index
# 定義模型輸入資料變數
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
# 構建隱藏層
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
# 構建損失函式和優化器
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
# 非同步訓練模式:自己計算完成梯度就去更新引數,不同副本之間不會去協調進度
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
# 同步訓練模式
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
# 使用SyncReplicasOptimizer作優化器,並且是在圖間複製情況下
# 在圖內複製情況下將所有梯度平均
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
name="mnist_sync_replicas")
train_step = opt.minimize(cross_entropy, global_step=global_step)
if FLAGS.sync_replicas:
local_init_op = opt.local_step_init_op
if is_chief:
# 所有進行計算工作節點裡一個主工作節點(chief)
# 主節點負責初始化引數、模型儲存、概要儲存
local_init_op = opt.chief_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
# Initial token and chief queue runners required by the sync_replicas mode
# 同步訓練模式所需初始令牌、主佇列
chief_queue_runner = opt.get_chief_queue_runner()
sync_init_op = opt.get_init_tokens_op()
init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()
if FLAGS.sync_replicas:
# 建立一個監管程式,用於統計訓練模型過程中的資訊
# lodger 是儲存和載入模型路徑
# 啟動就會去這個logdir目錄看是否有檢查點檔案,有的話就自動載入
# 沒有就用init_op指定初始化引數
# 主工作節點(chief)負責模型引數初始化工作
# 過程中,其他工作節點等待主節眯完成初始化工作,初始化完成後,一起開始訓練資料
# global_step值是所有計算節點共享的
# 在執行損失函式最小值時自動加1,通過global_step知道所有計算節點一共計算多少步
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
recovery_wait_secs=1,
global_step=global_step)
else:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
recovery_wait_secs=1,
global_step=global_step)
# 建立會話,設定屬性allow_soft_placement為True
# 所有操作預設使用被指定設定,如GPU
# 如果該操作函式沒有GPU實現,自動使用CPU裝置
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
# 主工作節點(chief),task_index為0節點初始化會話
# 其餘工作節點等待會話被初始化後進行計算
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.task_index)
else:
print("Worker %d: Waiting for session to be initialized..." %
FLAGS.task_index)
if FLAGS.existing_servers:
server_grpc_url = "grpc://"   worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url)
# 建立TensorFlow會話物件,用於執行TensorFlow圖計算
# prepare_or_wait_for_session需要引數初始化完成且主節點準備好後,才開始訓練
sess = sv.prepare_or_wait_for_session(server_grpc_url,
config=sess_config)
else:
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op.
sess.run(sync_init_op)
sv.start_queue_runners(sess, [chief_queue_runner])
# Perform training
# 執行分散式模型訓練
time_begin = time.time()
print("Training begins @ %f" % time_begin)
local_step = 0
while True:
# Training feed
# 讀入MNIST訓練資料,預設每批次100張圖片
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
train_feed = {x: batch_xs, y_: batch_ys}
_, step = sess.run([train_step, global_step], feed_dict=train_feed)
local_step  = 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# Validation feed
# 讀入MNIST驗證資料,計算驗證的交叉熵
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
if __name__ == "__main__":
tf.app.run()

參考資料:
《TensorFlow技術解析與實戰》

歡迎推薦上海機器學習工作機會,我的微信:qingxingfengzi

相關文章

人工智慧 最新文章