SOS dp

在算法竞赛中,处理涉及位运算的问题时,我们通常会遇到一类需求:给定一个数组,要求出每个二进制状态的所有子集的权值之和

当然,首先我们要解释几个变量,这样方便我们后续的学习:

  • n

    表示二进制位数,也可以理解成“集合里有多少个元素”。如果:n = 3,那么状态数量就是:1 << n,也就是:2^3 = 8,状态范围是:000 ~ 111

  • mask

    表示一个状态,也就是一个集合。比如:mask = 10110表示第 1、2、4 位被选中。for(int mask = 0;mask < (1 << n);mask++) 意思是枚举所有状态。

  • sub

    表示 mask 的某个子集。比如:mask = 110,它的 sub 可以是:000、010、100、110。for(int sub = mask; ;sub = (sub - 1) & mask) 表示枚举 mask 的所有子集。

  • sup

    表示 mask 的超集。如果:mask = 010,那么它的超集可以是:010、011、110、111。意思是:包含 mask 的集合。

  • bit

    表示正在处理哪一位

  • A

    表示原数组,或者每个状态本来的值/权值

  • F/dp

    表示处理后的数组

1. 引入

让我们从一道极其经典的背景题切入:

题目描述:给定一个长度为 的数组 (下标从 ),你需要计算一个同等大小的新数组 ,使得 等于 所有二进制子集的 之和。

用公式表达就是:

(注:这里的 表示 的子集,即满足 (mask & i) == i)

数据范围

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
什么意思呢?这题要你做的是:对每个 mask,把所有“二进制上被它包含”的下标 i 的 A[i] 加起来。换句话说,i 只能在 mask 有 1 的位置上取 1,不能在 mask 是 0 的位置上取 1。

举个例子,假设 n = 3,下标是 0 ~ 7,二进制如下:

0 = 000
1 = 001
2 = 010
3 = 011
4 = 100
5 = 101
6 = 110
7 = 111

如果 mask = 6,也就是:mask = 110

它的所有子集是:

000 = 0
010 = 2
100 = 4
110 = 6

所以:F[6] = A[0] + A[2] + A[4] + A[6]

再比如 mask = 7:mask = 111

因为它三个位置都是 1,所以所有下标都是它的子集:F[7] = A[0] + A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7]

面对子集求和,最直观的反应就是暴力枚举。在 C++ 中,我们可以通过位运算技巧 sub = (sub - 1) & mask 来高效地遍历一个状态的所有子集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
这句:sub = (sub - 1) & mask是在说,从当前子集 sub,跳到 mask 的下一个更小子集。
先看完整写法:
for(int sub = mask; ;sub = (sub - 1) & mask){
// sub 是 mask 的一个子集

if (sub == 0) break;
}
比如:mask = 110
它的子集有:
110
100
010
000
代码过程是:
sub = 110

sub - 1 = 101
101 & 110 = 100

sub = 100

sub - 1 = 011
011 & 110 = 010

sub = 010

sub - 1 = 001
001 & 110 = 000

sub = 000
结束

所以枚举顺序是:110 -> 100 -> 010 -> 000

为什么要 & mask?
因为 sub - 1 可能会把原来 mask 里没有的位变成 1。
例如:
mask = 110
sub = 100
做 sub - 1,得到100 - 1 = 011
但是 011 不是 110 的子集,因为最低位是 1,而 mask 的最低位是 0。
所以要再做:011 & 110 = 010,把 mask 中没有的位强制清零,于是回到合法子集。

于是我们可以有

1
2
3
4
5
6
7
8
9
10
11
12
13
void solve(){
int n = 20;
vector<ll> a(1 << n,1);
vector<ll> f(1 << n,0);

for(int mask = 0;mask < (1<<n);++mask){
for(int sub = mask;sub > 0;sub = (sub-1)&mask){
f[mask] += a[sub];
}
//最后加上空集0的情况
f[mask] += a[0];
}
}

虽然这段代码写起来好像很优雅,但但但但但是,由二项式定理可知,所有状态的子集数量之和为 。 当 时,。在绝大多数限时 1 秒或 2 秒的比赛中,超过 的计算量就会面临极高的 TLE 风险。我们需要一种更优雅的解法。

2. 想法

