二維特徵空間帶標籤可視化

爲了直觀地驗證網絡對特徵提取的性能,能夠使用可視化技術來可視化通過網絡以後的特徵分佈狀況。這也是目前softmax-base的人臉識別論文的常見作法。
首先就是訓練好一個網絡而後進行測試,爲了充分利用GPU的性能,測試過程當中使用mini-batch的數據進行前向傳播,並記錄特徵。html

def vis():
    '''
    對模型結果進行可視化
    '''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net().to(device)
    #model.load_state_dict(torch.load('mnist_cnn.pt'))
    model.eval()
    print('加載模型完畢')

    test_batch_size=64
    kwargs = {'num_workers': 0, 'pin_memory': True} if torch.cuda.is_available() else {}
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=test_batch_size, shuffle=True, **kwargs)

    data_iter=iter(test_loader)
    cnt=0
    
    fn='./res_random.csv'
    with torch.no_grad():
        while cnt<100:
            cnt+=1
            data, target = data_iter.__next__()
            data, target = data.to(device), target.to(device)
            output, features = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            if test_batch_size==1:
                target= target.cpu().numpy()[0]
                features=features.cpu().numpy()[0]
            else:
                target= target.cpu().numpy()
                features=features.cpu().numpy()
                print(features.shape,',',target.shape)
                #轉換爲1維的向量,方便後面解碼
                features=np.reshape(features,-1)
                target=np.reshape(target,-1)
            #features=np.array2string(features)
            #target=np.array2string(target)
            #features=features.tostring()
            #target=target.tostring()
            #後面就是存儲到文件中

在獲取到特徵數據後,將特徵可視化部分,以前有使用過pandas,發現挺好用的,爲了避免重複造輪子,直接在pandas的基礎上操做數據。python

fn='./res_random.csv'
    df=pd.read_csv(fn,sep='\t',names=['feat','id'])
    dic=df.set_index('feat')['id'].to_dict()
    keys=list(dic.keys())
    k=keys[0]
    k=k.strip('][')
    karr=np.fromstring(k,float,sep=' ')
    karr=np.reshape(karr,(64,2)).tolist()
    
    label=list(dic.values())
    l=label[0]
    l=l.strip('][')
    larr=np.fromstring(l,int,sep=' ').tolist()
    x, y=zip(*karr)
    x=np.array(x)
    y=np.array(y)
    group=np.array(larr)
    cdict={0:'b',1:'g',2:'r',3:'c',4:'m',5:'y',6:'k',7:'wheat',8:'tan',9:'orchid'}
    fig, ax= plt.subplots()
    total_x=[]
    total_y=[]
    total_l=[]
    for k,l in zip(keys,label):
        k=k.strip('][')
        karr=np.fromstring(k,float,sep=' ')
        karr=np.reshape(karr,(64,2)).tolist()
        l=l.strip('][')
        larr=np.fromstring(l,int,sep=' ').tolist()
        
        x, y=zip(*karr)
        x=list(x)
        y=list(y)
        total_x=total_x+x
        total_y=total_y+y
        total_l=total_l+larr
    group=np.array(total_l)

    x=np.array(total_x)
    y=np.array(total_y)
    for g in np.unique(group):
        ix= np.where(group==g)
        ax.scatter(x[ix],y[ix],c=cdict[g],label=int(g))
    ax.legend()
    plt.show()

這樣就能獲得

參考網站
不一樣標籤繪製不一樣顏色
將列表形式的xy座標分離
matplotlib顏色表

有空補一補註釋網絡

相關文章
相關標籤/搜索