在各類產品的廣告或宣傳營銷材料中,咱們常常會看到一些由模特展現產品使用場景的精美大圖在千方百計吸引咱們產生購買慾。然而在從事過相關工做後就會知道,這些東西準備起來有多麻煩。採用的圖片必須具有適當受權,而若是圖片中包含人物肖像,相關的受權工做就更顯得麻煩,甚至不一樣國家和地區對此都有着不一樣的要求。python
先看看下面這張女孩照片:git
很漂亮對吧!並且拍攝質量也挺高的,細節豐富,色彩逼真。不過真相呢?這個女孩她並不存在!這只是由一個機器學習模型創造出來的虛擬人物(圖片取自維基百科 GAN 條目)。github
生成對抗網絡(Generative Adversarial Networks,GAN)是一種生成式機器學習模型,它已經被普遍應用於廣告、遊戲、娛樂、媒體、製藥等行業,能夠用來創造虛構的人物、場景,模擬人臉老化和圖像風格變換,甚至用於產生化學分子式等。算法
下面的兩張圖片,就分別展現了圖片到圖片轉換的效果,以及基於語義佈局合成景物的效果:安全
下文將引領你們從工程實踐角度出發,藉助AWS機器學習相關雲計算服務,基於 PyTorch 機器學習框架,構建一個生成對抗網絡,並藉此開啓全新、有趣的機器學習和人工智能體驗。網絡
首先一塊兒看看下面顯示的兩組手寫體數字圖片,你是否能從中辨認出哪一組是真人手寫,哪一組又是由計算機生成的?session
本文的課題是用機器學習方法「模仿手寫字體」。爲了完成這個課題,咱們將親手體驗生成對抗網絡的設計和實現。模仿手寫字體與人像生成的基本原理和工程流程基本是一致的,雖然它們的複雜性和精度要求有必定差距,但經過解決模仿手寫字體問題,能夠爲生成對抗網絡的原理和工程實踐打下基礎,進而逐步嘗試和探索更加複雜先進的網絡架構和應用場景。架構
生成對抗網絡(GAN)由 Ian Goodfellow 等人在2014年提出,它是一種深度神經網絡架構,由一個生成網絡和一個判別網絡組成。生成網絡產生「假」數據並試圖欺騙判別網絡;判別網絡對所生成數據進行真僞鑑別,試圖正確識別全部「假」數據。在訓練迭代過程當中,兩個網絡將持續進化和對抗,直到達到平衡狀態(參考:納什均衡),判別網絡沒法再識別「假」數據,訓練結束。app
2016年,Alec Radford 等人發表的論文《深度卷積生成對抗網絡》(DCGAN)中,開創性地將卷積神經網絡應用到生成對抗網絡的模型算法設計當中,替代了全連接層,提升了圖片場景裏訓練的穩定性。框架
Amazon SageMaker 是 AWS 徹底託管的機器學習服務,數據處理和機器學習訓練工做能夠經過 Amazon SageMaker 快速、輕鬆地完成,訓練好的模型能夠直接部署到全託管的生產環境中。
Amazon SageMaker 提供了託管的 Jupyter Notebook 實例,經過 SageMaker SDK 與 AWS 的多種雲服務集成,方便您訪問數據源,進行探索和分析。SageMaker SDK 是一套開源的 Amazon SageMaker 的開發包,能夠協助咱們更好地使用 Amazon SageMaker 提供的託管容器鏡像,以及 AWS 的其餘雲服務,如計算和存儲資源。
如上圖所示,訓練用數據未來自 Amazon S3 的存儲桶;訓練用的框架和託管算法以容器鏡像的形式提供服務,在訓練時與代碼結合;模型代碼運行在 Amazon SageMaker 託管的計算實例中,在訓練時與數據結合;訓練輸出物將進入 Amazon S3 專門的存儲桶裏。後面的講解中,咱們會了解到如何經過 SageMaker SDK 使用這些資源。
下文的操做將用到 Amazon SageMaker、Amazon S三、Amazon EC2等AWS服務,會產生必定的雲資源使用費用。
打開 Amazon SageMaker 儀表板(點擊打開北京區域 | 寧夏區域),隨後點擊 Notebook instances 按鈕進入筆記本實例列表。
若是是第一次使用 Amazon SageMaker,您的 Notebook instances 列表將顯示爲空列表,此時需點擊 Create notebook instance 按鈕來建立全新 Jupyter Notebook 實例。
進入 Create notebook instance 頁面後,請在 Notebook instance name 字段輸入實例名字,本文將使用「MySageMakerInstance」做爲實例名。此處能夠選用本身認爲合適的名字。本文將使用默認實例類型,所以 Notebook instance type 選項將保持爲 ml.t2.medium。
若是是第一次使用 Amazon SageMaker,還須要建立一個 IAM role,以便筆記本實例可以訪問Amazon S3服務。請在 IAM role 選項點擊爲Create a new role。Amazon SageMaker 將建立一個具備必要權限的角色,並將這個角色分配給正在建立的實例。另外,根據實際狀況,咱們也能夠選擇一個已經存在的角色。
在 Create an IAM role 彈出窗口裏,能夠選擇 Any S3 bucket,這樣筆記本實例將可以訪問您帳戶裏的全部桶。另外,根據須要,還能夠選擇 Specific S3 buckets 並輸入桶名。點擊 Create role 按鈕,這個新角色將被建立。
隨後能夠看到 Amazon SageMaker 建立了一個名字相似*AmazonSageMaker-ExecutionRole-****
的角色。對於其餘字段,可使用默認值,請點擊 Create notebook instance 按鈕建立實例。
回到 Notebook instances 頁面,會看到 MySageMakerInstance 筆記本實例顯示爲 Pending 狀態,這將持續2分鐘左右,直到轉爲 InService 狀態。
點擊 Open JupyterLab 連接,在新頁面裏將看到熟悉的 Jupyter Notebook 加載界面。本文默認以 JupyterLab 筆記本做爲工程環境,根據須要,也能夠選擇使用傳統的 Jupyter 筆記本。
隨後點擊 conda_pytorch_p36 筆記本圖標建立一個叫作 Untitled.ipynb 的筆記本,稍後能夠更改它的名字。另外,也能夠經過 File > New > Notebook 菜單路徑,並選擇 conda_pytorch_p36 做爲 Kernel 來建立這個筆記本。
在新建的 Untitled.ipynb 筆記本里輸入第一行指令,以下:
import torch print(f"Hello PyTorch {torch.__version__}")
請在筆記本中輸入以下指令,下載代碼到實例本地文件系統:
!git clone "https://github.com/mf523/ml-on-aws.git" "ml-on-aws"
下載完成後,能夠經過 File browser 瀏覽源代碼結構。
本文涉及的代碼和筆記本均經過了 Amazon SageMaker 託管的 Python 3.六、PyTorch 1.4和 JupyterLab 驗證。相關代碼和筆記本能夠經過這裏獲取。
DCGAN 模型的生成網絡包含10層,它使用跨步轉置卷積層來提升張量的分辨率,輸入形狀爲(batchsize, 100),輸出形狀爲 (batchsize, 64, 64, 3)。換句話說,生成網絡接受噪聲向量,而後通過不斷變換,直到生成最終的圖像。
判別網絡也包含10層,它接收(64, 64, 3)格式的圖片,使用2D卷積層進行下采樣,最後傳遞給全連接層進行分類,分類結果是1或0,即真與假。
DCGAN 模型的訓練過程大體能夠分爲三個子過程。
首先,Generator網絡以一個隨機數做爲輸入,生成一張「假」圖片;接下來分別用「真」圖片和「假」圖片訓練Discriminator網絡並更新參數;最後,更新Generator網絡參數。
項目目錄 byos-pytorch-gan 的文件結構以下,
├── data │ └── empty ├── dcgan │ ├── entry_point.py │ └── model.py ├── dcgan.ipynb ├── helper.py ├── model │ └── empty └── tmp └── empty
文件 model.py 中包含3個類,分別是生成網絡 Generator 和判別網絡Discriminator:
class Generator(nn.Module): ... class Discriminator(nn.Module): ... class DCGAN(object): """ A wrapper class for Generator and Discriminator, 'train_step' method is for single batch training. """ ...
文件 train.py 用於 Generator 和 Discriminator 兩個神經網絡的訓練,主要包含如下幾個方法:
def parse_args(): ... def get_datasets(dataset_name, ...): ... def train(dataloader, hps, ...):
開發和調試階段,能夠從 Linux 命令行直接運行 train.py 腳本。超參數、輸入數據通道、模型和其餘訓練產出物存放目錄均可以經過命令行參數指定。
python dcgan/train.py --dataset qmnist \ --model-dir '/home/myhome/byom-pytorch-gan/model' \ --output-dir '/home/myhome/byom-pytorch-gan/tmp' \ --data-dir '/home/myhome/byom-pytorch-gan/data' \ --hps '{"beta1":0.5,"dataset":"qmnist","epochs":15,"learning-rate":0.0002,"log-interval":64,"nc":1,"nz":100,"sample-interval":100}'
這樣的訓練腳本參數設計,既提供了很好的調試方法,又是與 SageMaker Container 集成的規約和必要條件,很好地兼顧了模型開發的自由度和訓練環境的可移植性。
請查找並打開名爲 dcgan.ipynb 的筆記本文件,訓練過程將由這個筆記本介紹並執行,本節內容代碼部分從略,請以筆記本代碼爲準。
互聯網環境裏有不少公開的數據集,對於機器學習的工程和科研頗有幫助,好比算法學習和效果評價。咱們將使用 QMNIST 這個手寫字體數據集訓練模型,最終生成逼真的「手寫」字體效果圖樣。
PyTorch 框架的 torchvision.datasets 包提供了 QMNIST 數據集,咱們能夠經過以下指令下載 QMNIST 數據集到本地備用:
from torchvision import datasets dataroot = './data' trainset = datasets.QMNIST(root=dataroot, train=True, download=True) testset = datasets.QMNIST(root=dataroot, train=False, download=True)
Amazon SageMaker爲咱們建立了一個默認的 Amazon S3 桶,用來存取機器學習工做流程中可能須要的各類文件和數據。咱們能夠經過 SageMaker SDK 中 sagemaker.session.Session 類的default_bucket 方法得到這個桶的名字:
from sagemaker.session import Session
sess = Session() # S3 bucket for saving code and model artifacts. # Feel free to specify a different bucket here if you wish. bucket = sess.default_bucket()
SageMaker SDK 提供了操做 Amazon S3 服務的包和類,其中 S3Downloader 類用於訪問或下載 S3 裏的對象,而 S3Uploader 則用於將本地文件上傳至 S3。請將已經下載的數據上傳至Amazon S3供模型訓練使用。模型訓練過程不要從互聯網下載數據,避免經過互聯網獲取訓練數據的產生的網絡延遲,同時也規避了因直接訪問互聯網對模型訓練可能產生的安全風險。
from sagemaker.s3 import S3Uploader as s3up s3_data_location = s3up.upload(f"{dataroot}/QMNIST", f"s3://{bucket}/data/qmnist")
經過 sagemaker.getexecutionrole ()方法,當前筆記本能夠獲得預先分配給筆記本實例的角色,這個角色將被用來獲取訓練用的資源,好比下載訓練用框架鏡像、分配 Amazon EC2 計算資源等等。
訓練模型用的超參數能夠在筆記本里定義,實現與算法代碼的分離,在建立訓練任務時傳入超參數,與訓練任務動態結合。
hps = { "learning-rate": 0.0002, "epochs": 15, "dataset": "qmnist", "beta1": 0.5, "sample-interval": 200, "log-interval": 64 }
sagemaker.pytorch 包裏的 PyTorch 類是基於 PyTorch 框架的模型擬合器,能夠用來建立、執行訓練任務,還能夠對訓練完的模型進行部署。參數列表中,train_instance_type 用來指定 CPU 或者 GPU 實例類型,訓練腳本和包括模型代碼所在的目錄經過 source_dir 指定,訓練腳本文件名必須經過 entry_point 明肯定義。這些參數將和其他參數一塊兒被傳遞給訓練任務,他們決定了訓練任務的運行環境和模型訓練時參數。
from sagemaker.pytorch import PyTorch estimator = PyTorch(role=role, entry_point='train.py', source_dir='dcgan', output_path=s3_model_artifacts_location, code_location=s3_custom_code_upload_location, train_instance_count=1, train_instance_type='ml.c5.xlarge', train_use_spot_instances=True, train_max_wait=86400, framework_version='1.4.0', py_version='py3', hyperparameters=hps)
請特別注意 train_use_spot_instances 參數,True 值表明但願優先使用 SPOT 實例。因爲機器學習訓練工做一般須要大量計算資源長時間運行,善用 SPOT 能夠實現有效的成本控制,SPOT 實例價格多是按需實例價格的20%到60%,依據選擇實例類型、區域、時間不一樣實際價格有所不一樣。
建立 PyTorch 對象後,能夠用它來擬合預先存在 Amazon S3 上的數據了。下面的指令將執行訓練任務,訓練數據將以名爲 QMNIST 的輸入通道的方式導入訓練環境。訓練開始執行過程當中,Amazon S3 上的訓練數據將被下載到模型訓練環境的本地文件系統,訓練腳本 train.py 將從本地磁盤加載數據進行訓練。
# Start training estimator.fit({'QMNIST': s3_data_location}, wait=False)
根據選擇的訓練實例不一樣,訓練過程當中可能持續幾十分鐘到幾個小時不等。建議設置 wait 參數爲 False,這個選項將使筆記本與訓練任務分離,在訓練時間長、訓練日誌多的場景下,能夠避免筆記本上下文由於網絡中斷或者會話超時而丟失。訓練任務脫離筆記本後,輸出將暫時不可見,能夠執行以下代碼,筆記本將獲取並載入此前的訓練會話:
%%time from sagemaker.estimator import Estimator # Attaching previous training session training_job_name = estimator.latest_training_job.name attached_estimator = Estimator.attach(training_job_name)
因爲的模型設計考慮到了 GPU 對訓練加速的能力,因此用 GPU 實例訓練會比 CPU 實例快一些。例如 p3.2xlarge 實例大概須要15分鐘左右,而c5.xlarge 實例則可能須要6小時以上。目前模型不支持分佈、並行訓練,因此多實例、多 CPU/GPU 並不會帶來更多的訓練速度提高。
訓練完成後,模型將被上傳到 Amazon S3,上傳位置由建立 PyTorch 對象時提供的 output_path 參數指定。
爲此,咱們須要從 Amazon S3 下載通過訓練的模型到筆記本所在實例的本地文件系統,下面的代碼將載入模型,而後輸入一個隨機數,得到推理結果,以圖片形式展示出來。
執行以下指令加載訓練好的模型,並經過這個模型產生一組「手寫」的數字字體:
from helper import * import matplotlib.pyplot as plt import numpy as np import torch from dcgan.model import Generator device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") params = {'nz': nz, 'nc': nc, 'ngf': ngf} model = load_model(Generator, params, "./model/generator_state.pth", device=device) img = generate_fake_handwriting(model, batch_size=batch_size, nz=nz, device=device) plt.imshow(np.asarray(img))
近些年成長快速的 PyTorch 框架正在獲得普遍的承認和應用,愈來愈多的新模型採用 PyTorch 框架,也有模型被遷移到 PyTorch 上,或者基於 PyTorch 被完整再實現。生態環境持續豐富,應用領域不斷拓展,PyTorch 已成爲事實上的主流框架之一。
Amazon SageMaker 與多種 AWS 服務緊密集成,例如,各類類型和尺寸的 Amazon EC2 計算實例、Amazon S三、Amazon ECR 等,爲機器學習工程實踐提供了端到端的一致體驗。Amazon SageMaker 持續支持主流機器學習框架,PyTorch 就是其中之一。
用 PyTorch 開發的機器學習算法和模型,能夠輕鬆移植到 Amazon SageMaker 的工程和服務環境裏,進而利用 Amazon SageMaker 全託管的Jupyter Notebook、訓練容器鏡像、服務容器鏡像、訓練任務管理、部署環境託管等功能,簡化機器學習工程複雜度,提升生產效率,下降運維成本。
DCGAN 是生成對抗網絡領域中具里程碑意義的一個,是現今不少複雜生成對抗網絡的基石。文首提到的 StyleGAN,用文本合成圖像的 StackGAN,從草圖生成圖像的 Pix2pix,以及互聯網上爭議不斷的 DeepFakes 等,都有 DCGAN 的影子。相信經過本文的介紹和工程實踐,對你們瞭解生成對抗網絡的原理和工程方法會有所幫助。