TensorFlow

 

TensorFlow

標籤: 深度學習谷歌開源TensorFlow
 分類:

Google發佈了開源深度學習工具TensorFlow。html5

 

根據官方教程  http://tensorflow.org/tutorials/mnist/beginners/index.md  試用。python

 

操做系統是ubuntu 14.04,64位,python 2.7,已經安裝足夠的python包。linux

 

 

 

1. 安裝

    1.1 參考文檔 http://tensorflow.org/get_started/os_setup.md#binary_installation
    
    1.2 用pip安裝,須要用代理,不然連不上,這個是本地ssh到vps出去的。

    sudo pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl --proxy http://127.0.0.1:3128

    1.3 注意,個人py2.7已經安裝了足夠的包,如python-dev,numpy,swig等等。若是遇到缺乏相應包的問題,先安裝必須的包。

2. 第一個demo,test.py
------------------------------
import tensorflow as tf

hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print sess.run(hello)

a = tf.constant(10)
b = tf.constant(32)
print sess.run(a+b)

------------------------------


3. mnist手寫識別
    3.1 下載數據庫 
    在http://yann.lecun.com/exdb/mnist/下載上面提到的4個gz文件,放到本地目錄如 /tmp/mnist

    3.2 下載input_data.py,放在/home/tim/test目錄下
    https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py

    3.3 在/home/tim/test目錄下建立文件test_tensor_flow_mnist.py,內容以下
-----------------------
#!/usr/bin/env python 

import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
-----------------------

3.4 運行。大概之須要幾秒鐘時間,輸出結果是91%左右。

 

 

4. 關於版本數據庫

4.1  pip versionjson


pip 1.5.4 from /usr/lib/python2.7/dist-packages (python 2.7)ubuntu


4.2 已經安裝的python包api

    有一些是用easy_install安裝的,大部分是pip安裝的。

pip freeze


Jinja2==2.7.2
MarkupSafe==0.18
MySQL-python==1.2.3
PAM==0.4.2
Pillow==2.3.0
Twisted-Core==13.2.0
Twisted-Web==13.2.0
adium-theme-ubuntu==0.3.4
apt-xapian-index==0.45
argparse==1.2.1
beautifulsoup4==4.2.1
chardet==2.0.1
colorama==0.2.5
command-not-found==0.3
cvxopt==1.1.4
debtagshw==0.1
decorator==3.4.0
defer==1.0.6
dirspec==13.10
duplicity==0.6.23
fp-growth==0.1.2
html5lib==0.999
httplib2==0.8
ipython==1.2.1
joblib==0.7.1
lockfile==0.8
lxml==3.3.3
matplotlib==1.4.3
nose==1.3.1
numexpr==2.2.2
numpy==1.9.2
oauthlib==0.6.1
oneconf==0.3.7
openpyxl==1.7.0
pandas==0.13.1
patsy==0.2.1
pexpect==3.1
piston-mini-client==0.7.5
pyOpenSSL==0.13
pycrypto==2.6.1
pycups==1.9.66
pycurl==7.19.3
pygobject==3.12.0
pygraphviz==1.2
pyparsing==2.0.3
pyserial==2.6
pysmbc==1.0.14.1
python-apt==0.9.3.5
python-dateutil==2.4.2
python-debian==0.1.21-nmu2ubuntu2
pytz==2012c
pyxdg==0.25
pyzmq==14.0.1
reportlab==3.0
requests==2.2.1
scipy==0.13.3
sessioninstaller==0.0.0
simplegeneric==0.8.1
simplejson==3.3.1
six==1.10.0
software-center-aptd-plugins==0.0.0
ssh-import-id==3.21
statsmodels==0.5.0
sympy==0.7.4.1
system-service==0.1.6
tables==3.1.1
tensorflow==0.5.0
tornado==3.1.1
unity-lens-photos==1.0
urllib3==1.7.1
vboxapi==1.0
wheel==0.24.0
wsgiref==0.1.2
xdiagnose==3.6.3build2
xlrd==0.9.2
xlwt==0.7.5
zope.interface==4.0.5session

相關文章
相關標籤/搜索