二维数组第k个最小元素(或中位数)的java最快算法?

u5i3ibmn  于 2021-07-06  发布在  Java
关注(0)|答案(5)|浏览(290)

我在相关的主题上看到了很多这样的主题,但没有一个能提供有效的方法。
我想找到 k-th 二维阵列上的最小元素(或中值) [1..M][1..N] 其中每行按升序排序,所有元素都是不同的。
我想有 O(M log MN) 解决方案,但我不知道如何实现(中位数或使用线性复杂度的划分是一些方法,但没有更多的想法。
这是一个老的谷歌面试问题,可以在这里搜索。
但是现在我想要提示或描述最有效的算法(最快的算法)。
我也读了一篇关于这里的文章,但我不明白。
更新1:这里有一个解决方案,但当维度为奇数时。

qgelzfjb

qgelzfjb1#

添加了另一个答案以提供实际的解决方案。这一个已经被留下,因为它是相当兔子洞的评论。
我相信最快的解决方案是k路合并算法。它是一个 O(N log K) 合并算法 K 已排序的列表,总共 N 将项目放入单个大小排序列表中 N .
https://en.wikipedia.org/wiki/k-way_merge_algorithm#k-方式\u合并
给予 MxN 列表。结果是 O(MNlog(M)) . 但是,这是为了对整个列表进行排序。既然你只需要第一个 K 最小的项目而不是全部 N*M ,性能为 O(Klog(M)) . 这比你想要的要好一点 O(K) <= O(M) .
虽然这假设你有 N 大小排序列表 M . 如果你真的有 M 大小排序列表 N ,这可以很容易地处理,尽管只需更改数据的循环方式(请参阅下面的伪代码),但这确实意味着性能会有所提高 O(K log(N)) 相反。
k-way合并只是将每个列表的第一项添加到堆或具有 O(log N) 插入和 O(log N) 找到心灵。
k-way merge的伪代码如下所示:
对于每个排序的列表,将第一个值插入到数据结构中,并使用某种方法确定值来自哪个列表。ie:你可以插入 [value, row_index, col_index] 而不仅仅是 value . 这还可以让您轻松地处理列或行上的循环。
从数据结构中删除最小值并附加到排序列表。
考虑到第2步中的项目来自列表 I 从列表中添加下一个最低值 I 到数据结构。ie:如果值为 row 5 col 4 (data[5][4]) . 如果使用行作为列表,那么下一个值是 row 5 col 5 (data[5][5]) . 如果您使用的是列,那么下一个值是 row 6 col 4 (data[6][4]) . 将下一个值插入到数据结构中,就像插入#1(即: [value, row_index, col_index] )
根据需要返回步骤2。
根据您的需要,请执行步骤2-4 K 次。

idfiyjo8

idfiyjo82#

btilly和nuclearman的答案提供了两种不同的方法,一种是二进制搜索,另一种是k路行合并。
我的建议是把这两种方法结合起来。
如果k很小(比如说小于m乘以2或3)或很大(对于simmetry,接近nxm)enoug

vmpqdwk3

vmpqdwk33#

所以要解决这个问题,需要解决一个稍微不同的问题。我们想知道每一行的上界/下界,总的第k个截止点在哪里。然后我们可以通过,验证在下界或下界的事物数是<k,在上界或下界的事物数是>k,并且它们之间只有一个值。
我想出了一个策略,可以同时在所有行中对这些边界进行二进制搜索。作为一个二进制搜索它“应该”采取 O(log(n)) 通行证。每一关都涉及 O(m) 总共工作了 O(m log(n)) 次。我应该加引号,因为我没有证据证明 O(log(n)) 通行证。事实上,在一行中可能过于激进,从其他行中发现选择的轴心点已关闭,然后不得不后退。但我相信它很少后退,实际上是 O(m log(n)) .
策略是跟踪每一行的下界、上界和中间值。每一次传递我们都会生成一个范围的加权序列,从下到中、从中到上、从上到尾,权重是其中的事物数,值是序列中的最后一个。然后我们在该数据结构中找到第k个值(按权重),并将其用作每个维度的二进制搜索的轴心。
如果轴从下到上超出了范围,我们可以通过在纠正错误的方向上加宽间隔来进行纠正。
当我们有了正确的顺序,我们就有了答案。
有很多边缘情况,所以关注完整的代码可能会有所帮助。
我还假设每行的所有元素都是不同的。如果他们不是,你可以进入无尽的循环(解决这意味着更多的边缘情况……)

import random

# This takes (k, [(value1, weight1), (value2, weight2), ...])

def weighted_kth (k, pairs):
    # This does quickselect for average O(len(pairs)).
    # Median of medians is deterministically the same, but a bit slower
    pivot = pairs[int(random.random() * len(pairs))][0]

    # Which side of our answer is the pivot on?
    weight_under_pivot = 0
    pivot_weight = 0
    for value, weight in pairs:
        if value < pivot:
            weight_under_pivot += weight
        elif value == pivot:
            pivot_weight += weight

    if weight_under_pivot + pivot_weight < k:
        filtered_pairs = []
        for pair in pairs:
            if pivot < pair[0]:
                filtered_pairs.append(pair)
        return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
    elif k <= weight_under_pivot:
        filtered_pairs = []
        for pair in pairs:
            if pair[0] < pivot:
                filtered_pairs.append(pair)
        return weighted_kth (k, filtered_pairs)
    else:
        return pivot

# This takes (k, [[...], [...], ...])

def kth_in_row_sorted_matrix (k, matrix):
    # The strategy is to discover the k'th value, and also discover where
    # that would be in each row.
    #
    # For each row we will track what we think the lower and upper bounds
    # are on where it is.  Those bounds start as the start and end and
    # will do a binary search.
    #
    # In each pass we will break each row into ranges from start to lower,
    # lower to mid, mid to upper, and upper to end.  Some ranges may be
    # empty.  We will then create a weighted list of ranges with the weight
    # being the length, and the value being the end of the list.  We find
    # where the k'th spot is in that list, and use that approximate value
    # to refine each range.  (There is a chance that a range is wrong, and
    # we will have to deal with that.)
    #
    # We finish when all of the uppers are above our k, all the lowers
    # one are below, and the upper/lower gap is more than 1 only when our
    # k'th element is in the middle.

    # Our data structure is simply [row, lower, upper, bound] for each row.
    data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
    is_search = True
    while is_search:
        pairs = []
        for row, lower, upper, bound in data:
            # Literal edge cases
            if 0 == upper:
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            elif lower == bound:
                pairs.append((row[lower], lower + 1))
            elif lower + 1 == upper: # No mid.
                pairs.append((row[lower], lower + 1))
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            else:
                mid = (upper + lower) // 2
                pairs.append((row[lower], lower + 1))
                pairs.append((row[mid], mid - lower))
                pairs.append((row[upper], upper - mid))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))

        pivot = weighted_kth(k, pairs)

        # Now that we have our pivot, we try to adjust our parameters.
        # If any adjusts we continue our search.
        is_search = False
        new_data = []
        for row, lower, upper, bound in data:
            # First cases where our bounds weren't bounds for our pivot.
            # We rebase the interval and either double the range.
            # - double the size of the range
            # - go halfway to the edge
            if 0 < lower and pivot <= row[lower]:
                is_search = True
                if pivot == row[lower]:
                    new_data.append((row, lower-1, min(lower+1, bound), bound))
                elif upper <= lower:
                    new_data.append((row, lower-1, lower, bound))
                else:
                    new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
            elif upper < bound and row[upper] <= pivot:
                is_search = True
                if pivot == row[upper]:
                    new_data.append((row, upper-1, upper+1, bound))
                elif lower < upper:
                    new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
                else:
                    new_data.append((row, upper, upper+1, bound))
            elif lower + 1 < upper:
                if upper == lower+2 and pivot == row[lower+1]:
                    new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
                else:
                    # We will split this interval.
                    is_search = True
                    mid = (upper + lower) // 2
                    if row[mid] < pivot:
                        new_data.append((row, mid, upper, bound))
                    elif pivot < row[mid] pivot:
                        new_data.append((row, lower, mid, bound))
                    else:
                        # We center our interval on the pivot
                        new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
            else:
                # We look like we found where the pivot would be in this row.
                new_data.append((row, lower, upper, bound))
        data = new_data # And set up the next search
    return pivot
