Kaggle機器學習競賽是全球最著名的人工智能比賽,每一個競賽項目都吸引了大量AI愛好者參與。git
這裏選擇2018年末進行的鹽沉積區識別競賽做爲例子:https://www.kaggle.com/c/tgs-salt-identification-challengegithub
1、數據網絡
能夠從Kaggle網站下載,但需先註冊,下載速度可能也必將慢。能夠直接從個人百度網盤下載:dom
連接:https://pan.baidu.com/s/1htvnrwQagOXHXfjpaGedPQ
提取碼:a0zx機器學習
2、unet++模型開源代碼ide
unet++是2018年被提出的網絡模型,是對unet的優化,在圖像分割中有優異的表現。採用的源碼見:https://github.com/MrGiovanni/UNetPlusPlus學習
3、數據處理及準備優化
導入包:網站
import os import random import matplotlib.pyplot as plt import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from skimage.transform import resize from UNetPlusPlus_master.segmentation_models import Xnet from keras.preprocessing.image import load_img from keras.optimizers import * from keras.callbacks import EarlyStopping, ModelCheckpoint
數據位置、尺寸:人工智能
root = r'E:\Kaggle\salt\competition_data' model_path = root + '/model' imgs_path = root + r'\train' test_imgs_path = root + r'\test' train_csv = root + r'\train.csv' depths_csv = root + r'\depths.csv' orig_img_w = 101 orig_img_h = 101 train_img_w = 224 train_img_h = 224
原尺寸和訓練尺寸轉換:
def orig2tain(img): return resize(img, (train_img_w, train_img_h), mode='constant', preserve_range=True) def train2orig(img): return resize(img, (orig_img_w, orig_img_h), mode='constant', preserve_range=True)
讀入數據:
train_df = pd.read_csv(train_csv, usecols=[0], index_col='id') train_df["images"] = [np.array(load_img("{}/images/{}.png".format(imgs_path, idx), grayscale=False)) / 255 for idx in train_df.index] train_df["masks"] = [np.array(load_img("{}/masks/{}.png".format(imgs_path, idx), grayscale=True)) / 255 for idx in train_df.index]
顯示讀入結果:
max_images = 10 grid_width = 10 grid_height = int(max_images / grid_width) + 1 fig, axs = plt.subplots(grid_height, grid_width, figsize=(20, 4)) for i, idx in enumerate(train_df.index[:max_images]): img = train_df.loc[idx].images mask = train_df.loc[idx].masks ax = axs[int(i / grid_width), i % grid_width] ax.imshow(img, cmap="Greys") ax = axs[int(i / grid_width)+1, i % grid_width] ax.imshow(mask, cmap="Greens") ax.set_yticklabels([]) ax.set_xticklabels([]) plt.show()
按2:8隨機分訓練集、驗證集:
train_ids, valid_ids, train_x, valid_x, train_y, valid_y = train_test_split( train_df.index.values, np.array(train_df.images.map(orig2tain).tolist()).reshape(-1, train_img_w, train_img_h, 3), np.array(train_df.masks.map(orig2tain).tolist()).reshape(-1, train_img_w, train_img_h, 1), test_size=0.2, random_state=123)
4、訓練
input_size = (train_img_w, train_img_h, 3) model = Xnet(input_shape=input_size, backbone_name='resnet50', encoder_weights='imagenet', decoder_block_type='transpose') model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) model_name = 'Kaggle_Salt_{epoch:02d}-{val_acc:.3f}.hdf5' abs_model_name = os.path.join(model_path, model_name) model_checkpoint = ModelCheckpoint(abs_model_name, monitor='val_loss', verbose=2, save_best_only=True) early_stop = EarlyStopping(monitor='val_loss', patience=6) callbacks = [early_stop, model_checkpoint]
history = model.fit(train_x, train_y, validation_data=[valid_x, valid_y], epochs=100, batch_size=4, callbacks=callbacks)
顯示訓練曲線:
acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()