不要把 mask 当成一个抽象的十进制数字,而是把它看作一个 维空间中的坐标点。

  • 在一维数组求前缀和:是一条线上的累加
  • 在二维数组求前缀和:是一个矩形内的累加
  • 在二进制状态求子集和:就是一个 维空间中(每一维的坐标只有 0 和 1),从原点 到目标状态点的多维前缀和

既然是前缀和,我们就不需要每次都从头枚举所有子集,而是可以利用“之前已经算过的状态”,一层一层(一维一维)地进行递推。

3. 状态转移

状态定义: 定义 表示:在计算 的子集和时,只允许前 个二进制位(即第 到第 位)发生变化(即可以为 0 也可以为 1,只要不超出 的限制),而第 位及以后的高位必须与 严格相同的子集之和。

状态转移与推导: 我们要从 递推到 ,也就是现在我们“开放了第 位的变化权限”。我们需要观察 自身的第 位是 0 还是 1:

情况 1: 的第 位是

  • 推导:因为目标状态这一维是 0,它的子集在这一位也只能是 0。开放变化权限对它毫无影响,它只能继承之前的结果。
  • 转移方程:

情况 2: 的第 位是

  • 推导:既然目标这一维是 1,它的子集在这一位既可以是 ,也可以是
    • 当子集第 位是 1 时:这部分子集的和就是 (继承之前的结果)。
    • 当子集第 位是 0 时:我们需要把 的第 位强行变成 0(即 mask ^ (1 << i)),去拿那部分状态之前计算好的结果,即
  • 转移方程:
1
2
3
4
5
6
7
8
9
10
11
12
13
vector<vector<ll>> dp(n+1,vector<ll> (1 << n,0));
//当i = 0,一位都不允许改变,所以mask只能是他自己
for(int mask = 0;mask < (1 << n);++mask) dp[0][mask] = a[mask];

for(int i = 0;i < n;++i){ //枚举开放哪一位
for(int mask = 0;mask < (1 << n);++mask){ //枚举所有状态
if(mask & (1 << i)){
dp[i+1][mask] = dp[i][mask] + dp[i][mask^(1<<i)];
}else{
dp[i+1][mask] = dp[i][mask];
}
}
}

也许还是有些难理解,没事,可以稍微举个例子理解一下

假设n = 3,mask = 101,它的子集有000,001,100,101

初始化:dp[0][101] = a[101]

处理第0位时,101的第0位是1:dp[1][101] = dp[0][101] + dp[0][100]

所以现在包含a[101]+a[100]

处理第1位时,101的第1位是0:dp[2][101] = dp[1][101]

不变

处理第2位时,101的第2位是1:dp[3][101] = dp[2][101] + dp[2][001]

dp[2][001]已经包含a[001]+a[000],所以最终dp[3][101] = a[101]+a[100]+a[001]+a[000]

它的空间复杂度是$O(n2^n)O(n2^n)$

4. 优化

仔细观察上面的转移方程,我们会发现一个极其优美的性质: 只依赖于 ,并且我们在更新 时,用到的 一定是一个比 小的状态

和溢位

我们可以像01背包那样,把第一维完全砍掉,直接在一个一维数组上进行原地滚动更新,于是我们得到了终极模板

1
2
3
4
5
6
7
8
9
10
11
12
int max_mask = 1 << n;

vector<ll> dp(max_mask);
for(int i = 0;i < max_mask;++i) cin >> dp[i];

for(int i = 0;i < n;++i){
for(int mask = 0;mask < max_mask;++mask){
if(mask & (1 << i)){
dp[mask] += dp[mask ^ (1 << i)];
}
}
}

5. 例题

为了检验我们的学习成果,我们来看一道非常震撼的真题

例1:E. XOR Again?

门钥匙

Ques

题目描述

给定一个长度为 N 的整数数组:

1
A1, A2, ..., AN

对于每一个 M,其中 1 ≤ M ≤ N,解决下面的问题:

将数组划分成 恰好 M 个连续的非空块

每个块的代价为该块内所有元素的 按位异或和

一次划分的总代价为所有块代价的 按位或

请你求出对于每一个 M,划分成恰好 M 个连续块时,可能得到的最小总代价。

输入格式

第一行包含一个整数 N

第二行包含 N 个用空格分隔的整数:

