1、Gan的思想java
Gan的核心所作的事情是在解決一個argminmax的問題,公式:git
一、求解一個Discriminator,能夠最大尺度的丈量Generator 產生的數據和真實數據之間的分佈距離ui
二、求解一個Generator,能夠最大程度減少產生數據和真實數據之間的距離this
gan的原始公式以下:rest
實際上,咱們不可能真求指望,只能sample出data來近似求解,因而,公式變成以下:code
因而,求解V的最大值,變成了一個二分類問題,變成了求交叉熵的最小值。orm
2、代碼blog
public class Gan { static double lr = 0.01; public static void main(String[] args) throws Exception { final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr)) .weightInit(WeightInit.XAVIER); final GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard) .addInputs("input1", "input2") .addLayer("g1", new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build(), "input1") .addLayer("g2", new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build(), "g1") .addLayer("g3", new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build(), "g2") .addVertex("stack", new StackVertex(), "input2", "g3") .addLayer("d1", new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build(), "stack") .addLayer("d2", new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build(), "d1") .addLayer("d3", new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build(), "d2") .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1) .activation(Activation.SIGMOID).build(), "d3") .setOutputs("out"); ComputationGraph net = new ComputationGraph(graphBuilder.build()); net.init(); System.out.println(net.summary()); UIServer uiServer = UIServer.getInstance(); StatsStorage statsStorage = new InMemoryStatsStorage(); uiServer.attach(statsStorage); net.setListeners(new ScoreIterationListener(100)); net.getLayers(); DataSetIterator train = new MnistDataSetIterator(30, true, 12345); INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1)); INDArray labelG = Nd4j.ones(60, 1); for (int i = 1; i <= 100000; i++) { if (!train.hasNext()) { train.reset(); } INDArray trueExp = train.next().getFeatures(); INDArray z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution()); MultiDataSet dataSetD = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp }, new INDArray[] { labelD }); for(int m=0;m<10;m++){ trainD(net, dataSetD); } z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution()); MultiDataSet dataSetG = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp }, new INDArray[] { labelG }); trainG(net, dataSetG); if (i % 10000 == 0) { net.save(new File("E:/gan.zip"), true); } } } public static void trainD(ComputationGraph net, MultiDataSet dataSet) { net.setLearningRate("g1", 0); net.setLearningRate("g2", 0); net.setLearningRate("g3", 0); net.setLearningRate("d1", lr); net.setLearningRate("d2", lr); net.setLearningRate("d3", lr); net.setLearningRate("out", lr); net.fit(dataSet); } public static void trainG(ComputationGraph net, MultiDataSet dataSet) { net.setLearningRate("g1", lr); net.setLearningRate("g2", lr); net.setLearningRate("g3", lr); net.setLearningRate("d1", 0); net.setLearningRate("d2", 0); net.setLearningRate("d3", 0); net.setLearningRate("out", 0); net.fit(dataSet); } }
說明:ip
一、dl4j並無提供像keras那樣凍結某些層參數的方法,這裏採用設置learningrate爲0的方法,來凍結某些層的參數get
二、這個的更新器,用的是sgd,不能用其餘的(比方說Adam、Rmsprop),由於這些自適應更新器會考慮前面batch的梯度做爲本次更新的梯度,達不到不更新參數的目的
三、這裏用了StackVertex,沿着第一維合併張量,也就是合併真實數據樣本和Generator產生的數據樣本,共同訓練Discriminator
四、訓練過程當中屢次update Discriminator的參數,以便量出最大距離,讓後更新Generator一次
五、進行10w次迭代
3、Generator生成手寫數字
加載訓練好的模型,隨機從NormalDistribution取出一些噪音數據,丟給模型,通過feedForward,取出最後一層Generator的激活值,即是咱們想要的結果,代碼以下:
public class LoadGan { public static void main(String[] args) throws Exception { ComputationGraph restored = ComputationGraph.load(new File("E:/gan.zip"), true); DataSetIterator train = new MnistDataSetIterator(30, true, 12345); INDArray trueExp = train.next().getFeatures(); Map<String, INDArray> map = restored.feedForward( new INDArray[] { Nd4j.rand(new long[] { 50, 10 }, new NormalDistribution()), trueExp }, false); INDArray indArray = map.get("g3");// .reshape(20,28,28); List<INDArray> list = new ArrayList<>(); for (int j = 0; j < indArray.size(0); j++) { list.add(indArray.getRow(j)); } MNISTVisualizer bestVisualizer = new MNISTVisualizer(1, list, "Gan"); bestVisualizer.visualize(); } public static class MNISTVisualizer { private double imageScale; private List<INDArray> digits; // Digits (as row vectors), one per // INDArray private String title; private int gridWidth; public MNISTVisualizer(double imageScale, List<INDArray> digits, String title) { this(imageScale, digits, title, 5); } public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth) { this.imageScale = imageScale; this.digits = digits; this.title = title; this.gridWidth = gridWidth; } public void visualize() { JFrame frame = new JFrame(); frame.setTitle(title); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); JPanel panel = new JPanel(); panel.setLayout(new GridLayout(0, gridWidth)); List<JLabel> list = getComponents(); for (JLabel image : list) { panel.add(image); } frame.add(panel); frame.setVisible(true); frame.pack(); } public List<JLabel> getComponents() { List<JLabel> images = new ArrayList<>(); for (INDArray arr : digits) { BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); for (int i = 0; i < 784; i++) { bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i))); } ImageIcon orig = new ImageIcon(bi); Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28), Image.SCALE_DEFAULT); ImageIcon scaled = new ImageIcon(imageScaled); images.add(new JLabel(scaled)); } return images; } } }
實際效果,還算比較清晰
快樂源於分享。
此博客乃做者原創, 轉載請註明出處