tensorflow搭建神經網絡

最簡單的神經網絡網絡

 1 import tensorflow as tf
 2 import numpy as np
 3 import matplotlib.pyplot as plt
 4 
 5 date = np.linspace(1, 15, 15)# d定義日期
 6 endPrice = np.array([2511.90, 2538.26, 2510.68, 2591.66, 2732.98, 2701.69, 2701.29, 2678.67, 2726.50, 2681.50, 2739.17, 2715.07, 2823.58, 2864.90, 2919.08])
 7 beginPrice = np.array([2438.71, 2500.88, 2534.95, 2512.52, 2594.04, 2743.26, 2697.47, 2695.24, 2678.23, 2722.13, 2674.93, 2744.13, 2717.46, 2832.73, 2877.40])
 8 # print(date)
 9 plt.figure()
10 
11 for i in range(0, 15):
12     dataOne = np.zeros([2])
13     dataOne[0] = i
14     dataOne[1] = i
15     priceOne = np.zeros([2])
16     priceOne[0] = beginPrice[i]
17     priceOne[1] = endPrice[i]
18     if endPrice[i] > beginPrice[i]:
19         plt.plot(dataOne, priceOne, 'r', lw=8)
20     else:
21         plt.plot(dataOne, priceOne, 'g', lw=8)    
22 # plt.show()
23 # 歸一化處理
24 dateNormal = np.zeros([15,1])
25 PriceNormal  = np.zeros([15,1])
26 for i in range(0,15):
27     dateNormal[i] = i/14.0
28     PriceNormal[i] = endPrice[i]/3000.0
29 # print(dateNormal)
30 # print('\n')
31 # print(PriceNormal)
32 x = tf.placeholder(tf.float32, [None, 1])
33 y = tf.placeholder(tf.float32, [None, 1])
34 
35  # B 第一層
36 w1 = tf.Variable(tf.random_uniform([1, 10], 0, 1))
37 b1 = tf.Variable(tf.zeros([1, 10]))
38 wb1 = tf.matmul(x, w1) + b1
39 layer1 = tf.nn.relu(wb1)# 激勵函數
40 
41 # 第二層
42 w2 = tf.Variable(tf.random_uniform([10,1], 0, 1))
43 b2 = tf.Variable(tf.zeros([15, 1]))
44 wb2 = tf.matmul(layer1, w2) + b2
45 layer2 = tf.nn.relu(wb2)# 激勵函數
46 
47 # loss
48 loss = tf.reduce_mean(tf.square(y-layer2))
49 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
50 with tf.Session() as sess:
51     sess.run(tf.global_variables_initializer())
52     for i in range(0, 100000):
53         sess.run(train_step, feed_dict={x:dateNormal, y:PriceNormal})
54     pred = sess.run(layer2, feed_dict={x:dateNormal})
55     predPrice = np.zeros([15, 1])
56     for i in range(0, 15):
57         predPrice[i] = (pred * 3000)[i]
58     plt.plot(date, predPrice, 'b', lw=2)
59 plt.show()

 

相關文章
相關標籤/搜索