算法竞赛 Python 常见算法模板整理
本文对应原
常见算法模板整理.md的算法范围,将常用模板改写为 Python 竞赛版本。模板默认使用 0 下标;树状数组、部分区间结构会注明 1 下标。代码以可直接粘贴和快速修改为目标。
零、通用准备
import sys
from collections import deque, defaultdict, Counter
from heapq import heappush, heappop, heapify
from bisect import bisect_left, bisect_right
from functools import lru_cache, cache
from math import gcd, lcm, isqrt, inf
input = sys.stdin.readline
sys.setrecursionlimit(1_000_000)
INF = 10**30
MOD = 10**9 + 7
图常用建法:
n, m = map(int, input().split())
g = [[] for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
u -= 1
v -= 1
g[u].append((v, w))
g[v].append((u, w))
一、基础算法与技巧
1. 整数二分
找第一个满足条件的位置,闭区间 [lo, hi]:
def first_true(lo, hi, check):
while lo < hi:
mid = (lo + hi) // 2
if check(mid):
hi = mid
else:
lo = mid + 1
return lo
找最后一个满足条件的位置:
def last_true(lo, hi, check):
while lo < hi:
mid = (lo + hi + 1) // 2
if check(mid):
lo = mid
else:
hi = mid - 1
return lo
有序数组二分优先用标准库:
i = bisect_left(a, x) # 第一个 >= x
j = bisect_right(a, x) # 第一个 > x
exists = i < len(a) and a[i] == x
cnt_x = j - i
2. 浮点数二分
适合单调函数求根、求最小可行半径等精度题。
def cube_root(n):
lo, hi = -10000.0, 10000.0
for _ in range(100):
mid = (lo + hi) / 2
if mid * mid * mid >= n:
hi = mid
else:
lo = mid
return lo
print(f"{cube_root(float(input())):.6f}")
3. 三分查找
适合单峰/单谷函数极值。下面是单谷函数求最小值。
def ternary_search(lo, hi, f):
for _ in range(120):
m1 = lo + (hi - lo) / 3
m2 = hi - (hi - lo) / 3
if f(m1) < f(m2):
hi = m2
else:
lo = m1
x = (lo + hi) / 2
return x, f(x)
4. 双指针
有序数组两数之和:
def two_sum_sorted(a, target):
l, r = 0, len(a) - 1
while l < r:
s = a[l] + a[r]
if s == target:
return l, r
if s < target:
l += 1
else:
r -= 1
return -1, -1
最短满足条件的连续子数组,例:和至少为 target,数组元素非负。
def min_len_sum_at_least(a, target):
ans = 10**18
cur = 0
l = 0
for r, x in enumerate(a):
cur += x
while cur >= target:
ans = min(ans, r - l + 1)
cur -= a[l]
l += 1
return 0 if ans == 10**18 else ans
最长无重复子串:
def longest_unique_substring(s):
last = {}
l = 0
ans = 0
for r, ch in enumerate(s):
if ch in last and last[ch] >= l:
l = last[ch] + 1
last[ch] = r
ans = max(ans, r - l + 1)
return ans
5. 单调队列:滑动窗口最值
每个长度为 k 的窗口最大值:
def sliding_window_max(a, k):
q = deque() # 存下标,队内对应值单调递减
ans = []
for i, x in enumerate(a):
while q and a[q[-1]] <= x:
q.pop()
q.append(i)
if q[0] <= i - k:
q.popleft()
if i >= k - 1:
ans.append(a[q[0]])
return ans
最小值只需把 <= 改成 >=。
6. 前缀和
一维前缀和:
def prefix_sum(a):
pre = [0]
for x in a:
pre.append(pre[-1] + x)
return pre
pre = prefix_sum(a)
sum_l_r = pre[r + 1] - pre[l] # 0 下标闭区间 [l, r]
二维前缀和:
def prefix_sum_2d(mat):
n, m = len(mat), len(mat[0])
pre = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(n):
row_sum = 0
for j in range(m):
row_sum += mat[i][j]
pre[i + 1][j + 1] = pre[i][j + 1] + row_sum
return pre
def rect_sum(pre, x1, y1, x2, y2):
# 查询 0 下标闭矩形 [(x1,y1), (x2,y2)]
return pre[x2 + 1][y2 + 1] - pre[x1][y2 + 1] - pre[x2 + 1][y1] + pre[x1][y1]
7. 差分
一维区间加:
def range_add(n, ops):
diff = [0] * (n + 1)
for l, r, v in ops:
diff[l] += v
if r + 1 < n:
diff[r + 1] -= v
a = [0] * n
cur = 0
for i in range(n):
cur += diff[i]
a[i] = cur
return a
二维矩形加:
def range_add_2d(n, m, ops):
diff = [[0] * (m + 2) for _ in range(n + 2)]
for x1, y1, x2, y2, v in ops:
diff[x1][y1] += v
diff[x2 + 1][y1] -= v
diff[x1][y2 + 1] -= v
diff[x2 + 1][y2 + 1] += v
a = [[0] * m for _ in range(n)]
for i in range(n):
for j in range(m):
if i:
diff[i][j] += diff[i - 1][j]
if j:
diff[i][j] += diff[i][j - 1]
if i and j:
diff[i][j] -= diff[i - 1][j - 1]
a[i][j] = diff[i][j]
return a
8. 离散化
vals = sorted(set(a))
def rank(x):
return bisect_left(vals, x) # 0 下标排名
compressed = [rank(x) for x in a]
如果后续要用树状数组,常用 1 下标:
idx = bisect_left(vals, x) + 1
9. 位运算子集枚举
枚举 n 个元素的所有子集:
for mask in range(1 << n):
subset = []
for i in range(n):
if (mask >> i) & 1:
subset.append(i)
枚举一个集合 mask 的所有非空子集:
sub = mask
while sub:
# 使用 sub
sub = (sub - 1) & mask
包含空集:
sub = mask
while True:
# 使用 sub
if sub == 0:
break
sub = (sub - 1) & mask
10. 快速选择:第 K 小
k 为 0 下标,平均 O(n)。需要稳定最坏复杂度时用排序。
def kth_smallest(a, k):
a = a[:]
l, r = 0, len(a) - 1
while True:
pivot = a[(l + r) // 2]
lt, i, gt = l, l, r
while i <= gt:
if a[i] < pivot:
a[lt], a[i] = a[i], a[lt]
lt += 1
i += 1
elif a[i] > pivot:
a[i], a[gt] = a[gt], a[i]
gt -= 1
else:
i += 1
if k < lt:
r = lt - 1
elif k > gt:
l = gt + 1
else:
return a[k]
二、高级数据结构
11. 单调栈
右边第一个更大元素:
def next_greater(a):
n = len(a)
ans = [-1] * n
st = [] # 存还没找到答案的下标,值单调递减
for i, x in enumerate(a):
while st and x > a[st[-1]]:
ans[st.pop()] = x
st.append(i)
return ans
如果要距离,把赋值改为 ans[j] = i - j。
12. 并查集 DSU
class DSU:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n
self.count = n
def find(self, x):
while self.parent[x] != x:
self.parent[x] = self.parent[self.parent[x]]
x = self.parent[x]
return x
def union(self, a, b):
ra, rb = self.find(a), self.find(b)
if ra == rb:
return False
if self.size[ra] < self.size[rb]:
ra, rb = rb, ra
self.parent[rb] = ra
self.size[ra] += self.size[rb]
self.count -= 1
return True
def same(self, a, b):
return self.find(a) == self.find(b)
13. 树状数组 Fenwick Tree
1 下标,支持单点加、前缀和、区间和。
class Fenwick:
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1)
def add(self, i, delta):
while i <= self.n:
self.tree[i] += delta
i += i & -i
def sum(self, i):
res = 0
while i > 0:
res += self.tree[i]
i -= i & -i
return res
def range_sum(self, l, r):
return self.sum(r) - self.sum(l - 1)
树状数组上二分,找最小 idx 使前缀和 >= k,要求所有频次非负:
def kth(bit, k):
idx = 0
step = 1 << (bit.n.bit_length() - 1)
while step:
nxt = idx + step
if nxt <= bit.n and bit.tree[nxt] < k:
idx = nxt
k -= bit.tree[nxt]
step >>= 1
return idx + 1
14. 线段树:区间加、区间求和
0 下标接口,闭区间 [l, r]。
class SegmentTree:
def __init__(self, a):
self.n = len(a)
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n)
self._build(1, 0, self.n - 1, a)
def _build(self, p, l, r, a):
if l == r:
self.tree[p] = a[l]
return
m = (l + r) // 2
self._build(p * 2, l, m, a)
self._build(p * 2 + 1, m + 1, r, a)
self.tree[p] = self.tree[p * 2] + self.tree[p * 2 + 1]
def _apply(self, p, l, r, v):
self.tree[p] += v * (r - l + 1)
self.lazy[p] += v
def _push(self, p, l, r):
if self.lazy[p] and l != r:
m = (l + r) // 2
v = self.lazy[p]
self._apply(p * 2, l, m, v)
self._apply(p * 2 + 1, m + 1, r, v)
self.lazy[p] = 0
def add(self, ql, qr, v):
self._add(1, 0, self.n - 1, ql, qr, v)
def _add(self, p, l, r, ql, qr, v):
if ql <= l and r <= qr:
self._apply(p, l, r, v)
return
self._push(p, l, r)
m = (l + r) // 2
if ql <= m:
self._add(p * 2, l, m, ql, qr, v)
if qr > m:
self._add(p * 2 + 1, m + 1, r, ql, qr, v)
self.tree[p] = self.tree[p * 2] + self.tree[p * 2 + 1]
def query(self, ql, qr):
return self._query(1, 0, self.n - 1, ql, qr)
def _query(self, p, l, r, ql, qr):
if ql <= l and r <= qr:
return self.tree[p]
self._push(p, l, r)
m = (l + r) // 2
res = 0
if ql <= m:
res += self._query(p * 2, l, m, ql, qr)
if qr > m:
res += self._query(p * 2 + 1, m + 1, r, ql, qr)
return res
15. ST 表 Sparse Table
适合静态区间 min/max/gcd。函数必须满足幂等性。
class SparseTable:
def __init__(self, a, func=max):
self.func = func
self.n = len(a)
self.st = [a[:]]
k = 1
while (1 << k) <= self.n:
prev = self.st[k - 1]
half = 1 << (k - 1)
cur = [func(prev[i], prev[i + half]) for i in range(self.n - (1 << k) + 1)]
self.st.append(cur)
k += 1
def query(self, l, r):
k = (r - l + 1).bit_length() - 1
return self.func(self.st[k][l], self.st[k][r - (1 << k) + 1])
16. 字典树 Trie
字符串前缀树:
class Trie:
END = "#"
def __init__(self):
self.root = {}
def insert(self, word):
node = self.root
for ch in word:
node = node.setdefault(ch, {})
node[self.END] = node.get(self.END, 0) + 1
def count(self, word):
node = self.root
for ch in word:
if ch not in node:
return 0
node = node[ch]
return node.get(self.END, 0)
def starts_with(self, prefix):
node = self.root
for ch in prefix:
if ch not in node:
return False
node = node[ch]
return True
01-Trie,求最大异或:
class BinaryTrie:
def __init__(self, max_bit=30):
self.ch = [[-1, -1]]
self.max_bit = max_bit
def insert(self, x):
u = 0
for b in range(self.max_bit, -1, -1):
bit = (x >> b) & 1
if self.ch[u][bit] == -1:
self.ch[u][bit] = len(self.ch)
self.ch.append([-1, -1])
u = self.ch[u][bit]
def query_max_xor(self, x):
u = 0
ans = 0
for b in range(self.max_bit, -1, -1):
bit = (x >> b) & 1
want = bit ^ 1
if self.ch[u][want] != -1:
ans |= 1 << b
u = self.ch[u][want]
else:
u = self.ch[u][bit]
return ans
三、图论算法
17. Dijkstra 堆优化最短路
适合非负边权。
def dijkstra(n, g, s):
dist = [INF] * n
dist[s] = 0
pq = [(0, s)]
while pq:
d, u = heappop(pq)
if d != dist[u]:
continue
for v, w in g[u]:
nd = d + w
if nd < dist[v]:
dist[v] = nd
heappush(pq, (nd, v))
return dist
18. SPFA 与负环判定
能处理负边,最坏复杂度高;除非题目明确适合,否则优先 Bellman-Ford 或势能优化。
def spfa(n, g, s):
dist = [INF] * n
inq = [False] * n
cnt = [0] * n
dist[s] = 0
q = deque([s])
inq[s] = True
while q:
u = q.popleft()
inq[u] = False
for v, w in g[u]:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w
cnt[v] = cnt[u] + 1
if cnt[v] >= n:
return None # 存在从 s 可达的负环
if not inq[v]:
q.append(v)
inq[v] = True
return dist
19. Floyd-Warshall 多源最短路
适合 n 较小的稠密图,复杂度 O(n^3)。
def floyd(dist):
n = len(dist)
for k in range(n):
dk = dist[k]
for i in range(n):
dik = dist[i][k]
if dik == INF:
continue
di = dist[i]
for j in range(n):
nd = dik + dk[j]
if nd < di[j]:
di[j] = nd
return dist
初始化:
dist = [[INF] * n for _ in range(n)]
for i in range(n):
dist[i][i] = 0
for u, v, w in edges:
dist[u][v] = min(dist[u][v], w)
20. Kruskal 最小生成树
def kruskal(n, edges):
# edges: [(w, u, v), ...]
dsu = DSU(n)
total = 0
chosen = 0
for w, u, v in sorted(edges):
if dsu.union(u, v):
total += w
chosen += 1
if chosen == n - 1:
break
return total if chosen == n - 1 else None
21. 拓扑排序 Kahn
def topo_sort(n, g):
indeg = [0] * n
for u in range(n):
for v in g[u]:
indeg[v] += 1
q = deque(i for i in range(n) if indeg[i] == 0)
order = []
while q:
u = q.popleft()
order.append(u)
for v in g[u]:
indeg[v] -= 1
if indeg[v] == 0:
q.append(v)
return order if len(order) == n else None
22. Tarjan 强连通分量 SCC
有向图,返回每个强连通分量。
def tarjan_scc(g):
n = len(g)
dfn = [0] * n
low = [0] * n
in_st = [False] * n
st = []
timer = 0
comps = []
def dfs(u):
nonlocal timer
timer += 1
dfn[u] = low[u] = timer
st.append(u)
in_st[u] = True
for v in g[u]:
if not dfn[v]:
dfs(v)
low[u] = min(low[u], low[v])
elif in_st[v]:
low[u] = min(low[u], dfn[v])
if low[u] == dfn[u]:
comp = []
while True:
x = st.pop()
in_st[x] = False
comp.append(x)
if x == u:
break
comps.append(comp)
for i in range(n):
if not dfn[i]:
dfs(i)
return comps
23. Tarjan 桥与割点
无向图,邻接表需要带边编号:g[u].append((v, eid))。
def bridges_cutpoints(n, g):
dfn = [0] * n
low = [0] * n
is_cut = [False] * n
bridges = []
timer = 0
def dfs(u, peid):
nonlocal timer
timer += 1
dfn[u] = low[u] = timer
child = 0
for v, eid in g[u]:
if not dfn[v]:
child += 1
dfs(v, eid)
low[u] = min(low[u], low[v])
if low[v] > dfn[u]:
bridges.append((u, v))
if peid != -1 and low[v] >= dfn[u]:
is_cut[u] = True
elif eid != peid:
low[u] = min(low[u], dfn[v])
if peid == -1 and child > 1:
is_cut[u] = True
for i in range(n):
if not dfn[i]:
dfs(i, -1)
cuts = [i for i, ok in enumerate(is_cut) if ok]
return bridges, cuts
24. 二分图判定
def is_bipartite(g):
n = len(g)
color = [-1] * n
for s in range(n):
if color[s] != -1:
continue
color[s] = 0
q = deque([s])
while q:
u = q.popleft()
for v in g[u]:
if color[v] == -1:
color[v] = color[u] ^ 1
q.append(v)
elif color[v] == color[u]:
return False, color
return True, color
25. 匈牙利算法:二分图最大匹配
左部点 0..n_left-1,右部点 0..n_right-1,g[u] 存左点 u 能连到的右点。
def hungarian(g, n_left, n_right):
match_r = [-1] * n_right
def dfs(u, vis):
for v in g[u]:
if vis[v]:
continue
vis[v] = True
if match_r[v] == -1 or dfs(match_r[v], vis):
match_r[v] = u
return True
return False
ans = 0
for u in range(n_left):
if dfs(u, [False] * n_right):
ans += 1
return ans, match_r
四、树上问题
26. 树的直径:两次 DFS/BFS
带权树,g[u] = [(v, w), ...]。
def tree_diameter(n, g):
def farthest(src):
dist = [-1] * n
dist[src] = 0
st = [src]
while st:
u = st.pop()
for v, w in g[u]:
if dist[v] == -1:
dist[v] = dist[u] + w
st.append(v)
node = max(range(n), key=dist.__getitem__)
return node, dist
a, _ = farthest(0)
b, dist = farthest(a)
return dist[b], a, b
27. 树形 DP 求直径
def diameter_dp(n, g):
ans = 0
def dfs(u, p):
nonlocal ans
best1 = best2 = 0
for v, w in g[u]:
if v == p:
continue
d = dfs(v, u) + w
if d > best1:
best1, best2 = d, best1
elif d > best2:
best2 = d
ans = max(ans, best1 + best2)
return best1
dfs(0, -1)
return ans
28. 树的重心
def tree_centroids(n, g):
size = [0] * n
centroids = []
def dfs(u, p):
size[u] = 1
mx = 0
for v in g[u]:
if v == p:
continue
dfs(v, u)
size[u] += size[v]
mx = max(mx, size[v])
mx = max(mx, n - size[u])
if mx <= n // 2:
centroids.append(u)
dfs(0, -1)
return centroids
29. LCA 倍增
无权树邻接表 g[u] = [v1, v2, ...]。
class LCA:
def __init__(self, g, root=0):
self.n = len(g)
self.LOG = self.n.bit_length()
self.depth = [0] * self.n
self.up = [[0] * self.n for _ in range(self.LOG)]
st = [(root, root)]
order = []
self.up[0][root] = root
while st:
u, p = st.pop()
order.append((u, p))
for v in g[u]:
if v == p:
continue
self.depth[v] = self.depth[u] + 1
self.up[0][v] = u
st.append((v, u))
for k in range(1, self.LOG):
prev = self.up[k - 1]
cur = self.up[k]
for v in range(self.n):
cur[v] = prev[prev[v]]
def kth_ancestor(self, u, k):
bit = 0
while k:
if k & 1:
u = self.up[bit][u]
k >>= 1
bit += 1
return u
def lca(self, a, b):
if self.depth[a] < self.depth[b]:
a, b = b, a
a = self.kth_ancestor(a, self.depth[a] - self.depth[b])
if a == b:
return a
for k in range(self.LOG - 1, -1, -1):
if self.up[k][a] != self.up[k][b]:
a = self.up[k][a]
b = self.up[k][b]
return self.up[0][a]
def distance_edges(self, a, b):
c = self.lca(a, b)
return self.depth[a] + self.depth[b] - 2 * self.depth[c]
30. 树链剖分 HLD
将树上路径拆成若干 DFS 序连续区间,用于配合线段树/树状数组。
class HLD:
def __init__(self, g, root=0):
self.g = g
self.n = len(g)
self.parent = [-1] * self.n
self.depth = [0] * self.n
self.size = [0] * self.n
self.heavy = [-1] * self.n
self.head = [0] * self.n
self.pos = [0] * self.n
self.cur = 0
self._dfs1(root, -1)
self._dfs2(root, root)
def _dfs1(self, u, p):
self.parent[u] = p
self.size[u] = 1
max_size = 0
for v in self.g[u]:
if v == p:
continue
self.depth[v] = self.depth[u] + 1
self._dfs1(v, u)
self.size[u] += self.size[v]
if self.size[v] > max_size:
max_size = self.size[v]
self.heavy[u] = v
def _dfs2(self, u, h):
self.head[u] = h
self.pos[u] = self.cur
self.cur += 1
if self.heavy[u] != -1:
self._dfs2(self.heavy[u], h)
for v in self.g[u]:
if v != self.parent[u] and v != self.heavy[u]:
self._dfs2(v, v)
def path_segments(self, a, b):
segs = []
while self.head[a] != self.head[b]:
if self.depth[self.head[a]] < self.depth[self.head[b]]:
a, b = b, a
segs.append((self.pos[self.head[a]], self.pos[a]))
a = self.parent[self.head[a]]
if self.depth[a] > self.depth[b]:
a, b = b, a
segs.append((self.pos[a], self.pos[b]))
return segs
def subtree_segment(self, u):
return self.pos[u], self.pos[u] + self.size[u] - 1
五、动态规划核心模型
31. 01 背包
每个物品最多选一次。
def knapsack_01(items, W):
dp = [0] * (W + 1)
for w, v in items:
for c in range(W, w - 1, -1):
dp[c] = max(dp[c], dp[c - w] + v)
return dp[W]
32. 完全背包
每个物品可选无限次。
def knapsack_complete(items, W):
dp = [0] * (W + 1)
for w, v in items:
for c in range(w, W + 1):
dp[c] = max(dp[c], dp[c - w] + v)
return dp[W]
33. 多重背包:二进制拆分
def knapsack_multiple(items, W):
# items: [(w, v, cnt), ...]
bundles = []
for w, v, cnt in items:
k = 1
while k <= cnt:
bundles.append((w * k, v * k))
cnt -= k
k <<= 1
if cnt:
bundles.append((w * cnt, v * cnt))
return knapsack_01(bundles, W)
34. 最长递增子序列 LIS
严格递增:
def lis_length(a):
d = []
for x in a:
i = bisect_left(d, x)
if i == len(d):
d.append(x)
else:
d[i] = x
return len(d)
非递减把 bisect_left 改成 bisect_right。
35. 最长公共子序列 LCS
def lcs_length(a, b):
n, m = len(a), len(b)
prev = [0] * (m + 1)
for i in range(1, n + 1):
cur = [0] * (m + 1)
ai = a[i - 1]
for j in range(1, m + 1):
if ai == b[j - 1]:
cur[j] = prev[j - 1] + 1
else:
cur[j] = max(prev[j], cur[j - 1])
prev = cur
return prev[m]
36. 区间 DP:石子合并
def stone_merge(a):
n = len(a)
pre = [0]
for x in a:
pre.append(pre[-1] + x)
dp = [[0] * n for _ in range(n)]
for length in range(2, n + 1):
for l in range(n - length + 1):
r = l + length - 1
total = pre[r + 1] - pre[l]
dp[l][r] = min(dp[l][k] + dp[k + 1][r] + total for k in range(l, r))
return dp[0][n - 1]
37. 状态压缩 DP:TSP
从 0 出发访问所有点,最后不强制回到 0。
def tsp(dist):
n = len(dist)
dp = [[INF] * n for _ in range(1 << n)]
dp[1][0] = 0
for mask in range(1 << n):
for u in range(n):
if dp[mask][u] == INF:
continue
for v in range(n):
if not (mask >> v) & 1:
nm = mask | (1 << v)
dp[nm][v] = min(dp[nm][v], dp[mask][u] + dist[u][v])
return min(dp[(1 << n) - 1])
38. 数位 DP
例:统计 [0, n] 中数位和对 mod 取余为 0 的正整数个数。
def count_digit_sum_mod_zero(n, mod):
if n <= 0:
return 0
s = str(n)
@cache
def dfs(pos, tight, started, rem):
if pos == len(s):
return int(started and rem == 0)
up = int(s[pos]) if tight else 9
ans = 0
for d in range(up + 1):
ntight = tight and d == up
if not started and d == 0:
ans += dfs(pos + 1, ntight, False, 0)
else:
ans += dfs(pos + 1, ntight, True, (rem + d) % mod)
return ans
return dfs(0, True, False, 0)
六、数学与数论
39. 快速幂
Python 优先用内置:
pow(a, b, mod)
手写模板:
def qpow(a, b, mod):
res = 1
a %= mod
while b:
if b & 1:
res = res * a % mod
a = a * a % mod
b >>= 1
return res
40. 矩阵快速幂
def mat_mul(A, B, mod):
n, m, p = len(A), len(B), len(B[0])
C = [[0] * p for _ in range(n)]
for i in range(n):
for k in range(m):
if A[i][k]:
aik = A[i][k]
for j in range(p):
C[i][j] = (C[i][j] + aik * B[k][j]) % mod
return C
def mat_pow(A, e, mod):
n = len(A)
res = [[0] * n for _ in range(n)]
for i in range(n):
res[i][i] = 1
while e:
if e & 1:
res = mat_mul(res, A, mod)
A = mat_mul(A, A, mod)
e >>= 1
return res
41. 素数筛
埃氏筛:
def sieve(n):
is_prime = [True] * (n + 1)
if n >= 0:
is_prime[0] = False
if n >= 1:
is_prime[1] = False
for i in range(2, isqrt(n) + 1):
if is_prime[i]:
for j in range(i * i, n + 1, i):
is_prime[j] = False
primes = [i for i in range(2, n + 1) if is_prime[i]]
return primes, is_prime
欧拉线性筛:
def linear_sieve(n):
primes = []
is_comp = [False] * (n + 1)
for i in range(2, n + 1):
if not is_comp[i]:
primes.append(i)
for p in primes:
if i * p > n:
break
is_comp[i * p] = True
if i % p == 0:
break
return primes, is_comp
42. 扩展欧几里得
def exgcd(a, b):
if b == 0:
return a, 1, 0
g, x1, y1 = exgcd(b, a % b)
x = y1
y = x1 - (a // b) * y1
return g, x, y
43. 模逆元与模除
模数为质数:
inv_a = pow(a, MOD - 2, MOD)
ans = b * inv_a % MOD
一般互质情况:
def mod_inv(a, mod):
g, x, _ = exgcd(a, mod)
if g != 1:
return None
return x % mod
Python 3.8+ 可直接:
inv_a = pow(a, -1, mod)
44. 组合数取模预处理
适合固定质数模数,多次查询 C(n, k)。
def build_comb(N, mod):
fac = [1] * (N + 1)
ifac = [1] * (N + 1)
for i in range(1, N + 1):
fac[i] = fac[i - 1] * i % mod
ifac[N] = pow(fac[N], mod - 2, mod)
for i in range(N, 0, -1):
ifac[i - 1] = ifac[i] * i % mod
def C(n, k):
if k < 0 or k > n:
return 0
return fac[n] * ifac[k] % mod * ifac[n - k] % mod
return C, fac, ifac
七、字符串高级算法
45. 字符串哈希
双模哈希,查询子串 [l, r)。
class StringHash:
MOD1 = 1_000_000_007
MOD2 = 1_000_000_009
BASE = 911382323
def __init__(self, s):
n = len(s)
self.p1 = [1] * (n + 1)
self.p2 = [1] * (n + 1)
self.h1 = [0] * (n + 1)
self.h2 = [0] * (n + 1)
for i, ch in enumerate(s):
x = ord(ch)
self.p1[i + 1] = self.p1[i] * self.BASE % self.MOD1
self.p2[i + 1] = self.p2[i] * self.BASE % self.MOD2
self.h1[i + 1] = (self.h1[i] * self.BASE + x) % self.MOD1
self.h2[i + 1] = (self.h2[i] * self.BASE + x) % self.MOD2
def get(self, l, r):
x1 = (self.h1[r] - self.h1[l] * self.p1[r - l]) % self.MOD1
x2 = (self.h2[r] - self.h2[l] * self.p2[r - l]) % self.MOD2
return x1, x2
46. KMP 字符串匹配
def prefix_function(p):
n = len(p)
pi = [0] * n
for i in range(1, n):
j = pi[i - 1]
while j and p[i] != p[j]:
j = pi[j - 1]
if p[i] == p[j]:
j += 1
pi[i] = j
return pi
def kmp_search(text, pattern):
if not pattern:
return list(range(len(text) + 1))
pi = prefix_function(pattern)
ans = []
j = 0
for i, ch in enumerate(text):
while j and ch != pattern[j]:
j = pi[j - 1]
if ch == pattern[j]:
j += 1
if j == len(pattern):
ans.append(i - j + 1)
j = pi[j - 1]
return ans
47. Manacher 最长回文子串
def manacher(s):
t = "^#" + "#".join(s) + "#$"
p = [0] * len(t)
center = right = 0
best_len = best_center = 0
for i in range(1, len(t) - 1):
mirror = 2 * center - i
if i < right:
p[i] = min(right - i, p[mirror])
while t[i + p[i] + 1] == t[i - p[i] - 1]:
p[i] += 1
if i + p[i] > right:
center, right = i, i + p[i]
if p[i] > best_len:
best_len, best_center = p[i], i
start = (best_center - best_len) // 2
return s[start:start + best_len], best_len
48. AC 自动机
多模式串匹配。
class AhoCorasick:
def __init__(self):
self.next = [{}]
self.fail = [0]
self.out = [[]]
def add(self, word, idx):
u = 0
for ch in word:
if ch not in self.next[u]:
self.next[u][ch] = len(self.next)
self.next.append({})
self.fail.append(0)
self.out.append([])
u = self.next[u][ch]
self.out[u].append(idx)
def build(self):
q = deque()
for v in self.next[0].values():
q.append(v)
while q:
u = q.popleft()
for ch, v in self.next[u].items():
f = self.fail[u]
while f and ch not in self.next[f]:
f = self.fail[f]
self.fail[v] = self.next[f].get(ch, 0)
self.out[v].extend(self.out[self.fail[v]])
q.append(v)
def search(self, text):
res = []
u = 0
for i, ch in enumerate(text):
while u and ch not in self.next[u]:
u = self.fail[u]
u = self.next[u].get(ch, 0)
for idx in self.out[u]:
res.append((i, idx)) # 模式 idx 在 i 位置结束
return res
八、计算几何
49. 二维点/向量
点用 (x, y) 元组。
def sub(a, b):
return a[0] - b[0], a[1] - b[1]
def dot(a, b):
return a[0] * b[0] + a[1] * b[1]
def cross(a, b):
return a[0] * b[1] - a[1] * b[0]
def cross3(o, a, b):
return cross(sub(a, o), sub(b, o))
def norm2(a):
return a[0] * a[0] + a[1] * a[1]
50. 线段相交
整数坐标可直接精确判断。
def on_segment(a, b, p):
return cross3(a, b, p) == 0 and min(a[0], b[0]) <= p[0] <= max(a[0], b[0]) and min(a[1], b[1]) <= p[1] <= max(a[1], b[1])
def segments_intersect(a, b, c, d):
c1 = cross3(a, b, c)
c2 = cross3(a, b, d)
c3 = cross3(c, d, a)
c4 = cross3(c, d, b)
if c1 == 0 and on_segment(a, b, c):
return True
if c2 == 0 and on_segment(a, b, d):
return True
if c3 == 0 and on_segment(c, d, a):
return True
if c4 == 0 and on_segment(c, d, b):
return True
return c1 * c2 < 0 and c3 * c4 < 0
51. 凸包 Andrew
返回逆时针凸包,不保留边上的共线点。
def convex_hull(points):
pts = sorted(set(points))
if len(pts) <= 1:
return pts
lower = []
for p in pts:
while len(lower) >= 2 and cross3(lower[-2], lower[-1], p) <= 0:
lower.pop()
lower.append(p)
upper = []
for p in reversed(pts):
while len(upper) >= 2 and cross3(upper[-2], upper[-1], p) <= 0:
upper.pop()
upper.append(p)
return lower[:-1] + upper[:-1]
若要保留边上共线点,把 <= 0 改成 < 0。
九、搜索与回溯
52. 网格 BFS
DIR4 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
def grid_bfs(grid, start):
n, m = len(grid), len(grid[0])
sx, sy = start
dist = [[-1] * m for _ in range(n)]
dist[sx][sy] = 0
q = deque([(sx, sy)])
while q:
x, y = q.popleft()
for dx, dy in DIR4:
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] != "#" and dist[nx][ny] == -1:
dist[nx][ny] = dist[x][y] + 1
q.append((nx, ny))
return dist
53. 多源 BFS
def multi_source_bfs(grid, sources):
n, m = len(grid), len(grid[0])
dist = [[-1] * m for _ in range(n)]
q = deque()
for x, y in sources:
dist[x][y] = 0
q.append((x, y))
while q:
x, y = q.popleft()
for dx, dy in DIR4:
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] != "#" and dist[nx][ny] == -1:
dist[nx][ny] = dist[x][y] + 1
q.append((nx, ny))
return dist
54. 0-1 BFS
边权只能是 0 或 1。
def zero_one_bfs(n, g, s):
dist = [INF] * n
dist[s] = 0
q = deque([s])
while q:
u = q.popleft()
for v, w in g[u]:
nd = dist[u] + w
if nd < dist[v]:
dist[v] = nd
if w == 0:
q.appendleft(v)
else:
q.append(v)
return dist
55. 双向 BFS
适合无权图、状态空间巨大且可以从终点反向扩展。
def bidirectional_bfs(start, target, neighbors):
if start == target:
return 0
front = {start}
back = {target}
dist_f = {start: 0}
dist_b = {target: 0}
while front and back:
if len(front) > len(back):
front, back = back, front
dist_f, dist_b = dist_b, dist_f
nxt_front = set()
for u in front:
for v in neighbors(u):
if v in dist_f:
continue
if v in dist_b:
return dist_f[u] + 1 + dist_b[v]
dist_f[v] = dist_f[u] + 1
nxt_front.add(v)
front = nxt_front
return -1
56. 状态空间 BFS
状态必须可哈希,如整数、字符串、元组。
def state_bfs(start, is_goal, next_states):
dist = {start: 0}
q = deque([start])
while q:
state = q.popleft()
if is_goal(state):
return dist[state]
for ns in next_states(state):
if ns not in dist:
dist[ns] = dist[state] + 1
q.append(ns)
return -1
57. DFS Flood Fill
递归版:
def flood_fill(grid, sx, sy, old, new):
n, m = len(grid), len(grid[0])
if old == new or grid[sx][sy] != old:
return
def dfs(x, y):
grid[x][y] = new
for dx, dy in DIR4:
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] == old:
dfs(nx, ny)
dfs(sx, sy)
显式栈版:
def flood_fill_iter(grid, sx, sy, old, new):
n, m = len(grid), len(grid[0])
if old == new or grid[sx][sy] != old:
return
st = [(sx, sy)]
grid[sx][sy] = new
while st:
x, y = st.pop()
for dx, dy in DIR4:
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] == old:
grid[nx][ny] = new
st.append((nx, ny))
58. 图 DFS 显式栈
def iterative_dfs(g, start):
n = len(g)
vis = [False] * n
order = []
st = [start]
vis[start] = True
while st:
u = st.pop()
order.append(u)
for v in g[u]:
if not vis[v]:
vis[v] = True
st.append(v)
return order
59. 回溯
全排列模板:
def permute(a):
n = len(a)
used = [False] * n
path = []
ans = []
def dfs():
if len(path) == n:
ans.append(path[:])
return
for i in range(n):
if used[i]:
continue
used[i] = True
path.append(a[i])
dfs()
path.pop()
used[i] = False
dfs()
return ans
有重复元素时跳过同层重复:
def permute_unique(a):
a.sort()
n = len(a)
used = [False] * n
path = []
ans = []
def dfs():
if len(path) == n:
ans.append(path[:])
return
for i in range(n):
if used[i]:
continue
if i and a[i] == a[i - 1] and not used[i - 1]:
continue
used[i] = True
path.append(a[i])
dfs()
path.pop()
used[i] = False
dfs()
return ans
60. 记忆化搜索
@cache
def dfs(state):
if is_terminal(state):
return base_value(state)
best = -INF
for nxt in transitions(state):
best = max(best, gain(state, nxt) + dfs(nxt))
return best
常见要求:
state必须可哈希,列表要转成tuple。- 多组测试时,记得
dfs.cache_clear()。 - 递归深度大时改成递推 DP 或显式栈。
61. 剪枝常用写法
回溯中常见剪枝位置:
def dfs(i, cur):
global_ans = None
# 1. 可行性剪枝:剩余资源已经不可能满足要求
if not feasible(i, cur):
return
# 2. 最优性剪枝:理论上界也不可能超过当前答案
if upper_bound(i, cur) <= best_answer():
return
# 3. 记忆化剪枝:同一状态只保留更优值
key = encode_state(i, cur)
if key in seen and seen[key] >= cur:
return
seen[key] = cur
# 正常搜索
实战顺序通常是:先排序制造单调性,再写可行性剪枝,最后加记忆化或上界剪枝。