step 1app
用高斯分佈生成兩類點dom
1 class Point3: 2 def __init__(self): 3 self.x = random.gauss(50, 10) 4 self.y = random.gauss(50, 10) 5 6 self.label = -1 7 self.color = 'r' 8 9 class Point4: 10 def __init__(self): 11 self.x = random.gauss(90, 10) 12 self.y = random.gauss(90, 10) 13 14 self.label = 1 15 self.color = 'b'
step 2post
畫一條初始直線,先定義兩個點(x1, 0)和(x2, 100),x1屬於(0, 50),x2屬於(50, 100),有了兩個點以後,畫出一條直線spa
1 class Line: 2 def __init__(self): 3 self.x1 = random.randint(MIN, MAX//2) # MAX=100 MIN=0 (0, 50) 隨機生成一個整數 4 self.x2 = random.randint(MAX//2, MAX) # MAX=100 MIN=0 (50, 100) 5 self.y1 = 0 6 self.y2 = 100 7 8 self.x = [self.x1, self.x2] 9 self.y = [self.y1, self.y2] 10 11 self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1) 12 self.w2 = 1 13 self.b = -(self.w1 * self.x1) + self.w2 * self.y1
step 3
判斷誤分類點
正確分類1:w1*x+w2*y+b>0且label=1
正確分類2:w1*x+w2*y+b<0且label=-1code
1 def sign(self, point): 2 # print(self.w1 * point.x + self.w2 * point.y + self.b) 3 # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b)) 4 return point.label * (self.w1 * point.x + self.w2 * point.y + self.b)
step 4
有了更新後的w一、w2和b以後,更新一條新的直線。
首先,須要先找到兩個點,此時y1=0, y2=100不變,則咱們只需找到對應的x1,x2便可。orm
1 def update(self): 2 self.x1 = -self.b / self.w1 3 self.x2 = (-self.b - self.w2 * self.y2) / self.w1 4 self.x = [self.x1, self.x2] 5 self.y = [self.y1, self.y2]
step 5
w一、w2和b的更新規則,參考博文支持向量機http://www.carefree0910.com/posts/d455305a/blog
1 def preceptron_base_dis(all_points): 2 line = Line() 3 plt.plot(line.x, line.y, 'g--', linewidth=1) 4 Flag = True 5 while True: 6 Flag = True 7 for point in all_points: 8 if line.sign(point) < 1: # 只有誤分類點才更新 9 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x 10 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y 11 line.b = line.b + step * C * point.label 12 Flag = False 13 if Flag: 14 break 15 line.update() 16 #plt.plot(l.x, l.y, 'y--', linewidth=1) 17 plt.plot(line.x, line.y, '.-', linewidth=1) 18 plt.show()
所有代碼彙總it
1 import matplotlib.pyplot as plt 2 import numpy 3 import random 4 import sys 5 6 MAX=100 7 MIN=0 8 POINT_NUM=20 9 step=0.01 10 C = 0.1 11 12 class Point: 13 def __init__(self): 14 self.x = random.uniform(MIN, MAX) 15 self.y = random.uniform(MIN, MAX) 16 17 if self.x > self.y: 18 self.label = 1 19 self.color = 'b' 20 else: 21 self.label = -1 22 self.color = 'r' 23 class Point2: 24 def __init__(self): 25 self.x = random.randint(MIN, MAX) 26 if self.x > MAX // 2: 27 self.y = random.randint(0, MAX // 4) 28 else: 29 self.y = random.randint(MAX * 2 // 4, MAX) 30 31 if self.x > self.y: 32 self.label = 1 33 self.color = 'b' 34 else: 35 self.label = -1 36 self.color = 'r' 37 38 class Point3: 39 def __init__(self): 40 self.x = random.gauss(50, 10) 41 self.y = random.gauss(50, 10) 42 43 self.label = -1 44 self.color = 'r' 45 46 class Point4: 47 def __init__(self): 48 self.x = random.gauss(90, 10) 49 self.y = random.gauss(90, 10) 50 51 self.label = 1 52 self.color = 'b' 53 class Line: 54 def __init__(self): 55 self.x1 = random.randint(MIN, MAX//2) # MAX=100 MIN=0 (0, 50) 隨機生成一個整數 56 self.x2 = random.randint(MAX//2, MAX) # MAX=100 MIN=0 (50, 100) 57 self.y1 = 0 58 self.y2 = 100 59 60 self.x = [self.x1, self.x2] 61 self.y = [self.y1, self.y2] 62 63 self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1) 64 self.w2 = 1 65 self.b = -(self.w1 * self.x1) + self.w2 * self.y1 66 67 def sign(self, point): 68 # print(self.w1 * point.x + self.w2 * point.y + self.b) 69 # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b)) 70 return point.label * (self.w1 * point.x + self.w2 * point.y + self.b) 71 72 def update(self): 73 self.x1 = -self.b / self.w1 74 self.x2 = (-self.b - self.w2 * self.y2) / self.w1 75 self.x = [self.x1, self.x2] 76 self.y = [self.y1, self.y2] 77 78 79 def initialPoint(): 80 plt.figure() 81 all_point = [] 82 for idx in range(POINT_NUM): 83 p = Point3() 84 plt.plot(p.x, p.y, p.color + 'o', label="point") 85 all_point.append(p) 86 87 for idx in range(POINT_NUM): 88 p = Point4() 89 plt.plot(p.x, p.y, p.color + 'o', label="point") 90 all_point.append(p) 91 return all_point 92 93 def preceptron_base_dis(all_points): 94 line = Line() 95 plt.plot(line.x, line.y, 'g--', linewidth=1) 96 Flag = True 97 while True: 98 Flag = True 99 for point in all_points: 100 if line.sign(point) < 1: # 只有誤分類點才更新 101 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x 102 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y 103 line.b = line.b + step * C * point.label 104 Flag = False 105 if Flag: 106 break 107 line.update() 108 #plt.plot(l.x, l.y, 'y--', linewidth=1) 109 plt.plot(line.x, line.y, '.-', linewidth=1) 110 plt.show() 111 112 def preceptron(all_points): 113 line = Line() 114 plt.plot(line.x, line.y, 'g--', linewidth=1) 115 Flag = True 116 while True: 117 Flag = True 118 for point in all_points: 119 if line.sign(point) <= 0: 120 line.w1 += step * point.label * point.x 121 line.w2 += step * point.label * point.y 122 line.b += step * point.label 123 Flag = False 124 if Flag: 125 break 126 line.update() 127 #plt.plot(line.x, line.y, 'y--', linewidth=1) 128 plt.plot(line.x, line.y, 'o-', linewidth=1) 129 plt.show() 130 131 all_points = initialPoint() 132 preceptron_base_dis(all_points)