Jul 29

记一个高效、简单的nth_element算法 不指定

felix021 @ 2013-7-29 11:25 [IT » 程序设计] 评论(0) , 引用(0) , 阅读(16396) | Via 本站原创 | |
nth_element是一个用烂了的面试题,09年的时候我也曾经跑过一点数据(回头一看好像有点不太对,那时候CPU有那么慢吗?应该是当时没有把读取数据的时间给去掉吧),昨天看到 从一道笔试题谈算法优化(上)从一道笔试题谈算法优化(下) 这个系列,里面提到了从简单到复杂的6种算法,根据作者的说法,经过各种优化以后solution6达到了相当的效率。受到作者启发,我也想到了一种算法,于是花了些时间,把常用的几种算法实现了进行对比(包括对比stl里的nth_element)。

回到 nth_element 的问题:从 n 个数里头,找出其中的 top k (top可以是找最大 也可以是找最小)。简单起见,后面都以找最小为例。

那篇文章里的后面的3种算法和我的算法都是基于它的solution3进行优化的,所以先介绍一下solution3:

1. 将 n 个数的前 k 个拷贝出来
2. 找出这 k 个数的最大值 m
3. 对于每一个剩下的 n - k 个数 i :如果 i 小于 m ,将 m 替换为 i ,然后再找出新的 k 个数中的最大值 m
4. 返回过滤后得到的 k 个数

虽然第3步的判断条件成立时,这一步是 O(k) 的复杂度,但是在大部分情况下,这个判断条件都不成立,所以它需要执行的次数很少(这种效果越往后越明显,因为 m 的值越来越小,可以滤掉更多的数)。但是对于极端情况(或者接近极端的情况)——如果数组完全是降序排列的话,那么这个条件每次都会成立,会导致算法退化成 O(n * m) ,性能就不可接受了。

solution4~6的优化我觉得读起来有点晦涩,有兴趣的同学可以自己去看,这里不展开说了。它们都存在算法退化的情况。

我的算法可能理解起来要简单些,基本思路是偷懒:把缓冲区设置为 2 * k ,在扫数组的时候,如果这个数比当前保留的最大值还要小,就把它塞进缓冲区,直到缓冲区塞满了,再排序、取最大值、删多余的数。

1. 开辟一个 2 × k 的临时数组 r
2. 将 N[1..k] 拷贝到 r[1..k],并记录 r 的末尾 e = k + 1
3. 找出 r[1..k] 中的最大值 m
4. 对于每一个剩下的 n - k 个数 i:
4.1 如果 i 小于 m,将 m 塞到 r 的末尾 r[e];e = e + 1
4.1.1 如果 e == 2×k(r被塞满了),对 r 进行排序,得出前k个数中的最大值;抛弃后面的k个数,e = k + 1
5. 对 r[1..e] 进行排序
6. 返回 r[1..k]

这个算法的实际效果很好,在 n = 1亿、k = 1万 的情况下,对于随机的n个数,4.1条件成立的次数通常只有十几次,几乎可以忽略,因此大约只需要扫描整个数组时间的2~3倍即可。

但是这个算法同样存在退化的问题:如果全都是倒序排列的话,也变成O(n*m)了。幸而解决方案也很简单:对这个数组进行 n / 100 次的随机化(考虑到随机化耗时也比较大,100这个数字是试了几次以后发现的合适值,不一定最优,但是都差不多了):
void rand_factor()
{
    //randomization
    const int factor = 100;
    int r, t, i;
    for (i = n / factor; i > 0; i--)
    {
        r = rand() % n;
        t = a[0], a[0] = a[r], a[r] = t;
    }
}

p.s. 很不幸的是随机化这个方法对于solution3来说虽然提升也很明显,但是远远不够。solution4~6也不太行。

这里给出随机情况下和完全倒序情况下的性能对比吧(ubuntu 12.04 x86_64,i5 2400@3.1G):

引用
随机数据
========

随机化生成数据:    7.087s
遍历耗时:          0.053s
1/100随机化耗时:    0.064s

STL的nth_element:  0.841s