1
A1 A2 ... AN

输出格式

输出一行,包含 N 个整数。

i 个整数表示当 M = i 时的答案。

数据范围

1
2
1 ≤ N ≤ 10^6
0 ≤ Ai ≤ 10^6

样例

1
2
6
0 3 10 2 4 5
1
10 10 11 11 11 15
1
2
4
0 1 0 1
1
0 0 1 1

Ans

原题要求:

把数组切成恰好 M 段,每段的代价是这一段的 XOR,整个划分的代价是所有段代价的 OR。
对每个 M = 1...N,求最小代价。

如果直接想怎么切,会很难,因为既有分段,又有 XOR,又有 OR。

所以第一步可以先观察数据范围:

1
2
N ≤ 10^6
Ai ≤ 10^6

N 很大,说明不能做普通的分段 DP。但是 Ai ≤ 10^6,也就是说每个数最多只有大约 20 个二进制位。因此这题大概率不是按位置做复杂 DP,而是要从二进制状态入手。


而这道题的总代价是:

1
每一段 XOR 的按位 OR

按位 OR 有一个特点:

只要某一段 XOR 中出现了某一位 1,那么最终答案这一位就一定是 1

所以如果我们假设最终答案的二进制位只能出现在某个 mask 里面,那么每一段的 XOR 都不能出现 mask 之外的位。

也就是说,我们可以把问题转化成:

枚举一个 mask,判断能不能切成 M 段,使得每一段的 XOR 都是 mask 的子集。也就是 x 的所有 1 位,mask 中也都有。


定义前缀异或:pre[i] = A1 ^ A2 ^ … ^ Ai

那么一段区间 [l + 1, r] 的 XOR 是:pre[r] ^ pre[l]

假设我们把数组切成几段:[1 … c1] [c1+1 … c2] [c2+1 … c3] … [ck+1 … N]

这些切分点是:c1, c2, c3, …

每一段的 XOR 分别是:

1
2
3
4
5
pre[c1] ^ pre[0]
pre[c2] ^ pre[c1]
pre[c3] ^ pre[c2]
...
pre[N] ^ pre[ck]

因为:pre[0] = 0

第一段 XOR 就是 pre[c1]

如果第一段 XOR 必须是 mask 的子集,那么 pre[c1] 也必须是 mask 的子集。

再看前两段整体的 XOR pre[c2]

它等于第一段 XOR 再异或第二段 XOR。

如果每一段 XOR 都只包含 mask 中的位,那么它们异或起来也只会包含 mask 中的位。

所以 pre[c2] 也必须是 mask 的子集。

继续往后推,所有被选中的切分点 c 都必须满足 pre[c] 是 mask 的子集

也就是 (pre[c] & mask) == pre[c]

同时最后的 pre[N] 也必须是 mask 的子集。


当然,反过来也成立

如果两个切分点 xy 满足:

1
2
pre[x] 是 mask 的子集
pre[y] 是 mask 的子集

那么区间 (x, y] 的 XOR 是 pre[y] ^ pre[x]

因为 pre[x]pre[y] 都没有 mask 之外的位,所以它们异或之后也不会出现 mask 之外的位。

也就是说:

只要选中的切分点的前缀异或都是 mask 的子集,那么切出来的每一段 XOR 也一定是 mask 的子集。

所以对于一个固定的 mask,问题就变成:

有多少个位置 i 可以作为切分点,使得 pre[i]mask 的子集?

注意切分点只能在数组内部,所以只看 i = 1, 2, …, N - 1


当然,如果要切成 M 段,需要选:M - 1个内部切分点。

对于一个固定的 mask,假设有 cnt[mask] 个内部位置 i 满足 pre[i] 是 mask 的子集

那么只要:cnt[mask] >= M - 1 就可以切成 M 段。

同时还要保证整个数组的异或 pre[N] 也是 mask 的子集。否则最后终点不合法,无论怎么切都不行。

所以固定 mask 的可行条件是:

1
2
3
pre[N] 是 mask 的子集
并且
cnt[mask] >= M - 1

现在我们再看一遍题干

原题是:

1
对每个 M,求最小划分代价

现在变成:

1
2
3
对每个 M,找最小的 mask,
满足 pre[N] 是 mask 的子集,
并且 cnt[mask] >= M - 1

