Jul
29
记一个高效、简单的nth_element算法
nth_element是一个用烂了的面试题,09年的时候我也曾经跑过一点数据(回头一看好像有点不太对,那时候CPU有那么慢吗?应该是当时没有把读取数据的时间给去掉吧),昨天看到 从一道笔试题谈算法优化(上) 和 从一道笔试题谈算法优化(下) 这个系列,里面提到了从简单到复杂的6种算法,根据作者的说法,经过各种优化以后solution6达到了相当的效率。受到作者启发,我也想到了一种算法,于是花了些时间,把常用的几种算法实现了进行对比(包括对比stl里的nth_element)。
回到 nth_element 的问题:从 n 个数里头,找出其中的 top k (top可以是找最大 也可以是找最小)。简单起见,后面都以找最小为例。
那篇文章里的后面的3种算法和我的算法都是基于它的solution3进行优化的,所以先介绍一下solution3:
1. 将 n 个数的前 k 个拷贝出来
2. 找出这 k 个数的最大值 m
3. 对于每一个剩下的 n - k 个数 i :如果 i 小于 m ,将 m 替换为 i ,然后再找出新的 k 个数中的最大值 m
4. 返回过滤后得到的 k 个数
虽然第3步的判断条件成立时,这一步是 O(k) 的复杂度,但是在大部分情况下,这个判断条件都不成立,所以它需要执行的次数很少(这种效果越往后越明显,因为 m 的值越来越小,可以滤掉更多的数)。但是对于极端情况(或者接近极端的情况)——如果数组完全是降序排列的话,那么这个条件每次都会成立,会导致算法退化成 O(n * m) ,性能就不可接受了。
solution4~6的优化我觉得读起来有点晦涩,有兴趣的同学可以自己去看,这里不展开说了。它们都存在算法退化的情况。
我的算法可能理解起来要简单些,基本思路是偷懒:把缓冲区设置为 2 * k ,在扫数组的时候,如果这个数比当前保留的最大值还要小,就把它塞进缓冲区,直到缓冲区塞满了,再排序、取最大值、删多余的数。
1. 开辟一个 2 × k 的临时数组 r
2. 将 N[1..k] 拷贝到 r[1..k],并记录 r 的末尾 e = k + 1
3. 找出 r[1..k] 中的最大值 m
4. 对于每一个剩下的 n - k 个数 i:
4.1 如果 i 小于 m,将 m 塞到 r 的末尾 r[e];e = e + 1
4.1.1 如果 e == 2×k(r被塞满了),对 r 进行排序,得出前k个数中的最大值;抛弃后面的k个数,e = k + 1
5. 对 r[1..e] 进行排序
6. 返回 r[1..k]
这个算法的实际效果很好,在 n = 1亿、k = 1万 的情况下,对于随机的n个数,4.1条件成立的次数通常只有十几次,几乎可以忽略,因此大约只需要扫描整个数组时间的2~3倍即可。
但是这个算法同样存在退化的问题:如果全都是倒序排列的话,也变成O(n*m)了。幸而解决方案也很简单:对这个数组进行 n / 100 次的随机化(考虑到随机化耗时也比较大,100这个数字是试了几次以后发现的合适值,不一定最优,但是都差不多了):
p.s. 很不幸的是随机化这个方法对于solution3来说虽然提升也很明显,但是远远不够。solution4~6也不太行。
这里给出随机情况下和完全倒序情况下的性能对比吧(ubuntu 12.04 x86_64,i5 2400@3.1G):
补充说明一下,nth_element算法在面试的时候可能给出的 n 会更大,以至于在内存中存不下,这时候通常会认为用堆来实现最好——因为只要O(k)的空间就可以搞定,而且最差时间复杂度是O(logk * n),但是要注意logk的常数实际上是很大的,所以你可以看到前面的数据里 堆算法 的效率并不是最好的。事实上这里给出的所有算法(不考虑随机化的话)也都只需要O(k)的空间;如果需要随机化的话也很简单,分段处理,只要保证每段能在内存中保存下来就行了。
最后上代码存档(为了方便测试用了些全局变量,看起来可能有点挫):
转载请注明出自 ,如是转载文则注明原出处,谢谢:)
RSS订阅地址: https://www.felix021.com/blog/feed.php 。
回到 nth_element 的问题:从 n 个数里头,找出其中的 top k (top可以是找最大 也可以是找最小)。简单起见,后面都以找最小为例。
那篇文章里的后面的3种算法和我的算法都是基于它的solution3进行优化的,所以先介绍一下solution3:
1. 将 n 个数的前 k 个拷贝出来
2. 找出这 k 个数的最大值 m
3. 对于每一个剩下的 n - k 个数 i :如果 i 小于 m ,将 m 替换为 i ,然后再找出新的 k 个数中的最大值 m
4. 返回过滤后得到的 k 个数
虽然第3步的判断条件成立时,这一步是 O(k) 的复杂度,但是在大部分情况下,这个判断条件都不成立,所以它需要执行的次数很少(这种效果越往后越明显,因为 m 的值越来越小,可以滤掉更多的数)。但是对于极端情况(或者接近极端的情况)——如果数组完全是降序排列的话,那么这个条件每次都会成立,会导致算法退化成 O(n * m) ,性能就不可接受了。
solution4~6的优化我觉得读起来有点晦涩,有兴趣的同学可以自己去看,这里不展开说了。它们都存在算法退化的情况。
我的算法可能理解起来要简单些,基本思路是偷懒:把缓冲区设置为 2 * k ,在扫数组的时候,如果这个数比当前保留的最大值还要小,就把它塞进缓冲区,直到缓冲区塞满了,再排序、取最大值、删多余的数。
1. 开辟一个 2 × k 的临时数组 r
2. 将 N[1..k] 拷贝到 r[1..k],并记录 r 的末尾 e = k + 1
3. 找出 r[1..k] 中的最大值 m
4. 对于每一个剩下的 n - k 个数 i:
4.1 如果 i 小于 m,将 m 塞到 r 的末尾 r[e];e = e + 1
4.1.1 如果 e == 2×k(r被塞满了),对 r 进行排序,得出前k个数中的最大值;抛弃后面的k个数,e = k + 1
5. 对 r[1..e] 进行排序
6. 返回 r[1..k]
这个算法的实际效果很好,在 n = 1亿、k = 1万 的情况下,对于随机的n个数,4.1条件成立的次数通常只有十几次,几乎可以忽略,因此大约只需要扫描整个数组时间的2~3倍即可。
但是这个算法同样存在退化的问题:如果全都是倒序排列的话,也变成O(n*m)了。幸而解决方案也很简单:对这个数组进行 n / 100 次的随机化(考虑到随机化耗时也比较大,100这个数字是试了几次以后发现的合适值,不一定最优,但是都差不多了):
void rand_factor()
{
//randomization
const int factor = 100;
int r, t, i;
for (i = n / factor; i > 0; i--)
{
r = rand() % n;
t = a[0], a[0] = a[r], a[r] = t;
}
}
{
//randomization
const int factor = 100;
int r, t, i;
for (i = n / factor; i > 0; i--)
{
r = rand() % n;
t = a[0], a[0] = a[r], a[r] = t;
}
}
p.s. 很不幸的是随机化这个方法对于solution3来说虽然提升也很明显,但是远远不够。solution4~6也不太行。
这里给出随机情况下和完全倒序情况下的性能对比吧(ubuntu 12.04 x86_64,i5 2400@3.1G):
引用
随机数据
========
随机化生成数据: 7.087s
遍历耗时: 0.053s
1/100随机化耗时: 0.064s
STL的nth_element: 0.841s
最大堆+随机化: 0.162s
Solution3+随机化: 3.121s
solution4+随机化: 2.825s
solution6+随机化: 0.249s
我的算法+随机化: 0.163s
倒序数据
========
无随机化生成数据: 0.150s
遍历耗时: 0.052s
1/100随机化耗时: 0.067s
STL的nth_element: 1.420s
最大堆+随机化: 0.301s
Solution3+随机化: 67.552s
solution4+随机化: 66.405s
solution6+随机化: 2.388s
我的算法+随机化: 0.170s
========
随机化生成数据: 7.087s
遍历耗时: 0.053s
1/100随机化耗时: 0.064s
STL的nth_element: 0.841s
最大堆+随机化: 0.162s
Solution3+随机化: 3.121s
solution4+随机化: 2.825s
solution6+随机化: 0.249s
我的算法+随机化: 0.163s
倒序数据
========
无随机化生成数据: 0.150s
遍历耗时: 0.052s
1/100随机化耗时: 0.067s
STL的nth_element: 1.420s
最大堆+随机化: 0.301s
Solution3+随机化: 67.552s
solution4+随机化: 66.405s
solution6+随机化: 2.388s
我的算法+随机化: 0.170s
补充说明一下,nth_element算法在面试的时候可能给出的 n 会更大,以至于在内存中存不下,这时候通常会认为用堆来实现最好——因为只要O(k)的空间就可以搞定,而且最差时间复杂度是O(logk * n),但是要注意logk的常数实际上是很大的,所以你可以看到前面的数据里 堆算法 的效率并不是最好的。事实上这里给出的所有算法(不考虑随机化的话)也都只需要O(k)的空间;如果需要随机化的话也很简单,分段处理,只要保证每段能在内存中保存下来就行了。
最后上代码存档(为了方便测试用了些全局变量,看起来可能有点挫):
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <unistd.h>
#include <sys/time.h>
#ifdef _WIN32
#include <windows.h>
#endif
using namespace std;
const int N = 100000000, R = 10000;
int orig[N], a[N], n, res[R];
int nargs;
#define time_test(func) _time_test(func, #func)
void _time_test(void (*func)(void), const char *name)
{
struct timeval begin, end;
gettimeofday(&begin, NULL);
func();
gettimeofday(&end, NULL);
printf("%s: %.3fs\n", name,
((end.tv_sec - begin.tv_sec) * 1000000
+ (end.tv_usec - begin.tv_usec)) / 1000000.0);
}
void gen_int()
{
int i;
for (i = 0; i < N; i++)
orig[i] = N - i;
printf("generated\n");
if (nargs == 1)
{
random_shuffle(orig, orig + N);
printf("randomization ok\n");
}
else
printf("skip randomization\n");
n = N;
printf("%d numbers\n", n);
}
void do_copy()
{
memcpy(a, orig, sizeof(a));
}
void rand_factor()
{
//randomization
const int factor = 100;
int r, t, i;
for (i = n / factor; i > 0; i--)
{
r = rand() % n;
t = a[0], a[0] = a[r], a[r] = t;
}
}
void nth_stl()
{
nth_element(a, a + R, a + n);
int i;
for (i = 0; i < R; i++)
res[i] = a[i];
}
int findmax(int x[], int n)
{
int i, m = 0;
for (i = 1; i < n; i++)
m = x[m] > x[i] ? m : i;
return m;
}
void nth_selection()
{
int i, m;
for (i = 0; i < R; i++)
res[i] = a[i];
for (i = R; i < n; i++)
{
m = findmax(res, R);
if (a[i] < res[m])
{
res[m] = a[i];
}
}
}
void nth_sol3()
{
int i, m;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
m = findmax(res, R);
for (i = R; i < n; i++)
{
if (a[i] < res[m])
{
res[m] = a[i];
m = findmax(res, R);
}
}
}
void nth_sel_opt() //我的算法
{
int i, m, idx = R, cnt = 0;
rand_factor();
const int buflen= R * 2;
int *r = new int[buflen];
for (i = 0; i < R; i++)
r[i] = a[i];
m = r[findmax(r, R)];
for (i = R; i < n; i++)
{
if (a[i] < m)
{
r[idx++] = a[i];
if (idx == buflen)
{
cnt += 1;
idx = R;
nth_element(r, r + R, r + buflen); //这里用nth_element进一步优化,不过跟用sort差不多
m = r[findmax(r, R)];
}
}
}
nth_element(r, r + R, r + idx);
for (i = 0; i < R; i++)
res[i] = r[i];
delete[] r;
printf("%d reorg\n", cnt);
}
void nth_sol4()
{
int i;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
sort(res, res + R);
int mIdx = R - 1, zoneIdx = mIdx;
for (i = R; i < n; i++)
{
if (a[i] < res[mIdx])
{
res[mIdx] = a[i];
if (mIdx == zoneIdx)
zoneIdx--;
mIdx = zoneIdx;
for (int j = zoneIdx + 1; j < R; j++)
if (res[mIdx] < res[j])
mIdx = j;
}
}
}
void nth_sol6()
{
int i;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
sort(res, res + R);
int mIdx = R - 1, zoneIdx = mIdx;
for (i = R; i < n; i++)
{
if (a[i] < res[mIdx])
{
res[mIdx] = a[i];
if (mIdx == zoneIdx)
zoneIdx--;
if (zoneIdx >= 9400)
{
mIdx = zoneIdx;
for (int j = zoneIdx + 1; j < R; j++)
if (res[mIdx] < res[j])
mIdx = j;
}
else
{
sort(res + zoneIdx, res + R);
inplace_merge(res, res + zoneIdx, res + R);
mIdx = R - 1, zoneIdx = mIdx;
}
}
}
}
void sift_down(int *res, int R, int i)
{
int *x = res - 1, t;
while (i * 2 <= R)
{
i *= 2;
if (i + 1 <= R && x[i] < x[i+1])
i++;
if (x[i/2] < x[i])
t = x[i/2], x[i/2] = x[i], x[i] = t;
else
break;
}
}
void nth_heap()
{
int i, cnt = 0;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
for (i = (R + 1) / 2; i >= 1; i--)
sift_down(res, R, i);
for (i = R; i < n; i++)
{
if (a[i] < res[0])
{
res[0] = a[i];
sift_down(res, R, 1);
cnt += 1;
}
}
printf("%d sift_down\n", cnt);
}
void traverse()
{
int x = 0;
for (int i = 0; i < n; i++)
x = orig[i];
printf("x = %d\n", x);
}
void verify()
{
int i;
nth_element(a, a + R, a + n);
sort(a, a + R);
sort(res, res + R);
for (i = 0; i < R; i++)
{
if (a[i] != res[i])
{
printf("wrong algo(%d): %d vs %d!\n", i, a[i], res[i]);
break;
}
}
if (i == R)
printf("verify ok\n");
puts("");
}
int main(int argc, char *argv[])
{
nargs = argc;
time_test(gen_int);
time_test(traverse);
puts("");
//*
do_copy();
time_test(nth_stl);
verify();
// */
//*
do_copy();
time_test(nth_sol3);
verify();
// */
do_copy();
time_test(nth_heap);
verify();
do_copy();
time_test(nth_sel_opt);
verify();
do_copy();
time_test(nth_sol4);
verify();
do_copy();
time_test(nth_sol6);
verify();
time_test(rand_factor);
printf("\ntests over\n");
return 0;
}
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <unistd.h>
#include <sys/time.h>
#ifdef _WIN32
#include <windows.h>
#endif
using namespace std;
const int N = 100000000, R = 10000;
int orig[N], a[N], n, res[R];
int nargs;
#define time_test(func) _time_test(func, #func)
void _time_test(void (*func)(void), const char *name)
{
struct timeval begin, end;
gettimeofday(&begin, NULL);
func();
gettimeofday(&end, NULL);
printf("%s: %.3fs\n", name,
((end.tv_sec - begin.tv_sec) * 1000000
+ (end.tv_usec - begin.tv_usec)) / 1000000.0);
}
void gen_int()
{
int i;
for (i = 0; i < N; i++)
orig[i] = N - i;
printf("generated\n");
if (nargs == 1)
{
random_shuffle(orig, orig + N);
printf("randomization ok\n");
}
else
printf("skip randomization\n");
n = N;
printf("%d numbers\n", n);
}
void do_copy()
{
memcpy(a, orig, sizeof(a));
}
void rand_factor()
{
//randomization
const int factor = 100;
int r, t, i;
for (i = n / factor; i > 0; i--)
{
r = rand() % n;
t = a[0], a[0] = a[r], a[r] = t;
}
}
void nth_stl()
{
nth_element(a, a + R, a + n);
int i;
for (i = 0; i < R; i++)
res[i] = a[i];
}
int findmax(int x[], int n)
{
int i, m = 0;
for (i = 1; i < n; i++)
m = x[m] > x[i] ? m : i;
return m;
}
void nth_selection()
{
int i, m;
for (i = 0; i < R; i++)
res[i] = a[i];
for (i = R; i < n; i++)
{
m = findmax(res, R);
if (a[i] < res[m])
{
res[m] = a[i];
}
}
}
void nth_sol3()
{
int i, m;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
m = findmax(res, R);
for (i = R; i < n; i++)
{
if (a[i] < res[m])
{
res[m] = a[i];
m = findmax(res, R);
}
}
}
void nth_sel_opt() //我的算法
{
int i, m, idx = R, cnt = 0;
rand_factor();
const int buflen= R * 2;
int *r = new int[buflen];
for (i = 0; i < R; i++)
r[i] = a[i];
m = r[findmax(r, R)];
for (i = R; i < n; i++)
{
if (a[i] < m)
{
r[idx++] = a[i];
if (idx == buflen)
{
cnt += 1;
idx = R;
nth_element(r, r + R, r + buflen); //这里用nth_element进一步优化,不过跟用sort差不多
m = r[findmax(r, R)];
}
}
}
nth_element(r, r + R, r + idx);
for (i = 0; i < R; i++)
res[i] = r[i];
delete[] r;
printf("%d reorg\n", cnt);
}
void nth_sol4()
{
int i;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
sort(res, res + R);
int mIdx = R - 1, zoneIdx = mIdx;
for (i = R; i < n; i++)
{
if (a[i] < res[mIdx])
{
res[mIdx] = a[i];
if (mIdx == zoneIdx)
zoneIdx--;
mIdx = zoneIdx;
for (int j = zoneIdx + 1; j < R; j++)
if (res[mIdx] < res[j])
mIdx = j;
}
}
}
void nth_sol6()
{
int i;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
sort(res, res + R);
int mIdx = R - 1, zoneIdx = mIdx;
for (i = R; i < n; i++)
{
if (a[i] < res[mIdx])
{
res[mIdx] = a[i];
if (mIdx == zoneIdx)
zoneIdx--;
if (zoneIdx >= 9400)
{
mIdx = zoneIdx;
for (int j = zoneIdx + 1; j < R; j++)
if (res[mIdx] < res[j])
mIdx = j;
}
else
{
sort(res + zoneIdx, res + R);
inplace_merge(res, res + zoneIdx, res + R);
mIdx = R - 1, zoneIdx = mIdx;
}
}
}
}
void sift_down(int *res, int R, int i)
{
int *x = res - 1, t;
while (i * 2 <= R)
{
i *= 2;
if (i + 1 <= R && x[i] < x[i+1])
i++;
if (x[i/2] < x[i])
t = x[i/2], x[i/2] = x[i], x[i] = t;
else
break;
}
}
void nth_heap()
{
int i, cnt = 0;
rand_factor();
for (i = 0; i < R; i++)
res[i] = a[i];
for (i = (R + 1) / 2; i >= 1; i--)
sift_down(res, R, i);
for (i = R; i < n; i++)
{
if (a[i] < res[0])
{
res[0] = a[i];
sift_down(res, R, 1);
cnt += 1;
}
}
printf("%d sift_down\n", cnt);
}
void traverse()
{
int x = 0;
for (int i = 0; i < n; i++)
x = orig[i];
printf("x = %d\n", x);
}
void verify()
{
int i;
nth_element(a, a + R, a + n);
sort(a, a + R);
sort(res, res + R);
for (i = 0; i < R; i++)
{
if (a[i] != res[i])
{
printf("wrong algo(%d): %d vs %d!\n", i, a[i], res[i]);
break;
}
}
if (i == R)
printf("verify ok\n");
puts("");
}
int main(int argc, char *argv[])
{
nargs = argc;
time_test(gen_int);
time_test(traverse);
puts("");
//*
do_copy();
time_test(nth_stl);
verify();
// */
//*
do_copy();
time_test(nth_sol3);
verify();
// */
do_copy();
time_test(nth_heap);
verify();
do_copy();
time_test(nth_sel_opt);
verify();
do_copy();
time_test(nth_sol4);
verify();
do_copy();
time_test(nth_sol6);
verify();
time_test(rand_factor);
printf("\ntests over\n");
return 0;
}
欢迎扫码关注:
转载请注明出自 ,如是转载文则注明原出处,谢谢:)
RSS订阅地址: https://www.felix021.com/blog/feed.php 。