50pmv0ei

50pmv0ei4#

也许我错过了什么,但如果你 NxM 矩阵 AM 行已经升序排序,没有重复的元素 k -行的最小值只是拾取 k -行中的第个元素 O(1) . 要移动到2d,只需选择 k -改为按升序排序 O(M.log(M)) 再挑一次 k-th 导致 O(N.log(N)) .
让我们有矩阵 A[N][M] 元素所在的位置 A[column][row] 排序 k-thA 提升 O(M.log(M)) 如此排序 A[k][i] 哪里 i = { 1,2,3,...M } 提升
挑选 A[k][k] 结果呢
如果你想在所有元素中取第k个最小值 A 相反,您需要以类似于merge sort的形式利用已经排序的行。
创建空列表 c[] 等待 k 最小值
处理列
创建临时数组 b[] 它保存处理过的列 O(N.log(N)) 合并 c[] 以及 b[] 所以呢 c[] 坚持到 k 最小值
使用临时数组 d[] 将导致 O(k+n) 如果在合并过程中没有使用 b 停止处理列
这可以通过添加标志数组来完成 f 从哪来的 b,c 该值是在合并过程中获取的,然后只是检查是否从中获取了任何值 b 输出 c[k-1] 当把所有这些放在一起时,最后的复杂性是 O(min(k,M).N.log(N)) 如果我们考虑一下 k 小于 M 我们可以改写成 O(k.N.log(N)) 否则 O(M.N.log(N)) . 而且平均来说,要迭代的列的数量将更不可能 ~(1+(k/N)) 所以平均复杂度是 ~O(N.log(N)) 但这只是我的猜测,可能是错的。
下面是小型c++/vcl示例:

