爲了直觀地驗證網絡對特徵提取的性能,能夠使用可視化技術來可視化通過網絡以後的特徵分佈狀況。這也是目前softmax-base的人臉識別論文的常見作法。
首先就是訓練好一個網絡而後進行測試,爲了充分利用GPU的性能,測試過程當中使用mini-batch的數據進行前向傳播,並記錄特徵。html
def vis(): ''' 對模型結果進行可視化 ''' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Net().to(device) #model.load_state_dict(torch.load('mnist_cnn.pt')) model.eval() print('加載模型完畢') test_batch_size=64 kwargs = {'num_workers': 0, 'pin_memory': True} if torch.cuda.is_available() else {} test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=test_batch_size, shuffle=True, **kwargs) data_iter=iter(test_loader) cnt=0 fn='./res_random.csv' with torch.no_grad(): while cnt<100: cnt+=1 data, target = data_iter.__next__() data, target = data.to(device), target.to(device) output, features = model(data) pred = output.argmax(dim=1, keepdim=True) if test_batch_size==1: target= target.cpu().numpy()[0] features=features.cpu().numpy()[0] else: target= target.cpu().numpy() features=features.cpu().numpy() print(features.shape,',',target.shape) #轉換爲1維的向量,方便後面解碼 features=np.reshape(features,-1) target=np.reshape(target,-1) #features=np.array2string(features) #target=np.array2string(target) #features=features.tostring() #target=target.tostring() #後面就是存儲到文件中
在獲取到特徵數據後,將特徵可視化部分,以前有使用過pandas,發現挺好用的,爲了避免重複造輪子,直接在pandas的基礎上操做數據。python
fn='./res_random.csv' df=pd.read_csv(fn,sep='\t',names=['feat','id']) dic=df.set_index('feat')['id'].to_dict() keys=list(dic.keys()) k=keys[0] k=k.strip('][') karr=np.fromstring(k,float,sep=' ') karr=np.reshape(karr,(64,2)).tolist() label=list(dic.values()) l=label[0] l=l.strip('][') larr=np.fromstring(l,int,sep=' ').tolist() x, y=zip(*karr) x=np.array(x) y=np.array(y) group=np.array(larr) cdict={0:'b',1:'g',2:'r',3:'c',4:'m',5:'y',6:'k',7:'wheat',8:'tan',9:'orchid'} fig, ax= plt.subplots() total_x=[] total_y=[] total_l=[] for k,l in zip(keys,label): k=k.strip('][') karr=np.fromstring(k,float,sep=' ') karr=np.reshape(karr,(64,2)).tolist() l=l.strip('][') larr=np.fromstring(l,int,sep=' ').tolist() x, y=zip(*karr) x=list(x) y=list(y) total_x=total_x+x total_y=total_y+y total_l=total_l+larr group=np.array(total_l) x=np.array(total_x) y=np.array(total_y) for g in np.unique(group): ix= np.where(group==g) ax.scatter(x[ix],y[ix],c=cdict[g],label=int(g)) ax.legend() plt.show()
有空補一補註釋網絡