批量顯示一些圖像(fashion_mnist分類)

def show_imags(n_rows,n_cols,x_data,y_data,class_names):   assert len(x_data)==len(y_data)   assert n_rows*n_cols < len(x_data)   plt.figure(figsize=(n_cols*1.4,n_rows*1.6))   for row in range(n_rows):       for col in range(n_cols):           index = n_cols*row+col           plt.subplot(n_rows,n_cols,index+1)           plt.imshow(x_data[index],cmap = "binary",interpolation="nearest")           plt .axis("off")#不顯示座標尺寸           plt.title(class_names[y_data[index]])   plt.show()class_names = ['T-shirt','Trouser','Pullover','Dress','Coat','Sandal','Shirt',               'Sneaker','Bag','Ankle boot']show_imags(3,5,x_train,y_train,class_names)
相關文章
相關標籤/搜索