最大堆+随机化:    0.162s

Solution3+随机化:  3.121s
solution4+随机化:  2.825s
solution6+随机化:  0.249s

我的算法+随机化:  0.163s


倒序数据
========

无随机化生成数据:  0.150s
遍历耗时:          0.052s
1/100随机化耗时:    0.067s

STL的nth_element:  1.420s

最大堆+随机化:    0.301s

Solution3+随机化:  67.552s
solution4+随机化:  66.405s
solution6+随机化:  2.388s

我的算法+随机化:  0.170s



补充说明一下,nth_element算法在面试的时候可能给出的 n 会更大,以至于在内存中存不下,这时候通常会认为用堆来实现最好——因为只要O(k)的空间就可以搞定,而且最差时间复杂度是O(logk * n),但是要注意logk的常数实际上是很大的,所以你可以看到前面的数据里 堆算法 的效率并不是最好的。事实上这里给出的所有算法(不考虑随机化的话)也都只需要O(k)的空间;如果需要随机化的话也很简单,分段处理,只要保证每段能在内存中保存下来就行了。



最后上代码存档(为了方便测试用了些全局变量,看起来可能有点挫):
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <unistd.h>
#include <sys/time.h>
#ifdef _WIN32
#include <windows.h>
#endif

using namespace std;

const int N = 100000000, R = 10000;
int orig[N], a[N], n, res[R];

int nargs;

#define time_test(func) _time_test(func, #func)

void _time_test(void (*func)(void), const char *name)
{
    struct timeval begin, end;
    gettimeofday(&begin, NULL);
    func();
    gettimeofday(&end, NULL);
    printf("%s: %.3fs\n", name,
            ((end.tv_sec - begin.tv_sec) * 1000000
          + (end.tv_usec - begin.tv_usec)) / 1000000.0);
}

void gen_int()
{
    int i;
    for (i = 0; i < N; i++)
        orig[i] = N - i;
    printf("generated\n");
    if (nargs == 1)
    {
        random_shuffle(orig, orig + N);
        printf("randomization ok\n");
    }
    else
        printf("skip randomization\n");
    n = N;
    printf("%d numbers\n", n);
}

void do_copy()
{
    memcpy(a, orig, sizeof(a));
}

void rand_factor()
{
    //randomization
    const int factor = 100;
    int r, t, i;
    for (i = n / factor; i > 0; i--)
    {
        r = rand() % n;
        t = a[0], a[0] = a[r], a[r] = t;
    }
}


void nth_stl()
{
    nth_element(a, a + R, a + n);
    int i;
    for (i = 0; i < R; i++)
        res[i] = a[i];
}

int findmax(int x[], int n)
{
    int i, m = 0;
    for (i = 1; i < n; i++)
        m = x[m] > x[i] ? m : i;
    return m;
}

void nth_selection()
{
    int i, m;
    for (i = 0; i < R; i++)
        res[i] = a[i];
    for (i = R; i < n; i++)
    {
        m = findmax(res, R);
        if (a[i] < res[m])
        {
            res[m] = a[i];
        }
    }
}

void nth_sol3()
{
    int i, m;
    rand_factor();
    for (i = 0; i < R; i++)
        res[i] = a[i];
    m = findmax(res, R);
    for (i = R; i < n; i++)
    {
        if (a[i] < res[m])
        {
            res[m] = a[i];
            m = findmax(res, R);
        }
    }
}

void nth_sel_opt() //我的算法
{
    int i, m, idx = R, cnt = 0;

    rand_factor();

    const int buflen= R * 2;
    int *r = new int[buflen];
    for (i = 0; i < R; i++)
        r[i] = a[i];

    m = r[findmax(r, R)];

    for (i = R; i < n; i++)
    {
        if (a[i] < m)
        {
            r[idx++] = a[i];
            if (idx == buflen)
            {
                cnt += 1;
                idx = R;
                nth_element(r, r + R, r + buflen); //这里用nth_element进一步优化,不过跟用sort差不多
                m = r[findmax(r, R)];
            }
        }
    }
    nth_element(r, r + R, r + idx);
    for (i = 0; i < R; i++)
        res[i] = r[i];
    delete[] r;

    printf("%d reorg\n", cnt);
}