//$$---- Form CPP ----
//---------------------------------------------------------------------------

# include <vcl.h>

# pragma hdrstop

# include "Unit1.h"

# include "sorts.h"

//---------------------------------------------------------------------------

# pragma package(smart_init)

# pragma resource "*.dfm"

TForm1 *Form1;
//---------------------------------------------------------------------------
const int m=10,n=8; int a[m][n],a0[m][n]; // a[col][row]
//---------------------------------------------------------------------------
void generate()
    {
    int i,j,k,ii,jj,d=13,b[m];
    Randomize();
    RandSeed=0x12345678;
    // a,a0 = some distinct pseudorandom values (fully ordered asc)
    for (k=Random(d),j=0;j<n;j++)
     for (i=0;i<m;i++,k+=Random(d)+1)
      { a0[i][j]=k; a[i][j]=k; }
    // schuffle a
    for (j=0;j<n;j++)
     for (i=0;i<m;i++)
        {
        ii=Random(m);
        jj=Random(n);
        k=a[i][j]; a[i][j]=a[ii][jj]; a[ii][jj]=k;
        }
    // sort rows asc
    for (j=0;j<n;j++)
        {
        for (i=0;i<m;i++) b[i]=a[i][j];
        sort_asc_quick(b,m);
        for (i=0;i<m;i++) a[i][j]=b[i];
        }

    }
//---------------------------------------------------------------------------
int kmin(int k) // k-th min from a[m][n] where a rows are already sorted
    {
    int i,j,bi,ci,di,b[n],*c,*d,*e,*f,cn;
    c=new int[k+k+k]; d=c+k; f=d+k;
    // handle edge cases
    if (m<1) return -1;
    if (k>m*n) return -1;
    if (m==1) return a[0][k];
    // process columns
    for (cn=0,i=0;i<m;i++)
        {
        // b[] = sorted_asc a[i][]
        for (j=0;j<n;j++) b[j]=a[i][j];     // O(n)
        sort_asc_quick(b,n);                // O(n.log(n))
        // c[] = c[] + b[] asc sorted and limited to cn size
        for (bi=0,ci=0,di=0;;)              // O(k+n)
            {
                 if ((ci>=cn)&&(bi>=n)) break;
            else if (ci>=cn)     { d[di]=b[bi]; f[di]=1; bi++; di++; }
            else if (bi>= n)     { d[di]=c[ci]; f[di]=0; ci++; di++; }
            else if (b[bi]<c[ci]){ d[di]=b[bi]; f[di]=1; bi++; di++; }
            else                 { d[di]=c[ci]; f[di]=0; ci++; di++; }
            if (di>k) di=k;
            }
        e=c; c=d; d=e; cn=di;
        for (ci=0,j=0;j<cn;j++) ci|=f[j];   // O(k)
        if (!ci) break;
        }
    k=c[k-1];
    delete[] c;
    return k;
    }
//---------------------------------------------------------------------------
__fastcall TForm1::TForm1(TComponent* Owner):TForm(Owner)
    {
    int i,j,k;
    AnsiString txt="";

    generate();

    txt+="a0[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a0[i][j]);

    txt+="\r\na[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a[i][j]);

    k=20;
    txt+=AnsiString().sprintf("\r\n%ith smallest from a0 = %4i\r\n",k,a0[(k-1)%m][(k-1)/m]);
    txt+=AnsiString().sprintf("\r\n%ith smallest from a  = %4i\r\n",k,kmin(k));

    mm_log->Lines->Add(txt);
    }
//-------------------------------------------------------------------------

忽略vcl的东西。函数生成计算 a0, a 矩阵,其中 a0 完全分类和 a 只对行进行排序,所有值都是不同的。函数 kmin 上面描述的算法是否返回第k个最小值 a[m][n] 为了分类,我用了这个:

