FPGA上如何求32個輸入的最大值和次大值:分治

上午在論壇看到個熱帖,裏頭的題目挺有意思的,簡單的記錄了一下。面試

0. 題目 

在FPGA上實現一個模塊,求32個輸入中的最大值和次大值,32個輸入由一個時鐘週期給出。(題目來自論壇,面試題,若是以爲不合適請留言刪除)算法

    從我我的的觀點來看,這是一道很好的面試題目:數組

  • 其一是這大概是某些機器學習算法實現過程當中遇到的問題的簡化,是頗有意義的一道題目;
  • 其二是這道題目不只要求FPGA代碼能力,還有不少能夠在算法上優化的可能;

    固然,輸入的位寬可能會影響最終的解題思路和最終的實現可能性。但位寬在必定範圍內,譬如8或者32,解題的方案應該都是一致的,只是會影響最終的頻率。後文針對這一題目作具體分析。(題目沒有說明重複元素如何處理,這裏認爲最大值和次大值能夠是同樣的,即計算重複元素)機器學習

1. 解法

    從算法自己來看,找最大值和次大值的過程很簡單;經過兩次遍歷:第一次求最大值,第二次求次大值; 算法複雜度是O(2n)。FPGA顯然不可能在一個週期內完成如此複雜的操做,通常須要流水設計。這一方法下,整個結構是這樣的ide

  1. 經過比較,求最大值,經過流水線實現兩兩之間的比較,32-16-8-4-2-1經過5個clk的延遲能夠求得最大值;
  2. 因爲須要求取次大值,所以須要肯定最大值的位置,在求最大值的過程當中須要維持最大值的座標;
  3. 最大值座標處取值清零(置爲最小)
  4. 經過流水線實現兩兩之間的比較,32-16-8-4-2-1,再通過5個clk的延遲能夠求得次大值;

    這種解法有若干個缺點,包括:延遲求最大值和次大值分別須要5clk延時,總延遲會超過10個cycles;資源佔用較高,維持最大值座標和清零操做耗費了較多資源,同時爲了計算次大值,須要將輸入寄存若干個週期,寄存器消耗較多。oop

 

    另外一個種思路考慮同時求最大值和次大值,因爲這一邏輯較爲複雜,能夠將其流水化,以下圖。(以8輸入爲例,32輸入須要增長兩級)學習

image

    其中sort模塊完成對4輸入進行排序,獲得最大值和次大值輸出的功能。4個數的排序較爲複雜,這一過程大概須要2-3個cycles完成。對於32輸入而言,輸入數據通過32-16-8-4-2輸出獲得結果,延遲大概也有10個週期。測試

2. 分治

    若是須要在FPGA上實現一個特定的算法,那麼去找一個合適的方法去實現就行了;但若是是要實現一個特定的功能,那麼須要找一個優秀的且適合FPGA實現的方法優化

    求最大值和次大值是一個很不徹底的排序,經過簡單的查找複雜度爲O(2n),且不利於硬件實現。對於排序而言,不管快速排序或者歸併排序都用了分治的思想,若是咱們試圖用分治的思想來解決這一問題。考慮當只有2個輸入時,經過一個比較就能夠獲得輸出,此時獲得的是一個長度爲2的有序數組。若是兩個有序數組,那麼經過兩次比較就能夠獲得最大值和次大值。採用歸併排序的思想,查找最大值和次大值的複雜度爲O(1.5n)(即爲n/2+n/2+n/4… ,不知道有沒有算錯)。採用歸併排序的思想,從算法時間複雜度上看更爲高效了。spa

    那麼這一方案是否適合FPGA實現呢,答案是確定的。分治的局部性適合FPGA的流水實現,框圖以下。(以8輸入爲例,32輸入須要增長兩級)

image

    其中meg模塊內部有兩級的比較器,通常而言1clk就能夠完成,輸入數據通過32-32-16-8-4-2獲得結果,延遲爲5個時鐘週期。實現代碼以下

module test#(
parameter DW = 8
)
(
input clk,
input [32*DW-1 :0] din,
output [DW-1:0] max1,
output [DW-1:0] max2
);

