nparray的維度和取值方式

運行下面的代碼,分析結果,不作解釋。python

import numpy as np

def func(img, label):
    #print('label[0]:', label[:,:,0])
    if (1):
        img = img / 255.
        label = label[:, :, :, 0] if (len(label.shape) == 4) else label[:, :, 0]
        new_label = np.zeros(label.shape + (3,))
        for i in range(3):
            new_label[label == i, i] = 1
            print('\n i = \n ', i, '\n label = \n', label, '\n label==i: \n',label==i)
            print('\n\n after-img:\n',img,'\n\n after-label:\n', new_label, '\n')
        label = new_label

    elif (np.max(img) > 1):
        img = img / 255.
        label = label / 255.
        label[label > 0.5] = 1
        label[label <= 0.5] = 0
    return (img, label)
    
label = np.array([[[1,2,0],
                   [0,1,2],
                   [2,1,0]],
                   
                  [[1,2,0],
                   [2,1,2],
                   [2,1,0]],
                   
                  [[0,2,0],
                   [0,1,2],
                   [2,1,0]]])
                   
img = np.array([[129, 255, 30],
                [30, 30, 99],
                [90, 123, 49]])

#print(len(label.shape))
print('\n\n before-img:\n',img,'\n\n before-label:\n', label, '\n')
func(img, label)
相關文章
相關標籤/搜索