tensorflow加載embedding模型進行可視化

1.功能python

採用python的gensim模塊訓練的word2vec模型,而後採用tensorflow讀取模型可視化embedding向量session

ps:採用C++版本訓練的w2v模型,python的gensim模塊讀不了。spa

2.python訓練word2vec模型代碼code

import multiprocessing from gensim.models.word2vec import Word2Vec, LineSentence print('開始訓練') train_file = "/tmp/train_data" model = Word2Vec(LineSentence(train_file), size=128, workers=multiprocessing.cpu_count(), iter=10) print('結束') model.init_sims(replace=True) model.save('/tmp/emb.bin')

3.tensorflow讀取模型可視化blog

import numpy as np import tensorflow as tf import os from gensim.models.word2vec import Word2Vec from tensorflow.contrib.tensorboard.plugins import projector log_dir = '/tmp/embedding_log'
if not os.path.exists(log_dir): os.mkdir(log_dir) # load model
model_file = '/tmp/emb.bin' word2vec = Word2Vec.load(model_file) # create a list of vectors
embedding = np.empty((len(word2vec.vocab.keys()), word2vec.vector_size), dtype=np.float32) for i, word in enumerate(word2vec.vocab.keys()): embedding[i] = word2vec[word] # setup a TensorFlow session
tf.reset_default_graph() sess = tf.InteractiveSession() X = tf.Variable([0.0], name='embedding') place = tf.placeholder(tf.float32, shape=embedding.shape) set_x = tf.assign(X, place, validate_shape=False) sess.run(tf.global_variables_initializer()) sess.run(set_x, feed_dict={place: embedding}) # write labels
with open(os.path.join(log_dir, 'metadata.tsv'), 'w') as f: for word in word2vec.vocab.keys(): f.write(word + '\n') # create a TensorFlow summary writer
summary_writer = tf.summary.FileWriter(log_dir, sess.graph) config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = 'embedding:0' embedding_conf.metadata_path = os.path.join(log_dir, 'metadata.tsv') projector.visualize_embeddings(summary_writer, config) # save the model
saver = tf.train.Saver() saver.save(sess, os.path.join(log_dir, "model.ckpt")) print("完成!")
相關文章
相關標籤/搜索