DL4J之CNN對今日頭條文本分類

1、數據集介紹java

    數據來源:今日頭條客戶端git

    數據格式以下:github

6551700932705387022_!_101_!_news_culture_!_京城最值得你來場文化之旅的博物館_!_保利集團,馬未都,中國科學技術館,博物館,新中國
6552368441838272771_!_101_!_news_culture_!_發酵牀的墊料種類有哪些?哪一種更好?_!_
6552407965343678723_!_101_!_news_culture_!_上聯:黃山黃河黃皮膚黃土高原。怎麼對下聯?_!_
6552332417753940238_!_101_!_news_culture_!_林徽因什麼理由拒絕了徐志摩而選擇梁思成爲終身伴侶?_!_
6552475601595269390_!_101_!_news_culture_!_黃楊木是什麼樹?_!_

    每行爲一條數據,以_!_分割的個字段,從前日後分別是 新聞ID,分類code(見下文),分類名稱(見下文),新聞字符串(僅含標題),新聞關鍵詞web

    分類code與名稱:設計模式

100 民生 故事 news_story
101 文化 文化 news_culture
102 娛樂 娛樂 news_entertainment
103 體育 體育 news_sports
104 財經 財經 news_finance
106 房產 房產 news_house
107 汽車 汽車 news_car
108 教育 教育 news_edu 
109 科技 科技 news_tech
110 軍事 軍事 news_military
112 旅遊 旅遊 news_travel
113 國際 國際 news_world
114 證券 股票 stock
115 農業 三農 news_agriculture
116 電競 遊戲 news_game

    github地址:https://github.com/fate233/toutiao-text-classfication-dataset網絡

    數據資源中給出了分類的實驗結果:架構

Test Loss:   0.57, Test Acc:  83.81%

                    precision    recall  f1-score   support

        news_story       0.66      0.75      0.70       848

      news_culture       0.57      0.83      0.68      1531

news_entertainment       0.86      0.86      0.86      8078

       news_sports       0.94      0.91      0.92      7338

      news_finance       0.59      0.67      0.63      1594

        news_house       0.84      0.89      0.87      1478

          news_car       0.92      0.90      0.91      6481

          news_edu       0.71      0.86      0.77      1425

         news_tech       0.85      0.84      0.85      6944

     news_military       0.90      0.78      0.84      6174

       news_travel       0.58      0.76      0.66      1287

        news_world       0.72      0.69      0.70      3823

             stock       0.00      0.00      0.00        53

  news_agriculture       0.80      0.88      0.84      1701

         news_game       0.92      0.87      0.89      6244

       avg / total       0.85      0.84      0.84     54999

   下面咱們就來用deeplearning4j來實現一個卷積結構對該數據集進行分類,看能不能獲得更好的結果。app

2、卷積網絡能夠用於文本處理的緣由dom

    CNN很是適合處理圖像數據,前面一篇文章《deeplearning4j——卷積神經網絡對驗證碼進行識別》介紹了CNN對驗證碼進行識別。本篇博客將利用CNN對文本進行分類,在開始以前咱們先來直觀的說說卷積運算在作的本質事情是什麼。卷積運算,本質上能夠看作兩個向量的點積,兩個向量越同向,點積就越大,通過relu和MaxPooling以後,本質上是提取了與卷積核最同向的結構,這個「結構」其實是圖片上的一些線條。ide

    那麼文本能夠用CNN來處理嗎?答案是確定的,文本每一個詞用向量表示以後,依次排開,就變成了一張二維圖,以下圖,沿着紅色箭頭的方向(也就是文本的方向)看,兩個句子用一幅圖表示以後,會出現相同的單元,也就能夠用CNN來處理。

    

