給定一個無序的整型數組arr,找到其中最小的k個數程序員
該題是互聯網面試中十分高頻的一道題,若是用普通的排序算法,排序以後天然能夠獲得最小的k個數,但時間複雜度高達O(NlogN),且普通的排序算法均屬於內部排序,須要一次性將所有數據裝入內存,對於求解海量數據的top k問題是無能爲力的。面試
針對海量數據的top k問題,這裏實現了一種時間複雜度爲O(Nlogk)的有效算法:初始時一次性從文件中讀取k個數據,並創建一個有k個數的最大堆,表明目前選出的最小的k個數。而後從文件中一個一個的讀取剩餘數據,若是讀取的數據比堆頂元素小,則把堆頂元素替換成當前的數,而後從堆頂向下從新進行堆調整;不然不進行任何操做,繼續讀取下一個數據。直到文件中的全部數據讀取完畢,堆中的k個數就是海量數據中最小的k個數(若是是找最大的k個數,則使用最小堆)。具體過程請參看以下代碼:算法
public class FindKMinNums { /** * 維護一個有k個數的最大堆,表明目前選出的最小的k個數 * * @param read 實際場景中,read提供的數據須要從文件中讀取,這裏爲了方便用數組表示 * @param k * @return */ public static int[] getKMinsByHeap(int[] read, int k) { if (k < 1 || k > read.length) { return read; } int[] kHeap = new int[k]; for (int i = 0; i < k; i++) { // 初始時一次性從文件中讀取k個數據 kHeap[i] = read[i]; } buildHeap(kHeap, k); // 建堆,時間複雜度O(k) for (int i = k; i < read.length; i++) { // 從文件中一個一個的讀取剩餘數據 if (read[i] < kHeap[0]) { kHeap[0] = read[i]; heapify(kHeap, 0, k); // 從堆頂開始向下進行調整,時間複雜度O(logk) } } return kHeap; } /** * 建堆函數 * * @param arr * @param n */ public static void buildHeap(int[] arr, int n) { for (int i = n / 2 - 1; i >= 0; i--) { heapify(arr, i, n); } } /** * 從arr[i]向下進行堆調整 * * @param arr * @param i * @param heapSize */ public static void heapify(int[] arr, int i, int heapSize) { int leftChild = 2 * i + 1; int rightChild = 2 * i + 2; int max = i; if (leftChild < heapSize && arr[leftChild] > arr[max]) { max = leftChild; } if (rightChild < heapSize && arr[rightChild] > arr[max]) { max = rightChild; } if (max != i) { swap(arr, i, max); heapify(arr, max, heapSize); // 堆結構發生了變化,繼續向下進行堆調整 } } public static void swap(int[] arr, int i, int j) { int tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp; } public static void printArray(int[] arr) { for (int i = 0; i <= arr.length; i++) { System.out.print(arr[i] + " "); } System.out.println(); } public static void main(String[] args) { int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9}; // sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 } printArray(getKMinsByHeap(arr, 10)); } }
對於從海量數據(N)中找出TOP K,這種算法僅需一次性將k個數裝入內存,其他數據從文件一個一個讀便可,因此它是針對海量數據TOP K問題最爲有效的算法api
對於非海量數據的狀況,還有一種時間複雜度僅爲O(N)的經典算法 —— BFPRT算法,該算法於1973年由Blum、Floyd、Pratt、Rivest和Tarjan聯合發明,其中蘊含的深入思想改變了世界。數組
BFPRT算法解決了這樣一個問題:在時間複雜度O(N)內,從無序的數組中找到第k小的數。顯而易見的是,若是咱們找到了第k小的數,那麼想求arr中最小的k個數,只需再遍歷一遍數組,把小於第k小的數都蒐集起來,再把不足部分用第k小的數補全便可。函數
BFPRT算法是如何找到第k小的數?如下是BFPRT算法的過程,假設BFPRT算法的函數是int select(int[] arr, int k)
,該函數的功能爲在arr中找到第k小的數,而後返回該數。select(arr, k)
的過程以下:優化
將arr中的n個元素劃分紅 n/5 組,每組5個元素,若是最後的組不夠5個元素,那麼最後剩下的元素爲一組(n%5 個元素)。時間複雜度O(1)ui
對每一個組進行排序,好比選擇簡單的插入排序,只針對每一個組最多5個元素之間的組內排序,組與組之間不排序。時間複雜度 N/5O(1)code
找到每一個組的中位數,若是元素個數爲偶數能夠找下中位數,讓這些中位數組成一個新的數組,記爲mArr。時間複雜度O(N/5)排序
遞歸調用select(mArr, mArr.length / 2)
,意義是找到mArr這個數組的中位數x,即中位數的中位數。時間複雜度T(N/2)
根據上面獲得的x劃分整個arr數組(partition過程),劃分的過程爲:在arr中,比x小的都在x左邊,比x大的都在x右邊,x在中間。時間複雜度O(N)
假設劃分完成後,x在arr中的位置記爲i,關於i與k的相對大小,有以下三種狀況:
上述過程的代碼實現以下:
public class FindKMinNums { /** * 先用BFPRT算法求出第k小的數,再遍歷一遍數組才能求出最小的k個數,時間複雜度O(N) * 須要將全部數據一次性裝入內存,適用於非海量數據的狀況 * * @param arr * @param k * @return */ public static int[] getKMins(int[] arr, int k) { if (k < 1 || k > arr.length) { return arr; } int kthMin = getKthMinByBFPRT(arr, k); // 使用BFPRT算法求得第k小的數,O(N) int[] kMins = new int[k]; // 下面遍歷一遍數組,利用第k小的數找到最小的k個數,O(N) int index = 0; for (int i = 0; i < arr.length; i++) { if (arr[i] < kthMin) { // 小於第k小的數,必然屬於最小的k個數 kMins[index++] = arr[i]; } } while (index < k) { kMins[index++] = kthMin; // 不足部分用第k小的數補全 } return kMins; } /** * 使用BFPRT算法求第k小的數 * * @param arr * @param k * @return */ public static int getKthMinByBFPRT(int[] arr, int k) { int[] arrCopy = copyArray(arr); // 在獲得第k小的數以後還要遍歷一遍原數組,因此並不直接操做原數組 return select(arrCopy, 0, arrCopy.length - 1, k - 1); // 第k小的數,即排好序後下標爲k-1的數 } /** * 拷貝數組 * * @param arr * @return */ public static int[] copyArray(int[] arr) { int[] arrCopy = new int[arr.length]; for (int i = 0; i < arrCopy.length; i++) { arrCopy[i] = arr[i]; } return arrCopy; } /** * 在數組arr的下標範圍[begin, end]內,找到排序後位於整個arr數組下標爲index的數 * * @param arr * @param begin * @param end * @param index * @return */ public static int select(int[] arr, int begin, int end, int index) { if (begin == end) { return arr[begin]; } int pivot = medianOfMedians(arr, begin, end); // 核心操做:中位數的中位數做爲基準 int[] pivotRange = partition(arr, begin, end, pivot); // 拿到分區後中區的範圍 if (index >= pivotRange[0] && index <= pivotRange[1]) { // 命中 return arr[index]; } else if (index < pivotRange[0]) { return select(arr, begin, pivotRange[0] - 1, index); } else { return select(arr, pivotRange[1] + 1, end, index); } } /** * 選基準 * * @param arr * @param begin * @param end * @return */ public static int medianOfMedians(int[] arr, int begin, int end) { int num = end - begin + 1; int offset = num % 5 == 0 ? 0 : 1; // 5個成一組,不滿5個的本身成一組 int[] mArr = new int[num / 5 + offset]; // 每組的中位數取出構成中位數數組mArr for (int i = 0; i < mArr.length; i++) { int beginI = begin + i * 5; int endI = beginI + 4; mArr[i] = getMedian(arr, beginI, Math.min(endI, end)); } // 求中位數數組mArr的中位數,做爲基準返回 return select(mArr, 0, mArr.length - 1, mArr.length / 2); } /** * 在數組arr的下標範圍[begin, end]內,找中位數,若是元素個數爲偶數則找下中位數 * * @param arr * @param begin * @param end * @return */ public static int getMedian(int[] arr, int begin, int end) { insertionSort(arr, begin, end); int sum = begin + end; int mid = (sum / 2) + (sum % 2); return arr[mid]; } /** * 這裏僅用於對一組5個數進行插入排序,時間複雜度O(1) * * @param arr * @param begin * @param end */ public static void insertionSort(int[] arr, int begin, int end) { for (int i = begin + 1; i <= end; i++) { int get = arr[i]; int j = i - 1; while (j >= begin && arr[j] > get) { arr[j + 1] = arr[j]; j--; } arr[j + 1] = get; } } /** * 優化後的快排partition操做 * * @param arr * @param begin * @param end * @param pivot * @return 返回劃分後等於基準的元素下標範圍 */ public static int[] partition(int[] arr, int begin, int end, int pivot) { int small = begin - 1; // 小區最後一個元素下標 int big = end + 1; // 大區第一個元素下標 int cur = begin; while (cur < big) { if (arr[cur] < pivot) { swap(arr, ++small, cur++); } else if (arr[cur] > pivot) { swap(arr, --big, cur); } else { cur++; } } int[] range = new int[2]; range[0] = small + 1; // 中區第一個元素下標 range[1] = big - 1; // 中區最後一個元素下標 return range; } public static void swap(int[] arr, int i, int j) { int tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp; } public static void printArray(int[] arr) { for (int i = 0; i < arr.length; i++) { System.out.print(arr[i] + " "); } System.out.println(); } public static void main(String[] args) { int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9}; // sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 } printArray(getKMins(arr, 10)); } }
關於BFPRT算法爲何在時間複雜度上能夠作到穩定的O(N),能夠參考《程序員代碼面試指南》P339或《算法導論》9.3節內容,這裏不作證實。