c++ 需要一个大的约束子集和类问题的优化技巧

wnavrhmk  于 2023-02-01  发布在  其他
关注(0)|答案(1)|浏览(118)

给定一个数1〈= N〈= 3*10^5,计算集合{1,2,...,N-1}中所有和为N的子集。这本质上是子集求和问题的一个修改版本,但有一个修改,即求和和元素数相同,并且集合/数组线性增加1到N-1。
我想我已经用dp有序Map和包含/排除递归算法解决了这个问题,但是由于时间和空间的复杂性,我不能计算超过10000个元素。

#include <iostream>
#include <chrono>
#include <map>
#include "bigint.h"

using namespace std;

//2d hashmap to store values from recursion; keys- i & sum; value- count
map<pair<int, int>, bigint> hmap;

bigint counter(int n, int i, int sum){

    //end case
    if(i == 0){ 
        if(sum == 0){
            return 1;
        }
        return 0;
    }

    //alternative end case if its sum is zero before it has finished iterating through all of the possible combinations
    if(sum == 0){
        return 1;
    }

    //case if the result of the recursion is already in the hashmap
    if(hmap.find(make_pair(i, sum)) != hmap.end()){
        return hmap[make_pair(i, sum)];
    }

    //only proceed further recursion if resulting sum wouldnt be negative
    if(sum - i < 0){
        //optimization that skips unecessary recursive branches
        return hmap[make_pair(i, sum)] = counter(n, sum, sum);
    }
    else{
                                        //include the number   dont include the number
        return hmap[make_pair(i, sum)] = counter(n, i - 1, sum - i) + counter(n, i - 1, sum);
    }
}

该函数的起始值为N、N-1和N,分别表示元素数、迭代器(递减)和递归分支的总和(随每个包含的值而递减)。
这是计算子集数量的代码。对于输入3000,大约需要22秒才能输出40位长的结果。由于位数较长,我不得不使用rgroshanrg的任意精度库bigint,该库适用于小于10000的值。超出此范围的测试在第28-29行给我一个segfault。也许是由于存储的任意精度值变得太大,并在Map冲突。我需要以某种方式了这段代码,使它可以与值超过10000,但我难倒了它。任何想法或我应该切换到另一种算法和数据存储?

vdzxcuhz

vdzxcuhz1#

下面是Evangelos Georgiadis在一篇论文“Computing Partition Numbers q(n)"中描述的另一种算法:

std::vector<BigInt> RestrictedPartitionNumbers(int n)
{
    std::vector<BigInt> q(n, 0);
    // initialize q with A010815
    for (int i = 0; ; i++)
    {
        int n0 = i * (3 * i - 1) >> 1;
        if (n0 >= q.size())
            break;
        q[n0] = 1 - 2 * (i & 1);
        int n1 = i * (3 * i + 1) >> 1;
        if (n1 < q.size())
            q[n1] = 1 - 2 * (i & 1);
    }
    // construct A000009 as per "Evangelos Georgiadis, Computing Partition Numbers q(n)"
    for (size_t k = 0; k < q.size(); k++)
    {
        size_t j = 1;
        size_t m = k + 1;
        while (m < q.size())
        {
            if ((j & 1) != 0)
                q[m] += q[k] << 1;
            else
                q[m] -= q[k] << 1;
            j++;
            m = k + j * j;
        }
    }
    return q;
}

这不是最快的算法,在我的计算机上,对于n = 300000,这花了大约半分钟,但你只需要做一次(因为它计算所有的分区号,直到某个界限),而且它不需要很多内存(稍微超过150MB)。
结果上升到但 * 排除 * n,他们假设,对于每一个数字,该数字本身是允许的一个划分本身,例如集{4}是一个划分的数字4,在您的定义的问题,您排除了这种情况,所以您需要减去1的结果。
也许有一个更好的方式来表达A010815,这部分代码并不慢,虽然,我只是觉得它看起来很糟糕。

相关问题