标题:记一个高效、简单的nth_element算法 出处:Felix021 时间:Mon, 29 Jul 2013 11:25:58 +0000 作者:felix021 地址:https://www.felix021.com/blog/read.php?2122 内容: nth_element是一个用烂了的面试题,09年的时候我也曾经跑过一点数据(回头一看好像有点不太对,那时候CPU有那么慢吗?应该是当时没有把读取数据的时间给去掉吧),昨天看到 从一道笔试题谈算法优化(上) 和 从一道笔试题谈算法优化(下) 这个系列,里面提到了从简单到复杂的6种算法,根据作者的说法,经过各种优化以后solution6达到了相当的效率。受到作者启发,我也想到了一种算法,于是花了些时间,把常用的几种算法实现了进行对比(包括对比stl里的nth_element)。 回到 nth_element 的问题:从 n 个数里头,找出其中的 top k (top可以是找最大 也可以是找最小)。简单起见,后面都以找最小为例。 那篇文章里的后面的3种算法和我的算法都是基于它的solution3进行优化的,所以先介绍一下solution3: 1. 将 n 个数的前 k 个拷贝出来 2. 找出这 k 个数的最大值 m 3. 对于每一个剩下的 n - k 个数 i :如果 i 小于 m ,将 m 替换为 i ,然后再找出新的 k 个数中的最大值 m 4. 返回过滤后得到的 k 个数 虽然第3步的判断条件成立时,这一步是 O(k) 的复杂度,但是在大部分情况下,这个判断条件都不成立,所以它需要执行的次数很少(这种效果越往后越明显,因为 m 的值越来越小,可以滤掉更多的数)。但是对于极端情况(或者接近极端的情况)——如果数组完全是降序排列的话,那么这个条件每次都会成立,会导致算法退化成 O(n * m) ,性能就不可接受了。 solution4~6的优化我觉得读起来有点晦涩,有兴趣的同学可以自己去看,这里不展开说了。它们都存在算法退化的情况。 我的算法可能理解起来要简单些,基本思路是偷懒:把缓冲区设置为 2 * k ,在扫数组的时候,如果这个数比当前保留的最大值还要小,就把它塞进缓冲区,直到缓冲区塞满了,再排序、取最大值、删多余的数。 1. 开辟一个 2 × k 的临时数组 r 2. 将 N[1..k] 拷贝到 r[1..k],并记录 r 的末尾 e = k + 1 3. 找出 r[1..k] 中的最大值 m 4. 对于每一个剩下的 n - k 个数 i: 4.1 如果 i 小于 m,将 m 塞到 r 的末尾 r[e];e = e + 1 4.1.1 如果 e == 2×k(r被塞满了),对 r 进行排序,得出前k个数中的最大值;抛弃后面的k个数,e = k + 1 5. 对 r[1..e] 进行排序 6. 返回 r[1..k] 这个算法的实际效果很好,在 n = 1亿、k = 1万 的情况下,对于随机的n个数,4.1条件成立的次数通常只有十几次,几乎可以忽略,因此大约只需要扫描整个数组时间的2~3倍即可。 但是这个算法同样存在退化的问题:如果全都是倒序排列的话,也变成O(n*m)了。幸而解决方案也很简单:对这个数组进行 n / 100 次的随机化(考虑到随机化耗时也比较大,100这个数字是试了几次以后发现的合适值,不一定最优,但是都差不多了): void rand_factor() { //randomization const int factor = 100; int r, t, i; for (i = n / factor; i > 0; i--) { r = rand() % n; t = a[0], a[0] = a[r], a[r] = t; } } p.s. 很不幸的是随机化这个方法对于solution3来说虽然提升也很明显,但是远远不够。solution4~6也不太行。 这里给出随机情况下和完全倒序情况下的性能对比吧(ubuntu 12.04 x86_64,i5 2400@3.1G): 引用 随机数据 ======== 随机化生成数据: 7.087s 遍历耗时: 0.053s 1/100随机化耗时: 0.064s STL的nth_element: 0.841s 最大堆+随机化: 0.162s Solution3+随机化: 3.121s solution4+随机化: 2.825s solution6+随机化: 0.249s 我的算法+随机化: 0.163s 倒序数据 ======== 无随机化生成数据: 0.150s 遍历耗时: 0.052s 1/100随机化耗时: 0.067s STL的nth_element: 1.420s 最大堆+随机化: 0.301s Solution3+随机化: 67.552s solution4+随机化: 66.405s solution6+随机化: 2.388s 我的算法+随机化: 0.170s 补充说明一下,nth_element算法在面试的时候可能给出的 n 会更大,以至于在内存中存不下,这时候通常会认为用堆来实现最好——因为只要O(k)的空间就可以搞定,而且最差时间复杂度是O(logk * n),但是要注意logk的常数实际上是很大的,所以你可以看到前面的数据里 堆算法 的效率并不是最好的。事实上这里给出的所有算法(不考虑随机化的话)也都只需要O(k)的空间;如果需要随机化的话也很简单,分段处理,只要保证每段能在内存中保存下来就行了。 最后上代码存档(为了方便测试用了些全局变量,看起来可能有点挫): #include #include #include #include #include #include #include #include #ifdef _WIN32 #include #endif using namespace std; const int N = 100000000, R = 10000; int orig[N], a[N], n, res[R]; int nargs; #define time_test(func) _time_test(func, #func) void _time_test(void (*func)(void), const char *name) { struct timeval begin, end; gettimeofday(&begin, NULL); func(); gettimeofday(&end, NULL); printf("%s: %.3fs\n", name, ((end.tv_sec - begin.tv_sec) * 1000000 + (end.tv_usec - begin.tv_usec)) / 1000000.0); } void gen_int() { int i; for (i = 0; i < N; i++) orig[i] = N - i; printf("generated\n"); if (nargs == 1) { random_shuffle(orig, orig + N); printf("randomization ok\n"); } else printf("skip randomization\n"); n = N; printf("%d numbers\n", n); } void do_copy() { memcpy(a, orig, sizeof(a)); } void rand_factor() { //randomization const int factor = 100; int r, t, i; for (i = n / factor; i > 0; i--) { r = rand() % n; t = a[0], a[0] = a[r], a[r] = t; } } void nth_stl() { nth_element(a, a + R, a + n); int i; for (i = 0; i < R; i++) res[i] = a[i]; } int findmax(int x[], int n) { int i, m = 0; for (i = 1; i < n; i++) m = x[m] > x[i] ? m : i; return m; } void nth_selection() { int i, m; for (i = 0; i < R; i++) res[i] = a[i]; for (i = R; i < n; i++) { m = findmax(res, R); if (a[i] < res[m]) { res[m] = a[i]; } } } void nth_sol3() { int i, m; rand_factor(); for (i = 0; i < R; i++) res[i] = a[i]; m = findmax(res, R); for (i = R; i < n; i++) { if (a[i] < res[m]) { res[m] = a[i]; m = findmax(res, R); } } } void nth_sel_opt() //我的算法 { int i, m, idx = R, cnt = 0; rand_factor(); const int buflen= R * 2; int *r = new int[buflen]; for (i = 0; i < R; i++) r[i] = a[i]; m = r[findmax(r, R)]; for (i = R; i < n; i++) { if (a[i] < m) { r[idx++] = a[i]; if (idx == buflen) { cnt += 1; idx = R; nth_element(r, r + R, r + buflen); //这里用nth_element进一步优化,不过跟用sort差不多 m = r[findmax(r, R)]; } } } nth_element(r, r + R, r + idx); for (i = 0; i < R; i++) res[i] = r[i]; delete[] r; printf("%d reorg\n", cnt); } void nth_sol4() { int i; rand_factor(); for (i = 0; i < R; i++) res[i] = a[i]; sort(res, res + R); int mIdx = R - 1, zoneIdx = mIdx; for (i = R; i < n; i++) { if (a[i] < res[mIdx]) { res[mIdx] = a[i]; if (mIdx == zoneIdx) zoneIdx--; mIdx = zoneIdx; for (int j = zoneIdx + 1; j < R; j++) if (res[mIdx] < res[j]) mIdx = j; } } } void nth_sol6() { int i; rand_factor(); for (i = 0; i < R; i++) res[i] = a[i]; sort(res, res + R); int mIdx = R - 1, zoneIdx = mIdx; for (i = R; i < n; i++) { if (a[i] < res[mIdx]) { res[mIdx] = a[i]; if (mIdx == zoneIdx) zoneIdx--; if (zoneIdx >= 9400) { mIdx = zoneIdx; for (int j = zoneIdx + 1; j < R; j++) if (res[mIdx] < res[j]) mIdx = j; } else { sort(res + zoneIdx, res + R); inplace_merge(res, res + zoneIdx, res + R); mIdx = R - 1, zoneIdx = mIdx; } } } } void sift_down(int *res, int R, int i) { int *x = res - 1, t; while (i * 2 <= R) { i *= 2; if (i + 1 <= R && x[i] < x[i+1]) i++; if (x[i/2] < x[i]) t = x[i/2], x[i/2] = x[i], x[i] = t; else break; } } void nth_heap() { int i, cnt = 0; rand_factor(); for (i = 0; i < R; i++) res[i] = a[i]; for (i = (R + 1) / 2; i >= 1; i--) sift_down(res, R, i); for (i = R; i < n; i++) { if (a[i] < res[0]) { res[0] = a[i]; sift_down(res, R, 1); cnt += 1; } } printf("%d sift_down\n", cnt); } void traverse() { int x = 0; for (int i = 0; i < n; i++) x = orig[i]; printf("x = %d\n", x); } void verify() { int i; nth_element(a, a + R, a + n); sort(a, a + R); sort(res, res + R); for (i = 0; i < R; i++) { if (a[i] != res[i]) { printf("wrong algo(%d): %d vs %d!\n", i, a[i], res[i]); break; } } if (i == R) printf("verify ok\n"); puts(""); } int main(int argc, char *argv[]) { nargs = argc; time_test(gen_int); time_test(traverse); puts(""); //* do_copy(); time_test(nth_stl); verify(); // */ //* do_copy(); time_test(nth_sol3); verify(); // */ do_copy(); time_test(nth_heap); verify(); do_copy(); time_test(nth_sel_opt); verify(); do_copy(); time_test(nth_sol4); verify(); do_copy(); time_test(nth_sol6); verify(); time_test(rand_factor); printf("\ntests over\n"); return 0; } Generated by Bo-blog 2.1.0