Python基於dtw實現股票預測【多線程】

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Tue Dec 4 08:53:08 2018  4 
 5 @author: zhen  6 """
 7 from dtw import fastdtw  8 import matplotlib.pyplot as plt  9 import numpy as np  10 import pandas as pd  11 import threading  12 import time  13 from datetime import datetime  14 
 15 def normalization(x): # np.std:計算矩陣的標準差(方差的算術平方根)
 16     return (x - np.mean(x)) / np.std(x)  17 
 18 def corrcoef(a,b):  19     corrc = np.corrcoef(a,b) # 計算皮爾遜相關係數,用於度量兩個變量之間的相關性,其值介於-1到1之間
 20     corrc = corrc[0,1]  21     return (16 * ((1 - corrc) / (1 + corrc)) ** 1) # ** 表示乘方
 22         
 23 print("begin Main Thread")  24 startTimeStamp = datetime.now() # 獲取當前時間
 25 # 加載數據
 26 filename = 'C:/Users/zhen/.spyder-py3/sh000300_2017.csv'
 27 # 獲取第一,二列的數據
 28 all_date = pd.read_csv(filename,usecols=[0, 1], dtype = 'str')  29 all_date = np.array(all_date)  30 data = all_date[:, 0]  31 times = all_date[:, 1]  32 
 33 data_points = pd.read_csv(filename,usecols=[3])  34 data_points = np.array(data_points)  35 data_points = data_points[:,0] #數據
 36 
 37 topk = 10 #只顯示top-10
 38 baselen = 100 # 假設在50到150之間變化
 39 basebegin = 361
 40 basedata = data[basebegin]+' '+times[basebegin]+'~'+data[basebegin+baselen-1]+' '+times[basebegin+baselen-1]  41 length = len(data_points) #數據長度
 42 
 43 # 定義自定義線程類
 44 class Thread_Local(threading.Thread):  45     def __init__(self, thread_id, name, counter):  46         threading.Thread.__init__(self)  47         self.thread_id = thread_id  48         self.name = name  49         self.counter = counter  50         self.__running = threading.Event() # 標識中止線程
 51         self.__running.set() # 設置爲True
 52         
 53     def run(self):  54         print("starting %s" % self.name)  55         split_data(self, self.counter) # 執行代碼邏輯
 56         
 57     def stop(self):  58         self.__running.clear()  59         
 60 # 分割片斷並執行匹配,多線程
 61 def split_data(self, split_len):  62     base = data_points[basebegin:basebegin+baselen]  # 獲取初始要匹配的數據
 63     subseries = []  64     dateseries = []  65     for j in range(0, length):  66         if (j < (basebegin - split_len) or j > (basebegin + split_len - 1)) and j <length - split_len:  67             subseries.append(data_points[j:j+split_len])  68             dateseries.append(j) #開始位置
 69     search(self, subseries, base, dateseries)  # 調用模式匹配
 70 
 71 # 定義結果變量
 72 result = []  73 base_list = []  74 date_list = []  75 def search(self, subseries, base, dateseries):  76      # 片斷搜索
 77     listdistance = []  78     for i in range(0, len(subseries)):  79         tt = np.array(subseries[i])  80         dist, cost, acc, path = fastdtw(base, tt, dist='euclidean')  81  listdistance.append(dist)  82         # distance = corrcoef(base, tt)
 83         # listdistance.append(distance)
 84     # 排序
 85     index = np.argsort(listdistance, kind='quicksort') #排序,返回排序後的索引序列
 86  result.append(subseries[index[0]])  87     print("result length is %d" % len(result))  88  base_list.append(base)  89  date_list.append(dateseries[index[0]])  90     # 關閉線程
 91  self.stop()  92     
 93 # 變換數據(收縮或擴展),生成50到150之間的數據,間隔爲10
 94 loc = 0  95 for split_len in range(round(0.5 * baselen), round(1.5 * baselen), 10):  96     # 執行匹配
 97    thread = Thread_Local(1, "Thread" + str(loc), split_len)  98    loc += 1
 99    # 開啓線程
100  thread.start() 101 
102 boo = 1
103 
104 while(boo > 0): 105     if(len(result) < 10): 106         if(boo % 100 == 0): 107             print("has running %d s" % boo) 108         boo += 1
109         time.sleep(1) 110     else: 111         boo = 0 112         
113  # 片斷搜索
114 listdistance = [] 115 for i in range(0, len(result)): 116     tt = np.array(result[i]) 117     dist, cost, acc, path = fastdtw(base_list[i], tt, dist='euclidean') 118     # distance = corrcoef(base_list[i], tt)
119  listdistance.append(dist) 120 # 最終排序 
121 index = np.argsort(listdistance, kind='quicksort') #排序,返回排序後的索引序列
122 print("closed Main Thread") 123 endTimeStamp = datetime.now() 124 # 結果集對比
125 plt.figure(0) 126 plt.plot(normalization(base_list[index[0]]),label= basedata,linewidth='2') 127 length = len(result[index[0]]) 128 begin = data[date_list[index[0]]] + ' ' + times[date_list[index[0]]] 129 end = data[date_list[index[0]] + length - 1] + ' ' + times[date_list[index[0]] + length - 1] 130 label = begin + '~' + end 131 plt.plot(normalization(result[index[0]]), label=label, linewidth='2') 132 plt.legend(loc='lower right') 133 plt.title('normal similarity search') 134 plt.show() 135 print('run time', (endTimeStamp-startTimeStamp).seconds, "s")

結果:多線程

has running 100 s has running 200 s has running 300 s has running 400 s has running 500 s has running 600 s has running 700 s has running 800 s has running 900 s has running 1000 s has running 1100 s has running 1200 s has running 1300 s has running 1400 s has running 1500 s has running 1600 s has running 1700 s has running 1800 s has running 1900 s has running 2000 s has running 2100 s has running 2200 s has running 2300 s has running 2400 s has running 2500 s has running 2600 s has running 2700 s has running 2800 s has running 2900 s has running 3000 s has running 3100 s has running 3200 s has running 3300 s has running 3400 s has running 3500 s has running 3600 s has running 3700 s has running 3800 s has running 3900 s has running 4000 s has running 4100 s has running 4200 s has running 4300 s has running 4400 s has running 4500 s has running 4600 s has running 4700 s has running 4800 s has running 4900 s has running 5000 s has running 5100 s has running 5200 s has running 5300 s has running 5400 s has running 5500 s has running 5600 s has running 5700 s has running 5800 s has running 5900 s has running 6000 s has running 6100 s has running 6200 s has running 6300 s has running 6400 s has running 6500 s has running 6600 s has running 6700 s has running 6800 s has running 6900 s has running 7000 s has running 7100 s has running 7200 s has running 7300 s has running 7400 s has running 7500 s has running 7600 s has running 7700 s has running 7800 s has running 7900 s has running 8000 s has running 8100 s has running 8200 s has running 8300 s has running 8400 s has running 8500 s has running 8600 s has running 8700 s has running 8800 s has running 8900 s has running 9000 s has running 9100 s has running 9200 s has running 9300 s has running 9400 s has running 9500 s has running 9600 s has running 9700 s has running 9800 s has running 9900 s has running 10000 s has running 10100 s has running 10200 s has running 10300 s has running 10400 s has running 10500 s has running 10600 s has running 10700 s has running 10800 s has running 10900 s has running 11000 s has running 11100 s has running 11200 s has running 11300 s has running 11400 s has running 11500 s has running 11600 s has running 11700 s has running 11800 s has running 11900 s has running 12000 s has running 12100 s has running 12200 s has running 12300 s has running 12400 s has running 12500 s has running 12600 s has running 12700 s has running 12800 s has running 12900 s has running 13000 s has running 13100 s has running 13200 s has running 13300 s has running 13400 s has running 13500 s has running 13600 s has running 13700 s has running 13800 s has running 13900 s has running 14000 s has running 14100 s has running 14200 s has running 14300 s has running 14400 s result length is 1 result length is 2 has running 14500 s has running 14600 s has running 14700 s has running 14800 s result length is 3 has running 14900 s has running 15000 s result length is 4 has running 15100 s has running 15200 s has running 15300 s has running 15400 s result length is 5 has running 15500 s has running 15600 s has running 15700 s has running 15800 s has running 15900 s has running 16000 s has running 16100 s has running 16200 s result length is 6 has running 16300 s has running 16400 s has running 16500 s has running 16600 s result length is 7 result length is 8 has running 16700 s result length is 9 result length is 10 closed Main Thread
View Code

相關文章
相關標籤/搜索