3、文本處理的卷積結構

    那麼,怎麼設計這個CNN網絡結構呢?以下圖:(論文地址:https://arxiv.org/abs/1408.5882

    

   注意點:

   一、卷積核移動的方向必須爲句子的方向

   二、每一個卷積核提取的特徵爲N行1列的向量

   三、MaxPooling的操做的對象是每個Feature Map,也就是從每個N行1列的向量中選擇一個最大值

   四、把選擇的全部最大值接起來,通過幾個Fully Connected 層,進行分類

4、數據的預處理與詞向量

    一、分詞工具:HanLP

    二、處理後的數據格式以下:(類別code_!_詞,其中,詞與詞之間用空格隔開,_!_爲分割符)

   

    數據預處理代碼以下:

public static void main(String[] args) throws Exception {
		BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
				new FileInputStream(new File("/toutiao_cat_data/toutiao_cat_data.txt")), "UTF-8"));
		OutputStreamWriter writerStream = new OutputStreamWriter(
				new FileOutputStream("/toutiao_cat_data/toutiao_data_type_word.txt"), "UTF-8");
		BufferedWriter writer = new BufferedWriter(writerStream);
		String line = null;
		long startTime = System.currentTimeMillis();
		while ((line = bufferedReader.readLine()) != null) {
			String[] array = line.split("_!_");
			StringBuilder stringBuilder = new StringBuilder();
			for (Term term : HanLP.segment(array[3])) {
				if (stringBuilder.length() > 0) {
					stringBuilder.append(" ");
				}
				stringBuilder.append(term.word.trim());
			}
			writer.write(Integer.parseInt(array[1].trim()) + "_!_" + stringBuilder.toString() + "\n");
		}
		writer.flush();
		writer.close();
		System.out.println(System.currentTimeMillis() - startTime);
		bufferedReader.close();
	}

5、詞的向量表示

    一、one-hot

    用正交的向量來表示每個詞,這樣表示沒法反應詞與詞之間的關係,那麼兩句話中,要想複用同一個卷積核,那麼必須出現如出一轍的詞才能夠,實際上,咱們要求模型能夠觸類旁通,連類似的結構也能夠提取,那麼word2vec能夠解決這個問題。

    二、word2vec

    word2vec能夠充分考慮詞與詞之間的關係,類似的詞,確定有某些維度靠的比較近。那麼也就考慮了詞的語句之間的關係,訓練word2vec有兩種,skipgram和cbow,下面咱們用cbow來訓練詞向量,結果會持久化下來,就獲得了toutiao.vec的文件,下次變可從新加載該文件得到詞的向量表示,代碼以下:

String filePath = new ClassPathResource("toutiao_data_word.txt").getFile().getAbsolutePath();
		SentenceIterator iter = new BasicLineIterator(filePath);
		TokenizerFactory t = new DefaultTokenizerFactory();
		t.setTokenPreProcessor(new CommonPreprocessor());
		VocabCache<VocabWord> cache = new AbstractCache<>();
		WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100)
				.useAdaGrad(false).cache(cache).build();

		log.info("Building model....");
		Word2Vec vec = new Word2Vec.Builder()
				.elementsLearningAlgorithm("org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW")
				.minWordFrequency(0).iterations(1).epochs(20).layerSize(100).seed(42).windowSize(8).iterate(iter)
				.tokenizerFactory(t).lookupTable(table).vocabCache(cache).build();

		vec.fit();
		WordVectorSerializer.writeWord2VecModel(vec, "/toutiao_cat_data/toutiao.vec");

6、CNN網絡結構

    CNN網絡結構以下:

    說明:

    一、cnn三、cnn四、cnn五、cnn6卷積核大小爲(3,vectorSize)、(4,vectorSize)、(5,vectorSize)、(6,vectorSize),步幅爲1,也就是分別讀取三、四、五、6個詞,提取特徵

    二、cnn3-stride二、cnn4-stride二、cnn5-stride二、cnn6-stride2卷積核大小爲(3,vectorSize)、(4,vectorSize)、(5,vectorSize)、(6,vectorSize),步幅爲2

    三、兩組卷積核卷積的結果合併,分別獲得merge1和merge2,都是4維張量,形狀分別爲(batchSize,depth1+depth2+depth3,height/1,1),(batchSize,depth1+depth2+depth3,height/2,1),特別說明:這裏的卷積模式爲ConvolutionMode.Same

    四、merge一、2分別通過MaxPooling,這裏用的是GlobalPoolingLayer,和平臺的Pooling層不一樣,這裏會從指定維度中,取一個最大值,因此通過GlobalPoolingLayer以後,merge一、2分別變成2維張量,形狀爲(batchSize,depth1+depth2+depth3),那麼GlobalPoolingLayer是如何求Max的呢?源碼以下:

private INDArray activateHelperFullArray(INDArray inputArray, int[] poolDim) {
        switch (poolingType) {
            case MAX:
                return inputArray.max(poolDim);
            case AVG:
                return inputArray.mean(poolDim);
            case SUM:
                return inputArray.sum(poolDim);
            case PNORM:
                //P norm: https://arxiv.org/pdf/1311.1780.pdf
                //out = (1/N * sum( |in| ^ p) ) ^ (1/p)
                int pnorm = layerConf().getPnorm();

                INDArray abs = Transforms.abs(inputArray, true);
                Transforms.pow(abs, pnorm, false);
                INDArray pNorm = abs.sum(poolDim);

                return Transforms.pow(pNorm, 1.0 / pnorm, false);
            default:
                throw new RuntimeException("Unknown or not supported pooling type: " + poolingType + " " + layerId());
        }
    }

    五、兩邊GlobalPoolingLayer結果再接起來,丟給全鏈接網絡,通過softmax分類器進行分類

    六、fc層,用了0.5的dropout防止過擬合,在下面的代碼中能夠看到。

完整代碼以下:

public class CnnSentenceClassificationTouTiao {

	public static void main(String[] args) throws Exception {

		List<String> trainLabelList = new ArrayList<>();// 訓練集label
		List<String> trainSentences = new ArrayList<>();// 訓練集文本集合
		List<String> testLabelList = new ArrayList<>();// 測試集label
		List<String> testSentences = new ArrayList<>();//// 測試集文本集合
		Map<String, List<String>> map = new HashMap<>();

		BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
				new FileInputStream(new File("/toutiao_cat_data/toutiao_data_type_word.txt")), "UTF-8"));
		String line = null;
		int truncateReviewsToLength = 0;
		Random random = new Random(123);
		while ((line = bufferedReader.readLine()) != null) {
			String[] array = line.split("_!_");
			if (map.get(array[0]) == null) {
				map.put(array[0], new ArrayList<String>());
			}
			map.get(array[0]).add(array[1]);// 將樣本中全部數據,按照類別歸類
			int length = array[1].split(" ").length;
			if (length > truncateReviewsToLength) {
				truncateReviewsToLength = length;// 求樣本中,句子的最大長度
			}
		}
		bufferedReader.close();
		for (Map.Entry<String, List<String>> entry : map.entrySet()) {
			for (String sentence : entry.getValue()) {
				if (random.nextInt() % 5 == 0) {// 每一個類別抽取20%做爲test集
					testLabelList.add(entry.getKey());
					testSentences.add(sentence);
				} else {
					trainLabelList.add(entry.getKey());
					trainSentences.add(sentence);
				}
			}

		}
		int batchSize = 64;
		int vectorSize = 100;
		int nEpochs = 10;

		int cnnLayerFeatureMaps = 50;
		PoolingType globalPoolingType = PoolingType.MAX;
		Random rng = new Random(12345);
		Nd4j.getMemoryManager().setAutoGcWindow(5000);

		ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU)
				.activation(Activation.LEAKYRELU).updater(new Nesterovs(0.01, 0.9))
				.convolutionMode(ConvolutionMode.Same).l2(0.0001).graphBuilder().addInputs("input")
				.addLayer("cnn3",
						new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(1, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addLayer("cnn4",
						new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(1, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addLayer("cnn5",
						new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(1, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addLayer("cnn6",
						new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(1, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addLayer("cnn3-stride2",
						new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(2, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addLayer("cnn4-stride2",
						new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(2, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addLayer("cnn5-stride2",
						new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(2, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addLayer("cnn6-stride2",
						new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(2, vectorSize)
								.nOut(cnnLayerFeatureMaps).build(),
						"input")
				.addVertex("merge1", new MergeVertex(), "cnn3", "cnn4", "cnn5", "cnn6")
				.addLayer("globalPool1", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
						"merge1")
				.addVertex("merge2", new MergeVertex(), "cnn3-stride2", "cnn4-stride2", "cnn5-stride2", "cnn6-stride2")
				.addLayer("globalPool2", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
						"merge2")
				.addLayer("fc",
						new DenseLayer.Builder().nOut(200).dropOut(0.5).activation(Activation.LEAKYRELU).build(),
						"globalPool1", "globalPool2")
				.addLayer("out",
						new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
								.activation(Activation.SOFTMAX).nOut(15).build(),
						"fc")
				.setOutputs("out").setInputTypes(InputType.convolutional(truncateReviewsToLength, vectorSize, 1))
				.build();

		ComputationGraph net = new ComputationGraph(config);
		net.init();
		System.out.println(net.summary());
		Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("/toutiao_cat_data/toutiao.vec");
		System.out.println("Loading word vectors and creating DataSetIterators");
		DataSetIterator trainIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, trainLabelList,
				trainSentences, rng);
		DataSetIterator testIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, testLabelList,
				testSentences, rng);

		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		net.setListeners(new ScoreIterationListener(100), new StatsListener(statsStorage, 20),
				new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));

		// net.setListeners(new ScoreIterationListener(100),
		// new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
		net.fit(trainIter, nEpochs);
	}

	private static DataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
			List<String> lableList, List<String> sentences, Random rng) {

		LabeledSentenceProvider sentenceProvider = new CollectionLabeledSentenceProvider(sentences, lableList, rng);

		return new CnnSentenceDataSetIterator.Builder().sentenceProvider(sentenceProvider).wordVectors(wordVectors)
				.minibatchSize(minibatchSize).maxSentenceLength(maxSentenceLength).useNormalizedWordVectors(false)
				.build();
	}
}

 代碼說明:

    一、代碼分兩部分,第一部分是數據預處理,分出20%測試集、80%做爲訓練集

    二、第二部分爲網絡的基本結構代碼

網絡參數詳細以下:

===============================================================================================================================================
VertexName (VertexType)            nIn,nOut   TotalParams   ParamsShape                Vertex Inputs                                           
===============================================================================================================================================
input (InputVertex)                -,-        -             -                          -                                                       
cnn3 (ConvolutionLayer)            1,50       15050         W:{50,1,3,100}, b:{1,50}   [input]                                                 
cnn4 (ConvolutionLayer)            1,50       20050         W:{50,1,4,100}, b:{1,50}   [input]                                                 
cnn5 (ConvolutionLayer)            1,50       25050         W:{50,1,5,100}, b:{1,50}   [input]                                                 
cnn6 (ConvolutionLayer)            1,50       30050         W:{50,1,6,100}, b:{1,50}   [input]                                                 
cnn3-stride2 (ConvolutionLayer)    1,50       15050         W:{50,1,3,100}, b:{1,50}   [input]                                                 
cnn4-stride2 (ConvolutionLayer)    1,50       20050         W:{50,1,4,100}, b:{1,50}   [input]                                                 
cnn5-stride2 (ConvolutionLayer)    1,50       25050         W:{50,1,5,100}, b:{1,50}   [input]                                                 
cnn6-stride2 (ConvolutionLayer)    1,50       30050         W:{50,1,6,100}, b:{1,50}   [input]                                                 
merge1 (MergeVertex)               -,-        -             -                          [cnn3, cnn4, cnn5, cnn6]                                
merge2 (MergeVertex)               -,-        -             -                          [cnn3-stride2, cnn4-stride2, cnn5-stride2, cnn6-stride2]
globalPool1 (GlobalPoolingLayer)   -,-        0             -                          [merge1]                                                
globalPool2 (GlobalPoolingLayer)   -,-        0             -                          [merge2]                                                
fc-merge (MergeVertex)             -,-        -             -                          [globalPool1, globalPool2]                              
fc (DenseLayer)                    400,200    80200         W:{400,200}, b:{1,200}     [fc-merge]                                              
out (OutputLayer)                  200,15     3015          W:{200,15}, b:{1,15}       [fc]                                                    
-----------------------------------------------------------------------------------------------------------------------------------------------
            Total Parameters:  263615
        Trainable Parameters:  263615
           Frozen Parameters:  0
===============================================================================================================================================

 DL4J的UIServer界面以下,這裏我給定的端口號爲9001,打開web界面能夠看到平均loss的詳情,梯度更新的詳情等

http://localhost:9001/train/overview

 7、掩模

    句子有長有短,CNN將如何處理呢?

    處理的辦法其實很暴力,將一個minibatch中的最長句子找到,new出最大長度的張量,多餘值用掩模掩掉便可,廢話很少說,直接上代碼

if(sentencesAlongHeight){
                    featuresMask = Nd4j.create(currMinibatchSize, 1, maxLength, 1);
                    for (int i = 0; i < currMinibatchSize; i++) {
                        int sentenceLength = tokenizedSentences.get(i).getFirst().size();
                        if (sentenceLength >= maxLength) {
                            featuresMask.slice(i).assign(1.0);
                        } else {
                            featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength), NDArrayIndex.point(0)).assign(1.0);
                        }
                    }
                } else {
                    featuresMask = Nd4j.create(currMinibatchSize, 1, 1, maxLength);
                    for (int i = 0; i < currMinibatchSize; i++) {
                        int sentenceLength = tokenizedSentences.get(i).getFirst().size();
                        if (sentenceLength >= maxLength) {
                            featuresMask.slice(i).assign(1.0);
                        } else {
                            featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
                        }
                    }
                }

    這裏爲何有個if呢?生成句子張量的時候,能夠任意指定句子的方向,能夠沿着矩陣中height的方向,也能夠是width的方向,方向不一樣,填掩模的那一維也就不一樣。

8、結果

    運行了10個Epoch結果以下:

========================Evaluation Metrics========================
 # of classes:    15
 Accuracy:        0.8420
 Precision:       0.8362	(1 class excluded from average)
 Recall:          0.7783
 F1 Score:        0.8346	(1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 15 classes)

Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [12]

=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9   10   11   12   13   14
----------------------------------------------------------------------------
  973   35  114    2    9    8   11   19   14    6   19   11    0   22   13 | 0 = 0
   17 4636  250   37   51   16   14  151   47   29  232   36    0   82   44 | 1 = 1
  103  176 6980  108   16    8   31   62   83   41   53   77    0   36  163 | 2 = 2
    9   78  244 6692   37    9   52   59   33   27   57   54    0   10   96 | 3 = 3
    7   52   36   31 4072   96  101  107  581   20   64  108    0  135   37 | 4 = 4
   12   18   22    8  150 3061   27   36   53    2  100   16    0   56    2 | 5 = 5
   17   38   71   26   94   13 6443   43  174   31  121   39    0   32   34 | 6 = 6
   17  157   93   49   62   20   34 4793   85   14   58   36    0   49   31 | 7 = 7
    1   45   71   21  436   30  195  138 7018   48   54   49    0   45  148 | 8 = 8
   24   74   84   47   24    1   57   50   68 3963   45  431    0    9   65 | 9 = 9
    9  165   90   21   40   37   61   40   42   21 3428  111    0   78   30 | 10 = 10
   47   78  173   52  114   20   48   67   93  320  140 4097    0   48   29 | 11 = 11
    0    0    0    0   60    0    1    0    5    0    0    0    0    0    0 | 12 = 12
   35  105   31    6  139   37   34   61   79   11  153   35    0 3187   12 | 13 = 13
   14   36  210  128   31    2   19   20  164   44   38   15    0   19 5183 | 14 = 14

    平均準確率0.8420,比原資源中給定的結果略好,F1 score要略差一點,混淆矩陣中,有一個類別,沒法被預測到,是由於樣本中改類別數據量自己不多,難以抓到共性特徵。這裏參數若是精心調節一番,迭代更屢次數,理論上會有更好的表現。

9、後記    

    讀Deeplearning4j是一種享受,優雅的架構,清晰的邏輯,多種設計模式,擴展性強,將有後續博客,對dl4j源碼進行剖析。

    

快樂源於分享。

此博客乃做者原創, 轉載請註明出處

相關文章
相關標籤/搜索