其中:

1
cnt[mask] = 内部前缀异或 pre[i] 中,有多少个是 mask 的子集

这一步就是整道题的核心转化。


cnt[mask] 怎么快速求?

我们先统计每种前缀异或出现了多少次。

设 freq[x] = 内部位置中,pre[i] == x 的数量

那么 cnt[mask] = sum(freq[x]),其中 x 是 mask 的子集

也就是:cnt[mask] = 所有子集 x ⊆ mask 的 freq[x] 之和

这正好是 SOS DP 的经典形式。

初始化:cnt[mask] = freq[mask]

然后做 SOS DP:

1
2
3
4
5
6
7
for (int bit = 0; bit < B; bit++) {
for (int mask = 0; mask < (1 << B); mask++) {
if (mask & (1 << bit)) {
cnt[mask] += cnt[mask ^ (1 << bit)];
}
}
}

做完之后:cnt[mask] = 所有 x ⊆ mask 的 freq[x] 之和,也就是这个 mask 能提供多少个合法切分点。


对于每个 mask,我们已经知道它能提供 cnt[mask] 个合法切分点。

如果它能提供 c 个切分点,那么它可以支持:1 段,2 段,…,c + 1 段,因为切成 M 段只需要 M - 1 个切分点。

所以我们令:best[k] = 至少能提供 k 个切分点时的最小 mask

那么答案就是:ans[M] = best[M - 1]

具体做法是:

先枚举所有 mask

1
2
c = cnt[mask]
best[c] = min(best[c], mask)

这里先表示“恰好有 c 个合法切分点”的最小 mask

但是我们真正需要的是“至少有 k 个合法切分点”,所以再做一遍后缀最小值:

1
2
3
for (int k = N - 2; k >= 0; k--) {
best[k] = min(best[k], best[k + 1]);
}

这样 best[k] 就表示:cnt[mask] >= k 的所有 mask 中,最小的 mask

最后输出:best[0], best[1], best[2], …, best[N - 1]

分别对应:M = 1, 2, 3, …, N


而后,我们稍微总结一下

这题因为 N 很大,不能直接分段 DP;但 Ai ≤ 10^6,说明数值只有 20 位左右,可以从二进制状态入手。

由于最终代价是所有段 XOR 的 OR,所以如果我们枚举一个候选 mask,就可以把问题变成:能不能切出若干段,使得每一段的 XOR 都只包含 mask 中的位。

用前缀异或表示区间 XOR 后,可以发现:一段 XOR 合法等价于相邻切分点的前缀异或都在 mask 内。因此,对于固定的 mask,所有合法切分点就是那些满足 pre[i] ⊆ mask 的内部位置。

如果这样的内部位置有 cnt[mask] 个,那么这个 mask 最多可以切成 cnt[mask] + 1 段。要切成 M 段,只需要满足 cnt[mask] >= M - 1。同时整个数组的异或 pre[N] 也必须是 mask 的子集。

于是问题变成:对每个 mask,求有多少个内部前缀异或是它的子集。这个就是标准 SOS DP:

1
cnt[mask] = sum freq[sub], sub ⊆ mask

最后枚举所有合法 mask,根据它能提供的切分点数量更新答案,再做后缀最小值,就能得到所有 M 的答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
const int B = 20;
const int S = 1 << B;
const int INF = 1e9;

void solve(){
int n; cin >> n;

vector<int> freq(S,0);
int pre = 0;
for(int i = 1;i <= n;++i){
int x; cin >> x;
pre ^= x;
if(i < n) freq[pre]++;
}

int tot = pre;
vector<int> cnt = freq;

for(int i = 0;i < B;++i){
for(int mask = 0;mask < S;++mask){
if(mask & (1 << i)){
cnt[mask] += cnt[mask ^ (1 << i)];
}
}
}

vector<int> best(n,INF);
for(int mask = 0;mask < S;++mask){
if((mask&tot) == tot){
int c = cnt[mask];
best[c] = min(best[c],mask);
}
}

for(int k = n-2;k >= 0;--k){
best[k] = min(best[k],best[k+1]);
}

for(int M = 1;M <= n;++M){
if(M > 1) cout << ' ';
cout << best[M-1];
}
cout << '\n';
}

相关文章

[[4-1-Linear-DP]]