PAI-STUDIO在支持OSS數據源的基礎上,增長了對MaxCompute表的數據支持。用戶能夠直接使用PAI-STUDIO的Tensorflow組件讀寫MaxCompute數據,本教程將提供完整數據和代碼供你們測試。前端
爲了方便用戶快速上手,本文檔將以訓練iris數據集爲例,介紹如何跑通實驗。app
爲了方便你們,咱們提供了一份公共讀的數據供你們測試,只要拖出讀數據表組件,輸入:測試
pai_online_project.iris_data
便可獲取數據,ui
數據格式如圖:url
3個輸入樁從左到右分別是OSS輸入、MaxCompute輸入、模型輸入。2個輸出樁分別是模型輸出、MaxCompute輸出。若是輸入是一個MaxCompute表,輸出也是一個MaxCompute表,須要按下圖方法鏈接。spa
讀寫MaxCompute表須要配置數據源、代碼文件、輸出模型路徑、建表等操做。code
組件PAI命令blog
PAI -name tensorflow180_ext -project algo_public -Doutputs="odps://${當前項目名}/tables/${輸出表名}" -DossHost="${OSS的host}" -Dtables="odps://${當前項目名}/tables/${輸入表名}" -DgpuRequired="${GPU卡數}" -Darn="${OSS訪問RoleARN}" -Dscript="${執行的代碼文件}";
上述命令中的${}須要替換成用戶真實數據教程
import tensorflow as tf tf.app.flags.DEFINE_string("tables", "", "tables info") FLAGS = tf.app.flags.FLAGS print("tables:" + FLAGS.tables) tables = [FLAGS.tables] filename_queue = tf.train.string_input_producer(tables, num_epochs=1) reader = tf.TableRecordReader() key, value = reader.read(filename_queue) record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Iris-virginica"]] col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults = record_defaults) # line 9 and 10 can be written like below for short. It will be helpful when too many columns exist. # record_defaults = [[1.0]] * 4 + [["Iris-virginica"]] # value_list = tf.decode_csv(value, record_defaults = record_defaults) writer = tf.TableRecordWriter("odps://pai_bj_test2/tables/iris_output") write_to_table = writer.write([0, 1, 2, 3, 4], [col1, col2, col3, col4, col5]) # line 16 can be written like below for short. It will be helpful when too many columns exist. # write_to_table = writer.write(range(5), value_list) close_table = writer.close() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) sess.run(tf.local_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: step = 0 while not coord.should_stop(): step += 1 sess.run(write_to_table) except tf.errors.OutOfRangeError: print('%d records copied' % step) finally: sess.run(close_table) coord.request_stop() coord.join(threads)
tables = [FLAGS.tables]
filename_queue = tf.train.string_input_producer(tables, num_epochs=1)
reader = tf.TableRecordReader()
key, value = reader.read(filename_queue)
record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Iris-virginica"]]ip
其中FLAGS.tables是前端配置的輸入表名的傳參變量,對應組件的MaxCompute輸入樁:
writer = tf.TableRecordWriter("odps://pai_bj_test2/tables/iris_output")
write_to_table = writer.write([0, 1, 2, 3, 4], [col1, col2, col3, col4, col5])
TableRecordWriter中的格式爲odps://當前項目名/tables/輸出表名
本文做者:傲海
本文爲雲棲社區原創內容,未經容許不得轉載。