SOS DP
SOS dp
在算法竞赛中,处理涉及位运算的问题时,我们通常会遇到一类需求:给定一个数组,要求出每个二进制状态的所有子集的权值之和
当然,首先我们要解释几个变量,这样方便我们后续的学习:
n
表示二进制位数,也可以理解成“集合里有多少个元素”。如果:n = 3,那么状态数量就是:1 << n,也就是:2^3 = 8,状态范围是:000 ~ 111
mask
表示一个状态,也就是一个集合。比如:mask = 10110表示第 1、2、4 位被选中。for(int mask = 0;mask < (1 << n);mask++) 意思是枚举所有状态。
sub
表示 mask 的某个子集。比如:mask = 110,它的 sub 可以是:000、010、100、110。for(int sub = mask; ;sub = (sub - 1) & mask) 表示枚举 mask 的所有子集。
sup
表示 mask 的超集。如果:mask = 010,那么它的超集可以是:010、011、110、111。意思是:包含 mask 的集合。
bit
表示正在处理哪一位
A
表示原数组,或者每个状态本来的值/权值
F/dp
表示处理后的数组
1. 引入
让我们从一道极其经典的背景题切入:
题目描述:给定一个长度为
用公式表达就是:
(注:这里的 (mask & i) == i)
数据范围:
1 | 什么意思呢?这题要你做的是:对每个 mask,把所有“二进制上被它包含”的下标 i 的 A[i] 加起来。换句话说,i 只能在 mask 有 1 的位置上取 1,不能在 mask 是 0 的位置上取 1。 |
面对子集求和,最直观的反应就是暴力枚举。在 C++ 中,我们可以通过位运算技巧 sub = (sub - 1) & mask 来高效地遍历一个状态的所有子集。
1 | 这句:sub = (sub - 1) & mask是在说,从当前子集 sub,跳到 mask 的下一个更小子集。 |
于是我们可以有
1 | void solve(){ |
虽然这段代码写起来好像很优雅,但但但但但是,由二项式定理可知,所有状态的子集数量之和为
2. 想法
不要把 mask 当成一个抽象的十进制数字,而是把它看作一个
- 在一维数组求前缀和:是一条线上的累加
- 在二维数组求前缀和:是一个矩形内的累加
- 在二进制状态求子集和:就是一个
维空间中(每一维的坐标只有 0 和 1),从原点 到目标状态点的多维前缀和
既然是前缀和,我们就不需要每次都从头枚举所有子集,而是可以利用“之前已经算过的状态”,一层一层(一维一维)地进行递推。
3. 状态转移
状态定义: 定义
状态转移与推导: 我们要从
情况 1:
- 推导:因为目标状态这一维是 0,它的子集在这一位也只能是 0。开放变化权限对它毫无影响,它只能继承之前的结果。
- 转移方程:
情况 2:
- 推导:既然目标这一维是 1,它的子集在这一位既可以是
,也可以是 。 - 当子集第
位是 1 时:这部分子集的和就是 (继承之前的结果)。 - 当子集第
位是 0 时:我们需要把 的第 位强行变成 0(即 mask ^ (1 << i)),去拿那部分状态之前计算好的结果,即。
- 当子集第
- 转移方程:
1 | vector<vector<ll>> dp(n+1,vector<ll> (1 << n,0)); |
也许还是有些难理解,没事,可以稍微举个例子理解一下
假设n = 3,mask = 101,它的子集有000,001,100,101
初始化:dp[0][101] = a[101]
处理第0位时,101的第0位是1:dp[1][101] = dp[0][101] + dp[0][100]
所以现在包含a[101]+a[100]
处理第1位时,101的第1位是0:dp[2][101] = dp[1][101]
不变
处理第2位时,101的第2位是1:dp[3][101] = dp[2][101] + dp[2][001]
而dp[2][001]已经包含a[001]+a[000],所以最终dp[3][101] = a[101]+a[100]+a[001]+a[000]
它的空间复杂度是$O(n2^n)
4. 优化
仔细观察上面的转移方程,我们会发现一个极其优美的性质:
和溢位
我们可以像01背包那样,把第一维完全砍掉,直接在一个一维数组上进行原地滚动更新,于是我们得到了终极模板
1 | int max_mask = 1 << n; |
5. 例题
为了检验我们的学习成果,我们来看一道非常震撼的真题
例1:E. XOR Again?
Ques
题目描述
给定一个长度为 N 的整数数组:
1 | A1, A2, ..., AN |
对于每一个 M,其中 1 ≤ M ≤ N,解决下面的问题:
将数组划分成 恰好 M 个连续的非空块。
每个块的代价为该块内所有元素的 按位异或和。
一次划分的总代价为所有块代价的 按位或。
请你求出对于每一个 M,划分成恰好 M 个连续块时,可能得到的最小总代价。
输入格式
第一行包含一个整数 N。
第二行包含 N 个用空格分隔的整数:
1 | A1 A2 ... AN |
输出格式
输出一行,包含 N 个整数。
第 i 个整数表示当 M = i 时的答案。
数据范围
1 | 1 ≤ N ≤ 10^6 |
样例
1 | 6 |
1 | 10 10 11 11 11 15 |
1 | 4 |
1 | 0 0 1 1 |
Ans
原题要求:
把数组切成恰好
M段,每段的代价是这一段的 XOR,整个划分的代价是所有段代价的 OR。
对每个M = 1...N,求最小代价。
如果直接想怎么切,会很难,因为既有分段,又有 XOR,又有 OR。
所以第一步可以先观察数据范围:
1 | N ≤ 10^6 |
N 很大,说明不能做普通的分段 DP。但是 Ai ≤ 10^6,也就是说每个数最多只有大约 20 个二进制位。因此这题大概率不是按位置做复杂 DP,而是要从二进制状态入手。
而这道题的总代价是:
1 | 每一段 XOR 的按位 OR |
按位 OR 有一个特点:
只要某一段 XOR 中出现了某一位
1,那么最终答案这一位就一定是1。
所以如果我们假设最终答案的二进制位只能出现在某个 mask 里面,那么每一段的 XOR 都不能出现 mask 之外的位。
也就是说,我们可以把问题转化成:
枚举一个
mask,判断能不能切成M段,使得每一段的 XOR 都是mask的子集。也就是x的所有1位,mask中也都有。
定义前缀异或:pre[i] = A1 ^ A2 ^ … ^ Ai
那么一段区间 [l + 1, r] 的 XOR 是:pre[r] ^ pre[l]
假设我们把数组切成几段:[1 … c1] [c1+1 … c2] [c2+1 … c3] … [ck+1 … N]
这些切分点是:c1, c2, c3, …
每一段的 XOR 分别是:
1 | pre[c1] ^ pre[0] |
因为:pre[0] = 0
第一段 XOR 就是 pre[c1]
如果第一段 XOR 必须是 mask 的子集,那么 pre[c1] 也必须是 mask 的子集。
再看前两段整体的 XOR pre[c2]
它等于第一段 XOR 再异或第二段 XOR。
如果每一段 XOR 都只包含 mask 中的位,那么它们异或起来也只会包含 mask 中的位。
所以 pre[c2] 也必须是 mask 的子集。
继续往后推,所有被选中的切分点 c 都必须满足 pre[c] 是 mask 的子集
也就是 (pre[c] & mask) == pre[c]
同时最后的 pre[N] 也必须是 mask 的子集。
当然,反过来也成立
如果两个切分点 x 和 y 满足:
1 | pre[x] 是 mask 的子集 |
那么区间 (x, y] 的 XOR 是 pre[y] ^ pre[x]
因为 pre[x] 和 pre[y] 都没有 mask 之外的位,所以它们异或之后也不会出现 mask 之外的位。
也就是说:
只要选中的切分点的前缀异或都是
mask的子集,那么切出来的每一段 XOR 也一定是mask的子集。
所以对于一个固定的 mask,问题就变成:
有多少个位置
i可以作为切分点,使得pre[i]是mask的子集?
注意切分点只能在数组内部,所以只看 i = 1, 2, …, N - 1
当然,如果要切成 M 段,需要选:M - 1个内部切分点。
对于一个固定的 mask,假设有 cnt[mask] 个内部位置 i 满足 pre[i] 是 mask 的子集
那么只要:cnt[mask] >= M - 1 就可以切成 M 段。
同时还要保证整个数组的异或 pre[N] 也是 mask 的子集。否则最后终点不合法,无论怎么切都不行。
所以固定 mask 的可行条件是:
1 | pre[N] 是 mask 的子集 |
现在我们再看一遍题干
原题是:
1 | 对每个 M,求最小划分代价 |
现在变成:
1 | 对每个 M,找最小的 mask, |
其中:
1 | cnt[mask] = 内部前缀异或 pre[i] 中,有多少个是 mask 的子集 |
这一步就是整道题的核心转化。
cnt[mask] 怎么快速求?
我们先统计每种前缀异或出现了多少次。
设 freq[x] = 内部位置中,pre[i] == x 的数量
那么 cnt[mask] = sum(freq[x]),其中 x 是 mask 的子集
也就是:cnt[mask] = 所有子集 x ⊆ mask 的 freq[x] 之和
这正好是 SOS DP 的经典形式。
初始化:cnt[mask] = freq[mask]
然后做 SOS DP:
1 | for (int bit = 0; bit < B; bit++) { |
做完之后:cnt[mask] = 所有 x ⊆ mask 的 freq[x] 之和,也就是这个 mask 能提供多少个合法切分点。
对于每个 mask,我们已经知道它能提供 cnt[mask] 个合法切分点。
如果它能提供 c 个切分点,那么它可以支持:1 段,2 段,…,c + 1 段,因为切成 M 段只需要 M - 1 个切分点。
所以我们令:best[k] = 至少能提供 k 个切分点时的最小 mask
那么答案就是:ans[M] = best[M - 1]
具体做法是:
先枚举所有 mask:
1 | c = cnt[mask] |
这里先表示“恰好有 c 个合法切分点”的最小 mask。
但是我们真正需要的是“至少有 k 个合法切分点”,所以再做一遍后缀最小值:
1 | for (int k = N - 2; k >= 0; k--) { |
这样 best[k] 就表示:cnt[mask] >= k 的所有 mask 中,最小的 mask
最后输出:best[0], best[1], best[2], …, best[N - 1]
分别对应:M = 1, 2, 3, …, N
而后,我们稍微总结一下
这题因为 N 很大,不能直接分段 DP;但 Ai ≤ 10^6,说明数值只有 20 位左右,可以从二进制状态入手。
由于最终代价是所有段 XOR 的 OR,所以如果我们枚举一个候选 mask,就可以把问题变成:能不能切出若干段,使得每一段的 XOR 都只包含 mask 中的位。
用前缀异或表示区间 XOR 后,可以发现:一段 XOR 合法等价于相邻切分点的前缀异或都在 mask 内。因此,对于固定的 mask,所有合法切分点就是那些满足 pre[i] ⊆ mask 的内部位置。
如果这样的内部位置有 cnt[mask] 个,那么这个 mask 最多可以切成 cnt[mask] + 1 段。要切成 M 段,只需要满足 cnt[mask] >= M - 1。同时整个数组的异或 pre[N] 也必须是 mask 的子集。
于是问题变成:对每个 mask,求有多少个内部前缀异或是它的子集。这个就是标准 SOS DP:
1 | cnt[mask] = sum freq[sub], sub ⊆ mask |
最后枚举所有合法 mask,根据它能提供的切分点数量更新答案,再做后缀最小值,就能得到所有 M 的答案。
1 | const int B = 20; |
相关文章
[[4-1-Linear-DP]]

