Randomized-Select 算法详解

前言

在一个长为 n 的无序序列中,查找第 k 个大或小的元素,Randomized-Select 算法可以实现时间复杂度为 O(n) 的查找。

在网上查了一些资料,都没有讲解为什么该算法时间复杂度是 O(n),于是看了《算法导论》,看了原版的推导和证明,这里做个简单的讲解,具体还是看《算法导论》:9.2 以期望线性时间做选择。

前置知识:快速排序

算法原理

以下默认为升序排序

以快速排序为基点,我们都知道快排很快,能做到 O(nlogn) 的时间复杂度来排序,但当前我们只需要第 k 小的元素,能否对算法做一些缩减,来降低时间复杂度呢。

在原本的快排中,我们将元素分布在一个基准元素的两边,然后再分别递归地执行该算法,但此时我们只要第 k 个小的元素。

通过将元素分布在基准元素两边,假设基准元素下标为 i,可以确定基准元素是第 i 小的元素,同时它左边的元素比基准元素小,右边的元素比基准元素大。也就是说假如 k > i,那么目标元素在右边,反之则在左边,如果 k == i,则基准元素就是目标元素。

所以,我们实际上只需要对一边递归执行算法,另一边是可以确定不可能出现第 k 小的元素的。

到了这一步,算法的基本框架已经完成,但是考虑到最坏的情况,如果序列本身有序(本例中为升序)。那么每次递归,基准元素是第一个元素(假设基准元素都是取范围中的第一个元素),也就是左边没有元素,只有右边有元素,此时丢掉一边的元素就失去了意义,时间复杂度此时会上升至 O(n^2)

时间复杂度证明如下:

  1. 假设 k 的均值为 n/2
  2. E(T(n))=i=0n2iE(T(n))=\sum_{i=0}^{\frac{n}{2}}{i}
  3. E(T(n))=n28+n4E(T(n))=\frac{n^2}{8}+\frac{n}{4}
    所以该情况下时间复杂度为 O(n^2)

为了避免产生这种有序的最坏情况,我们随机挑选基准元素,于是就有了算法名:Randomized-Select 算法。

至此就明白了该算法的所有原理,但是仍然不能直观地理解,为什么避开最坏情况之后,时间复杂度就变成 O(n) 了呢?

时间复杂度证明

直观上,平均情况下,每次分割左右是均等的。

那么找到目标元素大概需要 n + n/2 + n/4 + ... + 1 次,为了方便,就把它当作等比数列计算,求其无穷级数的和(知道是收敛的,无穷不影响复杂度的结果)

于是 E(T(n))=2nE(T(n))=2n,所以时间复杂度是 O(n)

更严格的证明则是证明 T(n)T(n) 存在一个上界,而该上界的均值 E(T(n))E(T(n)) 是线性的,所以该算法的时间复杂度是 O(n)

具体的数学推导就不罗列在这了,可以翻阅《算法导论》详细推导一遍

算法实现 C++

inline int randomIndex(int l, int r){
    // [l, r) 的随机数
    return rand() % (r - l) + l;
}
void randomized_select(vector<int> &nums, int l, int r, int k) {
    if (l + 1 >= r) {
        return;
    }
    // 随机将一个元素放在第一个
    int random = randomIndex(l, r);
    swap(nums[l], nums[random]);

    int left = l, right = r - 1, key = nums[l];
    while (left < right){
        while(left < right && nums[right] >= key) {
            --right;
        }
        while (left < right && nums[left] <= key) {
            ++left;
        }
        swap(nums[left], nums[right]);
    }
    swap(nums[l], nums[left]);
    if(left == k){
        return;
    }else if(left > k){
        randomized_select(nums, l, left, k);
    }else{
        randomized_select(nums, left + 1, r, k);
    }
}
int findKthSmallest(vector<int>& nums, int k) {
    randomized_select(nums, 0, nums.size(), k - 1);
    return nums[k - 1];
}
上一篇 下一篇

评论 | 0条评论