wire[DW-1:0] d[31:0];
generate
    genvar i;
    for(i=0;i<32;i=i+1)
    begin:loop_assign
        assign d[i] = din[DW*i+DW-1:DW*i];
    end
endgenerate

// stage 1,comp
reg[DW-1:0] s1_max[15:0];
reg[DW-1:0] s1_min[15:0];
generate
    for(i=0;i<16;i=i+1)
    begin:loop_comp
        always@(posedge clk)
            if(d[2*i]>d[2*i+1])begin
                s1_max[i] <= d[2*i];
                s1_min[i] <= d[2*i+1];
            end
            else begin
                s1_max[i] <= d[2*i+1];
                s1_min[i] <= d[2*i];        
            end
    end
endgenerate

// stage 2,
wire[DW-1:0] s2_max[7:0];
wire[DW-1:0] s2_min[7:0];
generate
    for(i=0;i<8;i=i+1)
    begin:loop_megs2
        meg u_s2meg(
            .clk(clk),
            .g1_max(s1_max[2*i]),
            .g1_min(s1_min[2*i]),
            .g2_max(s1_max[2*i+1]),
            .g2_min(s1_min[2*i+1]),            
            .max1(s2_max[i]),
            .max2(s2_min[i])
        );
    end
endgenerate
// stage 3,
wire[DW-1:0] s3_max[3:0];
wire[DW-1:0] s3_min[3:0];
generate
    for(i=0;i<4;i=i+1)
    begin:loop_megs3
        meg u_s3meg(
            .clk(clk),
            .g1_max(s2_max[2*i]),
            .g1_min(s2_min[2*i]),
            .g2_max(s2_max[2*i+1]),
            .g2_min(s2_min[2*i+1]),            
            .max1(s3_max[i]),
            .max2(s3_min[i])
        );
    end
endgenerate

// stage 4,
wire[DW-1:0] s4_max[1:0];
wire[DW-1:0] s4_min[1:0];
generate
    for(i=0;i<2;i=i+1)
    begin:loop_megs4
        meg u_s4meg(
            .clk(clk),
            .g1_max(s3_max[2*i]),
            .g1_min(s3_min[2*i]),
            .g2_max(s3_max[2*i+1]),
            .g2_min(s3_min[2*i+1]),            
            .max1(s4_max[i]),
            .max2(s4_min[i])
        );
    end
endgenerate

// stage 5,
meg u_s5meg(
    .clk(clk),
    .g1_max(s4_max[0]),
    .g1_min(s4_min[0]),
    .g2_max(s4_max[1]),
    .g2_min(s4_min[1]),            
    .max1(max1),
    .max2(max2)
);
endmodule

module meg#(
parameter DW = 8
)
(
input clk,
input [DW-1 :0] g1_max,
input [DW-1 :0] g1_min,
input [DW-1 :0] g2_max,
input [DW-1 :0] g2_min,
output reg [DW-1:0] max1,
output reg [DW-1:0] max2
);
always@(posedge clk)
begin
    if(g1_max>g2_max) begin
        max1 <= g1_max;
        if(g2_max>g1_min)
            max2 <= g2_max;
        else
            max2 <= g1_min;
    end
    else begin
        max1 <= g2_max;
        if(g1_max>g2_min)
            max2 <= g1_max;
        else
            max2 <= g2_min;
    end
end
endmodule
View Code

3. 其餘

    簡單測試了上面的代碼,在上一代器件上(20nm FPGA),8bit數據輸入模塊能綜合到很高的頻率,邏輯級數大概是5級左右,對於整個工程而言瓶頸基本不會出如今這一部分。32bit數據輸入因爲數據位寬太大,頻率不會過高,可是經過將meg模塊作一級流水,也幾乎不會成爲整個系統的瓶頸。

    32bit32輸入狀況下,數據輸入位寬爲1024(不是IO輸入,是內部信號)。以前在通訊/數字信號處理方面可能不會用到這麼大位寬的數據,但對於AI領域FPGA的應用,數千比特的輸入應該是很日常的,這的確會影響最終FPGA上實現的效果。要想讓機器學習算法在FPGA上跑得更好,還須要算法和FPGA共同努力纔是。

相關文章
相關標籤/搜索