template <class T> void sort_asc_quick(T *a,int n)
    {
    int i,j; T a0,a1,p;
    if (n<=1) return;                                   // stop recursion
    if (n==2)                                           // edge case
        {
        a0=a[0];
        a1=a[1];
        if (a0>a1) { a[0]=a1; a[1]=a0; }                // condition
        return;
        }
    for (a0=a1=a[0],i=0;i<n;i++)                        // pivot = midle (should be median)
        {
        p=a[i];
        if (a0>p) a0=p;
        if (a1<p) a1=p;
        } if (a0==a1) return; p=(a0+a1+1)/2;            // if the same values stop
    if (a0==p) p++;
    for (i=0,j=n-1;i<=j;)                               // regroup
        {
        a0=a[i];
        if (a0<p) i++; else { a[i]=a[j]; a[j]=a0; j--; }// condition
        }
    sort_asc_quick(a  ,  i);                            // recursion a[]<=p
    sort_asc_quick(a+i,n-i);                            // recursion a[]> p
    }

这里是输出:

a0[][]
  10   17   29   42   54   66   74   85   90  102 
 112  114  123  129  142  145  146  150  157  161 
 166  176  184  191  195  205  213  216  222  224 
 226  237  245  252  264  273  285  290  291  296 
 309  317  327  334  336  349  361  370  381  390 
 397  398  401  411  422  426  435  446  452  462 
 466  477  484  496  505  515  522  524  525  530 
 542  545  548  553  555  560  563  576  588  590 

a[][]
 114  142  176  264  285  317  327  422  435  466 
 166  336  349  381  452  477  515  530  542  553 
 157  184  252  273  291  334  446  524  545  563 
  17  145  150  237  245  290  370  397  484  576 
  42  129  195  205  216  309  398  411  505  560 
  10  102  123  213  222  224  226  390  496  555 
  29   74   85  146  191  361  426  462  525  590 
  54   66   90  112  161  296  401  522  548  588 

20th smallest from a0 =  161

20th smallest from a  =  161

这个例子只迭代了5列。。。

dzjeubhm

dzjeubhm5#

似乎最好的办法是在越来越大的区块中进行k-way合并。k-way合并试图构建一个排序列表,但是我们不需要对它进行排序,也不需要考虑每个元素。相反,我们将创建一个半排序的间隔。间隔将被排序,但仅按最高值排序。
https://en.wikipedia.org/wiki/k-way_merge_algorithm#k-方式\u合并
我们使用与k-way合并相同的方法,但是有一个扭曲。基本上,它的目的是间接地建立一个半排序的子列表。例如,它不是找到[1,2,3,4,5,6,7,8,10]来确定k=10,而是找到类似于[(1,3),(4,6),(7,15)]的东西。对于k-way合并,我们每次从每个列表中考虑一个项目。在这种方法中,当从给定的列表中提取时,我们首先要考虑z项,然后是2z项,然后是22z项,因此第i次要考虑2^iz项。给定一个mxn矩阵,这意味着我们需要 O(log(N)) 列表中的项目 M 次。
对于每个排序的列表,插入第一个 K 使用某种方法确定值来自哪个列表,将子列表添加到数据结构中。我们希望数据结构使用插入其中的子列表中的最高值。在本例中,我们需要类似于[max\u value of sublist,row index,start\u index,end\u index]的内容。 O(m) 从数据结构中删除最小的值(现在是一个值列表)并附加到排序列表。 O(log (m)) 考虑到第2步中的项目来自列表 I 添加下一个 2^i * Z 列表中的值 I 在第i次从特定列表中提取数据结构时(基本上只是从数据结构中移除的子列表中出现的数字的两倍)。 O(log m) 如果半排序子列表的大小大于k,则使用二进制搜索查找第k个值。 O(log N)) . 如果数据结构中还有任何子列表,其中最小值小于k。转到步骤1,将列表作为输入,并使用新的 K 存在 k - (size of semi-sorted list) .
如果半排序子列表的大小等于k,则返回半排序子列表中的最后一个值,这是第k个值。
如果半排序子列表的大小小于k,请返回步骤2。
至于表现。让我们看看这里:
O(m log m) 将初始值添加到数据结构。
最多需要考虑一下 O(m) 每个子列表需要 O(log n) o(m log n)的时间。
最后需要执行二进制搜索, O(log m) ,如果k的值不确定(第4步),可能需要将问题简化为递归子列表,但我认为这不会影响大o。编辑:我相信这只是增加了另一个 O(mlog(n)) 在最坏的情况下,这对大o没有影响。
看起来像是 O(mlog(m) + mlog(n)) 或者只是 O(mlog(mn)) .
作为优化,如果k大于 NM/2 考虑最小值时考虑最大值,考虑最大值时考虑最小值。当k接近时,这将大大提高性能 NM .

相关问题