题目
在一个由 n 个元素组成的集合中,按从小到大顺序排序的话,第 K 个顺序统计位即指第 K 个数,当 K = n 时即最大值,当 K = 1 时即最小值。先给定一个无序的元素集合,求集合中第 K 统计位的值是多少?
同理,若求 top K 的数据的话,即求集合中最大的前 K 个数分别是多少?
例如,给定数组 [ 0, 9, 3, 6, 8, 2, 1, 5, 7, 4 ] ,则第 4 统计位的数字是 3,top 3 大的数是 [ 9, 8, 7 ]。
最大值和最小值
- 简介
先说一个很特殊的场景,碰巧 K = 1 或者 K = n 的时候,即我们常说的最小值和最大值。解法很简单,即遍历一遍集合就可以找出最大值和最小值了,遍历一次集合找出最大值或最小值需要比较 n – 1 次,时间复杂度为 O(n)。
假如有种场景我们要同时得到最大值和最小值,最直观的解法就是遍历一次,通过比较 2 * (n – 1) 次就可以得到最大值和最小值了。但是其实我们只需要比较 n * 3 / 2 次就可以了。
- 代码
#include <iostream>
#include <algorithm>
using namespace std;
int main() {
int arr[] = { 12, 2, 32, 64, -10, 33, -6, 0, 2, 10 };
int length = sizeof(arr) / sizeof(arr[0]);
int max;
int min;
int *p = arr;
if (length & 1) { // 数组长度为奇数
max = min = arr[0];
p++;
}
else { // 数组长度为偶数
if (arr[0] > arr[1]) {
max = arr[0];
min = arr[1];
}
else {
max = arr[1];
min = arr[0];
}
p += 2;
}
while (p < &arr[length]) {
int first = *(p++);
int second = *(p++);
if (first > second) {
max = std::max(max, first);
min = std::min(min, second);
}
else {
max = std::max(max, second);
min = std::min(min, first);
}
}
cout << "max: " << max << ", min: " << min << endl;
return 0;
}
- 评价
该方法只适用于求最大值和最小值的情况。我们每次取两个数,先比较一次得到他们的大小关系,然后我们用较大者和当前最大值比较,用较小者和当前最小值比较,这样我们就只需要比较 3 次就可以比较完 2 个新的元素了。
排序法
- 简介
回到主题,K 往往并不是 0 或者 n,那么怎么求解呢?很多人最直接的想法就是,我把集合排个序,然后再取第 K 位的数或者 top K 那还不是轻而易举了。
当然,排序也是可以偷懒的,例如我们求第 10 大的元素的话,我们只需要进行 10 次冒泡排序或者选择排序即可,即遍历 10 次集合就可以达到我们的目的。
- 代码
冒泡排序或选择排序的代码都很简单,这里不做示例了。 -
评价
值得说明的是,该方法只适合数据集不大或者 K 值很小的情况下使用,假如我们要求海量数据的中位数时,排序法的效率都远不如其它方法了。
最小堆和最大堆法
- 简介
最小堆和最大堆的性质还不知道的赶紧去百度了。
假如我们要求集合中最大的前 K 个数的话,我们可以创建一个大小为 K 的最小堆,它的根结点一定是这 K 个元素中的最小值,然后我们只需要遍历 1 遍集合即可。分别和堆中的最小值(即根结点)比较,如果大于它,则我们使用这个新的值替代根结点并刷新堆,使得堆的根结点依旧是这 K 个元素中的最小值。最后遍历完集合后,堆里的这 K 个值就是整个集合中的 top K 了,堆的根结点就是第 K 大的值了。
同理,求最小的前 K 个数的话,就使用一个大小为 K 的最大堆。
- 代码
public class Main {
public static void main(String[] args) {
int[] arr = new int[] { 12, 0, 88, -36, 24, 256, 4, -2, 64, 56, 88,
72, 100, 6, 12, 32, 96, 54, 48, 36 };
int[] heap = getTopK(arr, 5);
printArray(heap);
}
public static int[] getTopK(int[] arr, int k) {
if (k >= arr.length) {
return arr;
}
int[] heap = new int[k];
System.arraycopy(arr, 0, heap, 0, heap.length);
buildMinHeap(heap);
for (int i = k; i < arr.length; i++) {
if (arr[i] > heap[0]) {
heap[0] = arr[i];
minHeapify(heap, 0, heap.length);
}
}
// 如果需要 Top K 按照从大到小的顺序排序的话
int heapSize = heap.length;
for (int i = heap.length - 1; i > 0; i--) {
int min = heap[0];
heap[0] = heap[i];
heap[i] = min;
minHeapify(heap, 0, --heapSize);
}
return heap;
}
public static void buildMinHeap(int[] heap) {
// 堆的最后一个分支结点索引为 arr.length / 2 - 1
for (int i = heap.length / 2 - 1; i >= 0; i--) {
minHeapify(heap, i, heap.length);
}
}
/**
* 调整堆,使其满足最小堆的性质
*/
public static void minHeapify(int[] heap, int index, int heapSize) {
int leftIndex = index * 2 + 1; // 左子节点对应数组中的索引
int rightIndex = index * 2 + 2; // 右子节点对应数组中的索引
int minIndex = index;
// 如果左子结点较小,则将最小值索引设为左子节点
if (leftIndex < heapSize && heap[leftIndex] < heap[index]) {
minIndex = leftIndex;
}
// 如果右子结点比 min(this, left)还小,则将最小值索引设为右子节点
if (rightIndex < heapSize && heap[rightIndex] < heap[minIndex]) {
minIndex = rightIndex;
}
// 如果当前结点的值不是最小的,则需要交换最小值,并继续遍历交换后的子结点
if (minIndex != index) {
int temp = heap[minIndex];
heap[minIndex] = heap[index];
heap[index] = temp;
minHeapify(heap, minIndex, heapSize);
}
}
public static void printArray(int[] arr) {
for (int i = 0; i < arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
}
执行结果:
256 100 96 88 88
- 评价
由于最小堆或者最大堆的操作时间复杂度均为 O(lg n),n 是堆的大小,大小为 K 的堆即 O(lg K)。且只需要遍历一遍集合,则时间复杂度基本为 O(n * lg K)。可以说该方法的执行效率总是会高于排序法了。
且当 K 值较小时,该算法拥有一个较小的常数系数 lg K ,效率还是很高的。
快速选择法
- 简介
快速选择法是一个期望为线性时间的选择算法。它是以快速排序算法为模型修改的。
简单介绍一下原理:首先我们知道快速排序的原理是根据一个 pivot 值,将集合中小于 pivot 的值放置在其左侧,将大于 pivot 的值放置在其右侧,然后再递归地处理左右两侧,最终完成整个集合的排序。快速选择则只处理其中一边,那么根据 pivot 的坐标判断, K 小于 pivot 的话那么我们的数据肯定在左侧,相反则在右侧,等于则直接返回,因为它就是我们要找的数。
如果是求 top K,那我们再以这个第 K 统计位的数为 pivot,划分一次集合即可。
- 代码
#include <stdio.h>
int partition(int *arr, int start, int end) {
if (start >= end) return arr[start];
int pivot = arr[start];
while (end > start) {
while (end > start && arr[end] >= pivot) {
end--;
}
arr[start] = arr[end]; // 将小于 pivot 的数放在低位
while (end > start && arr[start] <= pivot) {
start++;
}
arr[end] = arr[start]; // 将大于 pivot 的数放在高位
}
arr[start] = pivot;
return start; // 返回当前轴点位置
}
int quickSelect(int *arr, int length, int k) {
int start = 0;
int end = length - 1;
while (end >= start) {
int p = partition(arr, start, end);
if (p == k - 1) { // 数组的索引是0开始的,第k大的索引是1开始的
return arr[p];
} else if (p < k - 1) {
start = p + 1;
} else {
end = p - 1;
}
}
return 0;
}
int main() {
int arr[] = { 5, 4, 8, 6, 3, 9, 10, 1, 7, 2 };
printf("第5位的元素为:%d\n", quickSelect(arr, sizeof(arr) / sizeof(arr[0]), 5));
return 0;
}
- 评价
快速选择的效率比快速排序已经高了很多,平均时间复杂度通常为 O(n * lg n) 到 O(n)。然后,最坏情况下,它的时间复杂度仍然为 O(n ^ 2)。
即便如此,快速选择及其变种是实际应用中最常使用的高效选择算法,适用于求海量数据的中位数这种 K 也很大的情况。
BFPRT 算法
- 简介
BFPRT 算法是一个最坏情况下仍为线性时间的选择算法,它是上述快速选择算法的变种,避免了最坏情况的产生。
BFPRT 算法又称为中位数的中位数算法,是由 5 位大牛 (Blum, Floyd, Pratt, Rivest, Tarjan) 提出,并以他们的名字命名的算法。和选择算法不同的是,它首先将集合以 5 个 5 个元素的划开,先求出每 5 个元素的中位数,再求出这些中位数的中位数,最后以这个中位数的中位数为 pivot 划分集合,以此避免了最坏情况的产生。
- 代码
#include <stdio.h>
int BFPRT(int *arr, int start, int end, int k);
void swap(int *a, int *b) {
if (a != b) {
int temp = *a;
*a = *b;
*b = temp;
}
}
int insertSort(int *arr, int start, int end) {
for (int i = start + 1; i <= end; i++) {
for (int j = i; j > start; j--) {
if (arr[j] < arr[j - 1]) {
swap(&arr[j], &arr[j - 1]);
} else {
break;
}
}
}
return ((end - start) >> 1) + start; // 返回中位数的下标
}
int getMedianIndex(int *arr, int start, int end) {
int length = end - start + 1;
if (length <= 5) {
return insertSort(arr, start, end);
}
int subEnd = start; // 中位数的结束位置
for (int i = start; i + 4 <= end; i += 5) {
int index = insertSort(arr, i, i + 4);
swap(&arr[subEnd++], &arr[index]);
}
int module = length % 5; // 不能被 5 整除的余数部分
if (module != 0) {
int index = insertSort(arr, end - module + 1, end);
swap(&arr[subEnd++], &arr[index]);
}
return getMedianIndex(arr, start, subEnd - 1);
}
int partition(int *arr, int start, int end, int pivotIndex) {
if (start >= end) return arr[start];
swap(&arr[start], &arr[pivotIndex]);
int pivot = arr[start];
while (end > start) {
while (end > start && arr[end] >= pivot) {
end--;
}
arr[start] = arr[end]; // 将小于 pivot 的数放在低位
while (end > start && arr[start] <= pivot) {
start++;
}
arr[end] = arr[start]; // 将大于 pivot 的数放在高位
}
arr[start] = pivot;
return start; // 返回当前轴点位置
}
int BFPRT(int *arr, int start, int end, int k) {
if (end - start < 5) {
insertSort(arr, start, end);
return arr[k - 1]; // 数组长度太短的话直接处理
}
int medianIndex = getMedianIndex(arr, start, end); // 中位数的中位数下标
int pivotIndex = partition(arr, start, end, medianIndex); // 划分后当前轴点的下标
if (pivotIndex == k - 1) {
return arr[pivotIndex];
} else if (pivotIndex < k - 1) {
return BFPRT(arr, pivotIndex + 1, end, k);
} else {
return BFPRT(arr, start, pivotIndex - 1, k);
}
return 0;
}
int main() {
int arr[] = { 5, 4, 8, 6, 3, 9, 0, 1, 7, 2 };
for (int i = 0; i < 10; i++) {
printf("第%d位的元素是:%d\n", i + 1, arr[BFPRT(arr, 0, 9, i + 1)]);
}
return 0;
}
- 评价
BFPRT 算法代码比较复杂,但它基本保证了 O(n) 时间下求出结果,只是它的常数系数并不小,所以适合海量数据下,同时 K 也很大的情况,典型的就是求海量数据的中位数了。