void nth_sol4()
{
    int i;

    rand_factor();

    for (i = 0; i < R; i++)
        res[i] = a[i];
    sort(res, res + R);
    int mIdx = R - 1, zoneIdx = mIdx;
    for (i = R; i < n; i++)
    {
        if (a[i] < res[mIdx])
        {
            res[mIdx] = a[i];
            if (mIdx == zoneIdx)
                zoneIdx--;
            mIdx = zoneIdx;
            for (int j = zoneIdx + 1; j < R; j++)
                if (res[mIdx] < res[j])
                    mIdx = j;
        }
    }
}

void nth_sol6()
{
    int i;

    rand_factor();

    for (i = 0; i < R; i++)
        res[i] = a[i];
    sort(res, res + R);
    int mIdx = R - 1, zoneIdx = mIdx;
    for (i = R; i < n; i++)
    {
        if (a[i] < res[mIdx])
        {
            res[mIdx] = a[i];
            if (mIdx == zoneIdx)
                zoneIdx--;
            if (zoneIdx >= 9400)
            {
                mIdx = zoneIdx;
                for (int j = zoneIdx + 1; j < R; j++)
                    if (res[mIdx] < res[j])
                        mIdx = j;
            }
            else
            {
                sort(res + zoneIdx, res + R);
                inplace_merge(res, res + zoneIdx, res + R);
                mIdx = R - 1, zoneIdx = mIdx;
            }
        }
    }
}

void sift_down(int *res, int R, int i)
{
    int *x = res - 1, t;
    while (i * 2 <= R)
    {
        i *= 2;
        if (i + 1 <= R && x[i] < x[i+1])
            i++;
        if (x[i/2] < x[i])
            t = x[i/2], x[i/2] = x[i], x[i] = t;
        else
            break;
    }
}

void nth_heap()
{
    int i, cnt = 0;

    rand_factor();

    for (i = 0; i < R; i++)
        res[i] = a[i];

    for (i = (R + 1) / 2; i >= 1; i--)
        sift_down(res, R, i);

    for (i = R; i < n; i++)
    {
        if (a[i] < res[0])
        {
            res[0] = a[i];
            sift_down(res, R, 1);
            cnt += 1;
        }
    }
    printf("%d sift_down\n", cnt);
}

void traverse()
{
    int x = 0;
    for (int i = 0; i < n; i++)
        x = orig[i];
    printf("x = %d\n", x);
}

void verify()
{
    int i;
    nth_element(a, a + R, a + n);
    sort(a, a + R);
    sort(res, res + R);
    for (i = 0; i < R; i++)
    {
        if (a[i] != res[i])
        {
            printf("wrong algo(%d): %d vs %d!\n", i, a[i], res[i]);
            break;
        }
    }
    if (i == R)
        printf("verify ok\n");
    puts("");
}

int main(int argc, char *argv[])
{
    nargs = argc;

    time_test(gen_int);
    time_test(traverse);

    puts("");

    //*
    do_copy();
    time_test(nth_stl);
    verify();
    // */

    //*
    do_copy();
    time_test(nth_sol3);
    verify();
    // */

    do_copy();
    time_test(nth_heap);
    verify();

    do_copy();
    time_test(nth_sel_opt);
    verify();

    do_copy();
    time_test(nth_sol4);
    verify();

    do_copy();
    time_test(nth_sol6);
    verify();

    time_test(rand_factor);

    printf("\ntests over\n");
    return 0;
}


转载请注明出自 ,如是转载文则注明原出处,谢谢:)
RSS订阅地址: http://www.felix021.com/blog/feed.php
发表评论
表情
emotemotemotemotemot
emotemotemotemotemot
emotemotemotemotemot
emotemotemotemotemot
emotemotemotemotemot
打开HTML
打开UBB
打开表情
隐藏
记住我
昵称   密码   *非必须
网址   电邮   [注册]