如何用Deeplearning4j實現GAN

1、Gan的思想java

    Gan的核心所作的事情是在解決一個argminmax的問題,公式:git

    一、求解一個Discriminator,能夠最大尺度的丈量Generator 產生的數據和真實數據之間的分佈距離ui

    二、求解一個Generator,能夠最大程度減少產生數據和真實數據之間的距離this

    gan的原始公式以下:rest

    實際上,咱們不可能真求指望,只能sample出data來近似求解,因而,公式變成以下:code

    因而,求解V的最大值,變成了一個二分類問題,變成了求交叉熵的最小值。orm

2、代碼blog

public class Gan {
	static double lr = 0.01;

	public static void main(String[] args) throws Exception {

		final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER);

		final GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
						.activation(Activation.SIGMOID).build(), "d3")
				.setOutputs("out");

		ComputationGraph net = new ComputationGraph(graphBuilder.build());
		net.init();
		System.out.println(net.summary());
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		net.setListeners(new ScoreIterationListener(100));
		net.getLayers();
		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
	
		INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));

		INDArray labelG = Nd4j.ones(60, 1);

		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();
			INDArray z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution());
			MultiDataSet dataSetD = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelD });
			for(int m=0;m<10;m++){
				trainD(net, dataSetD);
			}
			z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution());
			MultiDataSet dataSetG = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelG });
			trainG(net, dataSetG);

			if (i % 10000 == 0) {
			   net.save(new File("E:/gan.zip"), true);
			}

		}

	}

	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", 0);
		net.setLearningRate("g2", 0);
		net.setLearningRate("g3", 0);
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}

	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}
}

    說明:ip

    一、dl4j並無提供像keras那樣凍結某些層參數的方法,這裏採用設置learningrate爲0的方法,來凍結某些層的參數get

    二、這個的更新器,用的是sgd,不能用其餘的(比方說Adam、Rmsprop),由於這些自適應更新器會考慮前面batch的梯度做爲本次更新的梯度,達不到不更新參數的目的

    三、這裏用了StackVertex,沿着第一維合併張量,也就是合併真實數據樣本和Generator產生的數據樣本,共同訓練Discriminator

    四、訓練過程當中屢次update   Discriminator的參數,以便量出最大距離,讓後更新Generator一次

    五、進行10w次迭代

3、Generator生成手寫數字

    加載訓練好的模型,隨機從NormalDistribution取出一些噪音數據,丟給模型,通過feedForward,取出最後一層Generator的激活值,即是咱們想要的結果,代碼以下:

public class LoadGan {

	public static void main(String[] args) throws Exception {
	    ComputationGraph restored = ComputationGraph.load(new File("E:/gan.zip"), true);
		
		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		INDArray trueExp = train.next().getFeatures();
		Map<String, INDArray> map = restored.feedForward(
				new INDArray[] { Nd4j.rand(new long[] { 50, 10 }, new NormalDistribution()), trueExp }, false);
		INDArray indArray = map.get("g3");// .reshape(20,28,28);
		List<INDArray> list = new ArrayList<>();
		for (int j = 0; j < indArray.size(0); j++) {
			list.add(indArray.getRow(j));
		}
	    
		MNISTVisualizer bestVisualizer = new MNISTVisualizer(1, list, "Gan");

		bestVisualizer.visualize();
	}
	
	
	public static class MNISTVisualizer {
		private double imageScale;
		private List<INDArray> digits; // Digits (as row vectors), one per
										// INDArray
		private String title;
		private int gridWidth;

		public MNISTVisualizer(double imageScale, List<INDArray> digits, String title) {
			this(imageScale, digits, title, 5);
		}

		public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth) {
			this.imageScale = imageScale;
			this.digits = digits;
			this.title = title;
			this.gridWidth = gridWidth;
		}

		public void visualize() {
			JFrame frame = new JFrame();
			frame.setTitle(title);
			frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

			JPanel panel = new JPanel();
			panel.setLayout(new GridLayout(0, gridWidth));

			List<JLabel> list = getComponents();
			for (JLabel image : list) {
				panel.add(image);
			}

			frame.add(panel);
			frame.setVisible(true);
			frame.pack();
		}

		public List<JLabel> getComponents() {
			List<JLabel> images = new ArrayList<>();
			for (INDArray arr : digits) {
				BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
				for (int i = 0; i < 784; i++) {
					bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i)));
				}
				ImageIcon orig = new ImageIcon(bi);
				Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28),
						Image.SCALE_DEFAULT);
				ImageIcon scaled = new ImageIcon(imageScaled);
				images.add(new JLabel(scaled));
			}
			return images;
		}
	}
}

    實際效果,還算比較清晰

 

 

快樂源於分享。

此博客乃做者原創, 轉載請註明出處

相關文章
相關標籤/搜索