python 多線程稀疏矩陣乘法

 1 import threading, time
 2 import numpy as np
 3 res = []
 4 class MyThread(threading.Thread):
 5     def __init__(self,i,j,m1,m2):
 6         threading.Thread.__init__(self)
 7         self.x, self.y = i,j
 8         self.m1, self.m2 = m1, m2
 9     def run(self):
10         global res, lock
11         if lock.acquire():
12             m1 = self.m1[self.m1[:,0]==self.x]
13             m2 = self.m2[self.m2[:,1]==self.y]
14             value = 0.
15             for item1 in m1:
16                 for item2 in m2:
17                     if item1[1] == item2[0]:
18                         value += item1[2]*item2[2]
19             res.append([self.x,self.y,value])
20             lock.release()
21 if "__main__" == __name__:
22     m1 = [[2,2],[0,0,1],[0,1,2],[1,0,3],[1,1,4]]
23     m2 = [[2,3],[0,0,2],[0,2,1],[1,2,3],[1,1,4]]
24     s1, s2 = m1[0], m2[0]
25     assert s1[1]==s2[0], 'mismatch'
26     m1_value = np.array(m1[1:])
27     m2_value = np.array(m2[1:])  
28     rows, cols = s1[0], s2[1]
29     res.append([rows, cols])
30     ThreadList = []
31     lock = threading.Lock()
32     for i in range(rows):
33         for j in range(cols):
34             t = MyThread(i,j,m1_value,m2_value)
35             ThreadList.append(t)
36     for t in ThreadList:
37         t.start()
38     for t in ThreadList:
39         t.join()
40     print (res)
相關文章
相關標籤/搜索