yanchang
yanchang
发布于 2026-06-05 / 7 阅读
0
0

Python常见算法模板整理

算法竞赛 Python 常见算法模板整理

本文对应原 常见算法模板整理.md 的算法范围,将常用模板改写为 Python 竞赛版本。模板默认使用 0 下标;树状数组、部分区间结构会注明 1 下标。代码以可直接粘贴和快速修改为目标。

零、通用准备

import sys
from collections import deque, defaultdict, Counter
from heapq import heappush, heappop, heapify
from bisect import bisect_left, bisect_right
from functools import lru_cache, cache
from math import gcd, lcm, isqrt, inf

input = sys.stdin.readline
sys.setrecursionlimit(1_000_000)

INF = 10**30
MOD = 10**9 + 7

图常用建法:

n, m = map(int, input().split())
g = [[] for _ in range(n)]
for _ in range(m):
    u, v, w = map(int, input().split())
    u -= 1
    v -= 1
    g[u].append((v, w))
    g[v].append((u, w))

一、基础算法与技巧

1. 整数二分

找第一个满足条件的位置,闭区间 [lo, hi]

def first_true(lo, hi, check):
    while lo < hi:
        mid = (lo + hi) // 2
        if check(mid):
            hi = mid
        else:
            lo = mid + 1
    return lo

找最后一个满足条件的位置:

def last_true(lo, hi, check):
    while lo < hi:
        mid = (lo + hi + 1) // 2
        if check(mid):
            lo = mid
        else:
            hi = mid - 1
    return lo

有序数组二分优先用标准库:

i = bisect_left(a, x)     # 第一个 >= x
j = bisect_right(a, x)    # 第一个 > x
exists = i < len(a) and a[i] == x
cnt_x = j - i

2. 浮点数二分

适合单调函数求根、求最小可行半径等精度题。

def cube_root(n):
    lo, hi = -10000.0, 10000.0
    for _ in range(100):
        mid = (lo + hi) / 2
        if mid * mid * mid >= n:
            hi = mid
        else:
            lo = mid
    return lo

print(f"{cube_root(float(input())):.6f}")

3. 三分查找

适合单峰/单谷函数极值。下面是单谷函数求最小值。

def ternary_search(lo, hi, f):
    for _ in range(120):
        m1 = lo + (hi - lo) / 3
        m2 = hi - (hi - lo) / 3
        if f(m1) < f(m2):
            hi = m2
        else:
            lo = m1
    x = (lo + hi) / 2
    return x, f(x)

4. 双指针

有序数组两数之和:

def two_sum_sorted(a, target):
    l, r = 0, len(a) - 1
    while l < r:
        s = a[l] + a[r]
        if s == target:
            return l, r
        if s < target:
            l += 1
        else:
            r -= 1
    return -1, -1

最短满足条件的连续子数组,例:和至少为 target,数组元素非负。

def min_len_sum_at_least(a, target):
    ans = 10**18
    cur = 0
    l = 0
    for r, x in enumerate(a):
        cur += x
        while cur >= target:
            ans = min(ans, r - l + 1)
            cur -= a[l]
            l += 1
    return 0 if ans == 10**18 else ans

最长无重复子串:

def longest_unique_substring(s):
    last = {}
    l = 0
    ans = 0
    for r, ch in enumerate(s):
        if ch in last and last[ch] >= l:
            l = last[ch] + 1
        last[ch] = r
        ans = max(ans, r - l + 1)
    return ans

5. 单调队列:滑动窗口最值

每个长度为 k 的窗口最大值:

def sliding_window_max(a, k):
    q = deque()  # 存下标,队内对应值单调递减
    ans = []
    for i, x in enumerate(a):
        while q and a[q[-1]] <= x:
            q.pop()
        q.append(i)
        if q[0] <= i - k:
            q.popleft()
        if i >= k - 1:
            ans.append(a[q[0]])
    return ans

最小值只需把 <= 改成 >=

6. 前缀和

一维前缀和:

def prefix_sum(a):
    pre = [0]
    for x in a:
        pre.append(pre[-1] + x)
    return pre

pre = prefix_sum(a)
sum_l_r = pre[r + 1] - pre[l]  # 0 下标闭区间 [l, r]

二维前缀和:

def prefix_sum_2d(mat):
    n, m = len(mat), len(mat[0])
    pre = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n):
        row_sum = 0
        for j in range(m):
            row_sum += mat[i][j]
            pre[i + 1][j + 1] = pre[i][j + 1] + row_sum
    return pre

def rect_sum(pre, x1, y1, x2, y2):
    # 查询 0 下标闭矩形 [(x1,y1), (x2,y2)]
    return pre[x2 + 1][y2 + 1] - pre[x1][y2 + 1] - pre[x2 + 1][y1] + pre[x1][y1]

7. 差分

一维区间加:

def range_add(n, ops):
    diff = [0] * (n + 1)
    for l, r, v in ops:
        diff[l] += v
        if r + 1 < n:
            diff[r + 1] -= v
    a = [0] * n
    cur = 0
    for i in range(n):
        cur += diff[i]
        a[i] = cur
    return a

二维矩形加:

def range_add_2d(n, m, ops):
    diff = [[0] * (m + 2) for _ in range(n + 2)]
    for x1, y1, x2, y2, v in ops:
        diff[x1][y1] += v
        diff[x2 + 1][y1] -= v
        diff[x1][y2 + 1] -= v
        diff[x2 + 1][y2 + 1] += v

    a = [[0] * m for _ in range(n)]
    for i in range(n):
        for j in range(m):
            if i:
                diff[i][j] += diff[i - 1][j]
            if j:
                diff[i][j] += diff[i][j - 1]
            if i and j:
                diff[i][j] -= diff[i - 1][j - 1]
            a[i][j] = diff[i][j]
    return a

8. 离散化

vals = sorted(set(a))

def rank(x):
    return bisect_left(vals, x)  # 0 下标排名

compressed = [rank(x) for x in a]

如果后续要用树状数组,常用 1 下标:

idx = bisect_left(vals, x) + 1

9. 位运算子集枚举

枚举 n 个元素的所有子集:

for mask in range(1 << n):
    subset = []
    for i in range(n):
        if (mask >> i) & 1:
            subset.append(i)

枚举一个集合 mask 的所有非空子集:

sub = mask
while sub:
    # 使用 sub
    sub = (sub - 1) & mask

包含空集:

sub = mask
while True:
    # 使用 sub
    if sub == 0:
        break
    sub = (sub - 1) & mask

10. 快速选择:第 K 小

k 为 0 下标,平均 O(n)。需要稳定最坏复杂度时用排序。

def kth_smallest(a, k):
    a = a[:]
    l, r = 0, len(a) - 1
    while True:
        pivot = a[(l + r) // 2]
        lt, i, gt = l, l, r
        while i <= gt:
            if a[i] < pivot:
                a[lt], a[i] = a[i], a[lt]
                lt += 1
                i += 1
            elif a[i] > pivot:
                a[i], a[gt] = a[gt], a[i]
                gt -= 1
            else:
                i += 1
        if k < lt:
            r = lt - 1
        elif k > gt:
            l = gt + 1
        else:
            return a[k]

二、高级数据结构

11. 单调栈

右边第一个更大元素:

def next_greater(a):
    n = len(a)
    ans = [-1] * n
    st = []  # 存还没找到答案的下标,值单调递减
    for i, x in enumerate(a):
        while st and x > a[st[-1]]:
            ans[st.pop()] = x
        st.append(i)
    return ans

如果要距离,把赋值改为 ans[j] = i - j

12. 并查集 DSU

class DSU:
    def __init__(self, n):
        self.parent = list(range(n))
        self.size = [1] * n
        self.count = n

    def find(self, x):
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return False
        if self.size[ra] < self.size[rb]:
            ra, rb = rb, ra
        self.parent[rb] = ra
        self.size[ra] += self.size[rb]
        self.count -= 1
        return True

    def same(self, a, b):
        return self.find(a) == self.find(b)

13. 树状数组 Fenwick Tree

1 下标,支持单点加、前缀和、区间和。

class Fenwick:
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 1)

    def add(self, i, delta):
        while i <= self.n:
            self.tree[i] += delta
            i += i & -i

    def sum(self, i):
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

    def range_sum(self, l, r):
        return self.sum(r) - self.sum(l - 1)

树状数组上二分,找最小 idx 使前缀和 >= k,要求所有频次非负:

def kth(bit, k):
    idx = 0
    step = 1 << (bit.n.bit_length() - 1)
    while step:
        nxt = idx + step
        if nxt <= bit.n and bit.tree[nxt] < k:
            idx = nxt
            k -= bit.tree[nxt]
        step >>= 1
    return idx + 1

14. 线段树:区间加、区间求和

0 下标接口,闭区间 [l, r]

class SegmentTree:
    def __init__(self, a):
        self.n = len(a)
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)
        self._build(1, 0, self.n - 1, a)

    def _build(self, p, l, r, a):
        if l == r:
            self.tree[p] = a[l]
            return
        m = (l + r) // 2
        self._build(p * 2, l, m, a)
        self._build(p * 2 + 1, m + 1, r, a)
        self.tree[p] = self.tree[p * 2] + self.tree[p * 2 + 1]

    def _apply(self, p, l, r, v):
        self.tree[p] += v * (r - l + 1)
        self.lazy[p] += v

    def _push(self, p, l, r):
        if self.lazy[p] and l != r:
            m = (l + r) // 2
            v = self.lazy[p]
            self._apply(p * 2, l, m, v)
            self._apply(p * 2 + 1, m + 1, r, v)
            self.lazy[p] = 0

    def add(self, ql, qr, v):
        self._add(1, 0, self.n - 1, ql, qr, v)

    def _add(self, p, l, r, ql, qr, v):
        if ql <= l and r <= qr:
            self._apply(p, l, r, v)
            return
        self._push(p, l, r)
        m = (l + r) // 2
        if ql <= m:
            self._add(p * 2, l, m, ql, qr, v)
        if qr > m:
            self._add(p * 2 + 1, m + 1, r, ql, qr, v)
        self.tree[p] = self.tree[p * 2] + self.tree[p * 2 + 1]

    def query(self, ql, qr):
        return self._query(1, 0, self.n - 1, ql, qr)

    def _query(self, p, l, r, ql, qr):
        if ql <= l and r <= qr:
            return self.tree[p]
        self._push(p, l, r)
        m = (l + r) // 2
        res = 0
        if ql <= m:
            res += self._query(p * 2, l, m, ql, qr)
        if qr > m:
            res += self._query(p * 2 + 1, m + 1, r, ql, qr)
        return res

15. ST 表 Sparse Table

适合静态区间 min/max/gcd。函数必须满足幂等性。

class SparseTable:
    def __init__(self, a, func=max):
        self.func = func
        self.n = len(a)
        self.st = [a[:]]
        k = 1
        while (1 << k) <= self.n:
            prev = self.st[k - 1]
            half = 1 << (k - 1)
            cur = [func(prev[i], prev[i + half]) for i in range(self.n - (1 << k) + 1)]
            self.st.append(cur)
            k += 1

    def query(self, l, r):
        k = (r - l + 1).bit_length() - 1
        return self.func(self.st[k][l], self.st[k][r - (1 << k) + 1])

16. 字典树 Trie

字符串前缀树:

class Trie:
    END = "#"

    def __init__(self):
        self.root = {}

    def insert(self, word):
        node = self.root
        for ch in word:
            node = node.setdefault(ch, {})
        node[self.END] = node.get(self.END, 0) + 1

    def count(self, word):
        node = self.root
        for ch in word:
            if ch not in node:
                return 0
            node = node[ch]
        return node.get(self.END, 0)

    def starts_with(self, prefix):
        node = self.root
        for ch in prefix:
            if ch not in node:
                return False
            node = node[ch]
        return True

01-Trie,求最大异或:

class BinaryTrie:
    def __init__(self, max_bit=30):
        self.ch = [[-1, -1]]
        self.max_bit = max_bit

    def insert(self, x):
        u = 0
        for b in range(self.max_bit, -1, -1):
            bit = (x >> b) & 1
            if self.ch[u][bit] == -1:
                self.ch[u][bit] = len(self.ch)
                self.ch.append([-1, -1])
            u = self.ch[u][bit]

    def query_max_xor(self, x):
        u = 0
        ans = 0
        for b in range(self.max_bit, -1, -1):
            bit = (x >> b) & 1
            want = bit ^ 1
            if self.ch[u][want] != -1:
                ans |= 1 << b
                u = self.ch[u][want]
            else:
                u = self.ch[u][bit]
        return ans

三、图论算法

17. Dijkstra 堆优化最短路

适合非负边权。

def dijkstra(n, g, s):
    dist = [INF] * n
    dist[s] = 0
    pq = [(0, s)]
    while pq:
        d, u = heappop(pq)
        if d != dist[u]:
            continue
        for v, w in g[u]:
            nd = d + w
            if nd < dist[v]:
                dist[v] = nd
                heappush(pq, (nd, v))
    return dist

18. SPFA 与负环判定

能处理负边,最坏复杂度高;除非题目明确适合,否则优先 Bellman-Ford 或势能优化。

def spfa(n, g, s):
    dist = [INF] * n
    inq = [False] * n
    cnt = [0] * n
    dist[s] = 0
    q = deque([s])
    inq[s] = True

    while q:
        u = q.popleft()
        inq[u] = False
        for v, w in g[u]:
            if dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                cnt[v] = cnt[u] + 1
                if cnt[v] >= n:
                    return None  # 存在从 s 可达的负环
                if not inq[v]:
                    q.append(v)
                    inq[v] = True
    return dist

19. Floyd-Warshall 多源最短路

适合 n 较小的稠密图,复杂度 O(n^3)

def floyd(dist):
    n = len(dist)
    for k in range(n):
        dk = dist[k]
        for i in range(n):
            dik = dist[i][k]
            if dik == INF:
                continue
            di = dist[i]
            for j in range(n):
                nd = dik + dk[j]
                if nd < di[j]:
                    di[j] = nd
    return dist

初始化:

dist = [[INF] * n for _ in range(n)]
for i in range(n):
    dist[i][i] = 0
for u, v, w in edges:
    dist[u][v] = min(dist[u][v], w)

20. Kruskal 最小生成树

def kruskal(n, edges):
    # edges: [(w, u, v), ...]
    dsu = DSU(n)
    total = 0
    chosen = 0
    for w, u, v in sorted(edges):
        if dsu.union(u, v):
            total += w
            chosen += 1
            if chosen == n - 1:
                break
    return total if chosen == n - 1 else None

21. 拓扑排序 Kahn

def topo_sort(n, g):
    indeg = [0] * n
    for u in range(n):
        for v in g[u]:
            indeg[v] += 1

    q = deque(i for i in range(n) if indeg[i] == 0)
    order = []
    while q:
        u = q.popleft()
        order.append(u)
        for v in g[u]:
            indeg[v] -= 1
            if indeg[v] == 0:
                q.append(v)

    return order if len(order) == n else None

22. Tarjan 强连通分量 SCC

有向图,返回每个强连通分量。

def tarjan_scc(g):
    n = len(g)
    dfn = [0] * n
    low = [0] * n
    in_st = [False] * n
    st = []
    timer = 0
    comps = []

    def dfs(u):
        nonlocal timer
        timer += 1
        dfn[u] = low[u] = timer
        st.append(u)
        in_st[u] = True

        for v in g[u]:
            if not dfn[v]:
                dfs(v)
                low[u] = min(low[u], low[v])
            elif in_st[v]:
                low[u] = min(low[u], dfn[v])

        if low[u] == dfn[u]:
            comp = []
            while True:
                x = st.pop()
                in_st[x] = False
                comp.append(x)
                if x == u:
                    break
            comps.append(comp)

    for i in range(n):
        if not dfn[i]:
            dfs(i)
    return comps

23. Tarjan 桥与割点

无向图,邻接表需要带边编号:g[u].append((v, eid))

def bridges_cutpoints(n, g):
    dfn = [0] * n
    low = [0] * n
    is_cut = [False] * n
    bridges = []
    timer = 0

    def dfs(u, peid):
        nonlocal timer
        timer += 1
        dfn[u] = low[u] = timer
        child = 0
        for v, eid in g[u]:
            if not dfn[v]:
                child += 1
                dfs(v, eid)
                low[u] = min(low[u], low[v])
                if low[v] > dfn[u]:
                    bridges.append((u, v))
                if peid != -1 and low[v] >= dfn[u]:
                    is_cut[u] = True
            elif eid != peid:
                low[u] = min(low[u], dfn[v])
        if peid == -1 and child > 1:
            is_cut[u] = True

    for i in range(n):
        if not dfn[i]:
            dfs(i, -1)
    cuts = [i for i, ok in enumerate(is_cut) if ok]
    return bridges, cuts

24. 二分图判定

def is_bipartite(g):
    n = len(g)
    color = [-1] * n
    for s in range(n):
        if color[s] != -1:
            continue
        color[s] = 0
        q = deque([s])
        while q:
            u = q.popleft()
            for v in g[u]:
                if color[v] == -1:
                    color[v] = color[u] ^ 1
                    q.append(v)
                elif color[v] == color[u]:
                    return False, color
    return True, color

25. 匈牙利算法:二分图最大匹配

左部点 0..n_left-1,右部点 0..n_right-1g[u] 存左点 u 能连到的右点。

def hungarian(g, n_left, n_right):
    match_r = [-1] * n_right

    def dfs(u, vis):
        for v in g[u]:
            if vis[v]:
                continue
            vis[v] = True
            if match_r[v] == -1 or dfs(match_r[v], vis):
                match_r[v] = u
                return True
        return False

    ans = 0
    for u in range(n_left):
        if dfs(u, [False] * n_right):
            ans += 1
    return ans, match_r

四、树上问题

26. 树的直径:两次 DFS/BFS

带权树,g[u] = [(v, w), ...]

def tree_diameter(n, g):
    def farthest(src):
        dist = [-1] * n
        dist[src] = 0
        st = [src]
        while st:
            u = st.pop()
            for v, w in g[u]:
                if dist[v] == -1:
                    dist[v] = dist[u] + w
                    st.append(v)
        node = max(range(n), key=dist.__getitem__)
        return node, dist

    a, _ = farthest(0)
    b, dist = farthest(a)
    return dist[b], a, b

27. 树形 DP 求直径

def diameter_dp(n, g):
    ans = 0

    def dfs(u, p):
        nonlocal ans
        best1 = best2 = 0
        for v, w in g[u]:
            if v == p:
                continue
            d = dfs(v, u) + w
            if d > best1:
                best1, best2 = d, best1
            elif d > best2:
                best2 = d
        ans = max(ans, best1 + best2)
        return best1

    dfs(0, -1)
    return ans

28. 树的重心

def tree_centroids(n, g):
    size = [0] * n
    centroids = []

    def dfs(u, p):
        size[u] = 1
        mx = 0
        for v in g[u]:
            if v == p:
                continue
            dfs(v, u)
            size[u] += size[v]
            mx = max(mx, size[v])
        mx = max(mx, n - size[u])
        if mx <= n // 2:
            centroids.append(u)

    dfs(0, -1)
    return centroids

29. LCA 倍增

无权树邻接表 g[u] = [v1, v2, ...]

class LCA:
    def __init__(self, g, root=0):
        self.n = len(g)
        self.LOG = self.n.bit_length()
        self.depth = [0] * self.n
        self.up = [[0] * self.n for _ in range(self.LOG)]

        st = [(root, root)]
        order = []
        self.up[0][root] = root
        while st:
            u, p = st.pop()
            order.append((u, p))
            for v in g[u]:
                if v == p:
                    continue
                self.depth[v] = self.depth[u] + 1
                self.up[0][v] = u
                st.append((v, u))

        for k in range(1, self.LOG):
            prev = self.up[k - 1]
            cur = self.up[k]
            for v in range(self.n):
                cur[v] = prev[prev[v]]

    def kth_ancestor(self, u, k):
        bit = 0
        while k:
            if k & 1:
                u = self.up[bit][u]
            k >>= 1
            bit += 1
        return u

    def lca(self, a, b):
        if self.depth[a] < self.depth[b]:
            a, b = b, a
        a = self.kth_ancestor(a, self.depth[a] - self.depth[b])
        if a == b:
            return a
        for k in range(self.LOG - 1, -1, -1):
            if self.up[k][a] != self.up[k][b]:
                a = self.up[k][a]
                b = self.up[k][b]
        return self.up[0][a]

    def distance_edges(self, a, b):
        c = self.lca(a, b)
        return self.depth[a] + self.depth[b] - 2 * self.depth[c]

30. 树链剖分 HLD

将树上路径拆成若干 DFS 序连续区间,用于配合线段树/树状数组。

class HLD:
    def __init__(self, g, root=0):
        self.g = g
        self.n = len(g)
        self.parent = [-1] * self.n
        self.depth = [0] * self.n
        self.size = [0] * self.n
        self.heavy = [-1] * self.n
        self.head = [0] * self.n
        self.pos = [0] * self.n
        self.cur = 0

        self._dfs1(root, -1)
        self._dfs2(root, root)

    def _dfs1(self, u, p):
        self.parent[u] = p
        self.size[u] = 1
        max_size = 0
        for v in self.g[u]:
            if v == p:
                continue
            self.depth[v] = self.depth[u] + 1
            self._dfs1(v, u)
            self.size[u] += self.size[v]
            if self.size[v] > max_size:
                max_size = self.size[v]
                self.heavy[u] = v

    def _dfs2(self, u, h):
        self.head[u] = h
        self.pos[u] = self.cur
        self.cur += 1
        if self.heavy[u] != -1:
            self._dfs2(self.heavy[u], h)
            for v in self.g[u]:
                if v != self.parent[u] and v != self.heavy[u]:
                    self._dfs2(v, v)

    def path_segments(self, a, b):
        segs = []
        while self.head[a] != self.head[b]:
            if self.depth[self.head[a]] < self.depth[self.head[b]]:
                a, b = b, a
            segs.append((self.pos[self.head[a]], self.pos[a]))
            a = self.parent[self.head[a]]
        if self.depth[a] > self.depth[b]:
            a, b = b, a
        segs.append((self.pos[a], self.pos[b]))
        return segs

    def subtree_segment(self, u):
        return self.pos[u], self.pos[u] + self.size[u] - 1

五、动态规划核心模型

31. 01 背包

每个物品最多选一次。

def knapsack_01(items, W):
    dp = [0] * (W + 1)
    for w, v in items:
        for c in range(W, w - 1, -1):
            dp[c] = max(dp[c], dp[c - w] + v)
    return dp[W]

32. 完全背包

每个物品可选无限次。

def knapsack_complete(items, W):
    dp = [0] * (W + 1)
    for w, v in items:
        for c in range(w, W + 1):
            dp[c] = max(dp[c], dp[c - w] + v)
    return dp[W]

33. 多重背包:二进制拆分

def knapsack_multiple(items, W):
    # items: [(w, v, cnt), ...]
    bundles = []
    for w, v, cnt in items:
        k = 1
        while k <= cnt:
            bundles.append((w * k, v * k))
            cnt -= k
            k <<= 1
        if cnt:
            bundles.append((w * cnt, v * cnt))
    return knapsack_01(bundles, W)

34. 最长递增子序列 LIS

严格递增:

def lis_length(a):
    d = []
    for x in a:
        i = bisect_left(d, x)
        if i == len(d):
            d.append(x)
        else:
            d[i] = x
    return len(d)

非递减把 bisect_left 改成 bisect_right

35. 最长公共子序列 LCS

def lcs_length(a, b):
    n, m = len(a), len(b)
    prev = [0] * (m + 1)
    for i in range(1, n + 1):
        cur = [0] * (m + 1)
        ai = a[i - 1]
        for j in range(1, m + 1):
            if ai == b[j - 1]:
                cur[j] = prev[j - 1] + 1
            else:
                cur[j] = max(prev[j], cur[j - 1])
        prev = cur
    return prev[m]

36. 区间 DP:石子合并

def stone_merge(a):
    n = len(a)
    pre = [0]
    for x in a:
        pre.append(pre[-1] + x)

    dp = [[0] * n for _ in range(n)]
    for length in range(2, n + 1):
        for l in range(n - length + 1):
            r = l + length - 1
            total = pre[r + 1] - pre[l]
            dp[l][r] = min(dp[l][k] + dp[k + 1][r] + total for k in range(l, r))
    return dp[0][n - 1]

37. 状态压缩 DP:TSP

从 0 出发访问所有点,最后不强制回到 0。

def tsp(dist):
    n = len(dist)
    dp = [[INF] * n for _ in range(1 << n)]
    dp[1][0] = 0
    for mask in range(1 << n):
        for u in range(n):
            if dp[mask][u] == INF:
                continue
            for v in range(n):
                if not (mask >> v) & 1:
                    nm = mask | (1 << v)
                    dp[nm][v] = min(dp[nm][v], dp[mask][u] + dist[u][v])
    return min(dp[(1 << n) - 1])

38. 数位 DP

例:统计 [0, n] 中数位和对 mod 取余为 0 的正整数个数。

def count_digit_sum_mod_zero(n, mod):
    if n <= 0:
        return 0
    s = str(n)

    @cache
    def dfs(pos, tight, started, rem):
        if pos == len(s):
            return int(started and rem == 0)
        up = int(s[pos]) if tight else 9
        ans = 0
        for d in range(up + 1):
            ntight = tight and d == up
            if not started and d == 0:
                ans += dfs(pos + 1, ntight, False, 0)
            else:
                ans += dfs(pos + 1, ntight, True, (rem + d) % mod)
        return ans

    return dfs(0, True, False, 0)

六、数学与数论

39. 快速幂

Python 优先用内置:

pow(a, b, mod)

手写模板:

def qpow(a, b, mod):
    res = 1
    a %= mod
    while b:
        if b & 1:
            res = res * a % mod
        a = a * a % mod
        b >>= 1
    return res

40. 矩阵快速幂

def mat_mul(A, B, mod):
    n, m, p = len(A), len(B), len(B[0])
    C = [[0] * p for _ in range(n)]
    for i in range(n):
        for k in range(m):
            if A[i][k]:
                aik = A[i][k]
                for j in range(p):
                    C[i][j] = (C[i][j] + aik * B[k][j]) % mod
    return C

def mat_pow(A, e, mod):
    n = len(A)
    res = [[0] * n for _ in range(n)]
    for i in range(n):
        res[i][i] = 1
    while e:
        if e & 1:
            res = mat_mul(res, A, mod)
        A = mat_mul(A, A, mod)
        e >>= 1
    return res

41. 素数筛

埃氏筛:

def sieve(n):
    is_prime = [True] * (n + 1)
    if n >= 0:
        is_prime[0] = False
    if n >= 1:
        is_prime[1] = False
    for i in range(2, isqrt(n) + 1):
        if is_prime[i]:
            for j in range(i * i, n + 1, i):
                is_prime[j] = False
    primes = [i for i in range(2, n + 1) if is_prime[i]]
    return primes, is_prime

欧拉线性筛:

def linear_sieve(n):
    primes = []
    is_comp = [False] * (n + 1)
    for i in range(2, n + 1):
        if not is_comp[i]:
            primes.append(i)
        for p in primes:
            if i * p > n:
                break
            is_comp[i * p] = True
            if i % p == 0:
                break
    return primes, is_comp

42. 扩展欧几里得

def exgcd(a, b):
    if b == 0:
        return a, 1, 0
    g, x1, y1 = exgcd(b, a % b)
    x = y1
    y = x1 - (a // b) * y1
    return g, x, y

43. 模逆元与模除

模数为质数:

inv_a = pow(a, MOD - 2, MOD)
ans = b * inv_a % MOD

一般互质情况:

def mod_inv(a, mod):
    g, x, _ = exgcd(a, mod)
    if g != 1:
        return None
    return x % mod

Python 3.8+ 可直接:

inv_a = pow(a, -1, mod)

44. 组合数取模预处理

适合固定质数模数,多次查询 C(n, k)

def build_comb(N, mod):
    fac = [1] * (N + 1)
    ifac = [1] * (N + 1)
    for i in range(1, N + 1):
        fac[i] = fac[i - 1] * i % mod
    ifac[N] = pow(fac[N], mod - 2, mod)
    for i in range(N, 0, -1):
        ifac[i - 1] = ifac[i] * i % mod

    def C(n, k):
        if k < 0 or k > n:
            return 0
        return fac[n] * ifac[k] % mod * ifac[n - k] % mod

    return C, fac, ifac

七、字符串高级算法

45. 字符串哈希

双模哈希,查询子串 [l, r)

class StringHash:
    MOD1 = 1_000_000_007
    MOD2 = 1_000_000_009
    BASE = 911382323

    def __init__(self, s):
        n = len(s)
        self.p1 = [1] * (n + 1)
        self.p2 = [1] * (n + 1)
        self.h1 = [0] * (n + 1)
        self.h2 = [0] * (n + 1)
        for i, ch in enumerate(s):
            x = ord(ch)
            self.p1[i + 1] = self.p1[i] * self.BASE % self.MOD1
            self.p2[i + 1] = self.p2[i] * self.BASE % self.MOD2
            self.h1[i + 1] = (self.h1[i] * self.BASE + x) % self.MOD1
            self.h2[i + 1] = (self.h2[i] * self.BASE + x) % self.MOD2

    def get(self, l, r):
        x1 = (self.h1[r] - self.h1[l] * self.p1[r - l]) % self.MOD1
        x2 = (self.h2[r] - self.h2[l] * self.p2[r - l]) % self.MOD2
        return x1, x2

46. KMP 字符串匹配

def prefix_function(p):
    n = len(p)
    pi = [0] * n
    for i in range(1, n):
        j = pi[i - 1]
        while j and p[i] != p[j]:
            j = pi[j - 1]
        if p[i] == p[j]:
            j += 1
        pi[i] = j
    return pi

def kmp_search(text, pattern):
    if not pattern:
        return list(range(len(text) + 1))
    pi = prefix_function(pattern)
    ans = []
    j = 0
    for i, ch in enumerate(text):
        while j and ch != pattern[j]:
            j = pi[j - 1]
        if ch == pattern[j]:
            j += 1
        if j == len(pattern):
            ans.append(i - j + 1)
            j = pi[j - 1]
    return ans

47. Manacher 最长回文子串

def manacher(s):
    t = "^#" + "#".join(s) + "#$"
    p = [0] * len(t)
    center = right = 0
    best_len = best_center = 0

    for i in range(1, len(t) - 1):
        mirror = 2 * center - i
        if i < right:
            p[i] = min(right - i, p[mirror])
        while t[i + p[i] + 1] == t[i - p[i] - 1]:
            p[i] += 1
        if i + p[i] > right:
            center, right = i, i + p[i]
        if p[i] > best_len:
            best_len, best_center = p[i], i

    start = (best_center - best_len) // 2
    return s[start:start + best_len], best_len

48. AC 自动机

多模式串匹配。

class AhoCorasick:
    def __init__(self):
        self.next = [{}]
        self.fail = [0]
        self.out = [[]]

    def add(self, word, idx):
        u = 0
        for ch in word:
            if ch not in self.next[u]:
                self.next[u][ch] = len(self.next)
                self.next.append({})
                self.fail.append(0)
                self.out.append([])
            u = self.next[u][ch]
        self.out[u].append(idx)

    def build(self):
        q = deque()
        for v in self.next[0].values():
            q.append(v)
        while q:
            u = q.popleft()
            for ch, v in self.next[u].items():
                f = self.fail[u]
                while f and ch not in self.next[f]:
                    f = self.fail[f]
                self.fail[v] = self.next[f].get(ch, 0)
                self.out[v].extend(self.out[self.fail[v]])
                q.append(v)

    def search(self, text):
        res = []
        u = 0
        for i, ch in enumerate(text):
            while u and ch not in self.next[u]:
                u = self.fail[u]
            u = self.next[u].get(ch, 0)
            for idx in self.out[u]:
                res.append((i, idx))  # 模式 idx 在 i 位置结束
        return res

八、计算几何

49. 二维点/向量

点用 (x, y) 元组。

def sub(a, b):
    return a[0] - b[0], a[1] - b[1]

def dot(a, b):
    return a[0] * b[0] + a[1] * b[1]

def cross(a, b):
    return a[0] * b[1] - a[1] * b[0]

def cross3(o, a, b):
    return cross(sub(a, o), sub(b, o))

def norm2(a):
    return a[0] * a[0] + a[1] * a[1]

50. 线段相交

整数坐标可直接精确判断。

def on_segment(a, b, p):
    return cross3(a, b, p) == 0 and min(a[0], b[0]) <= p[0] <= max(a[0], b[0]) and min(a[1], b[1]) <= p[1] <= max(a[1], b[1])

def segments_intersect(a, b, c, d):
    c1 = cross3(a, b, c)
    c2 = cross3(a, b, d)
    c3 = cross3(c, d, a)
    c4 = cross3(c, d, b)

    if c1 == 0 and on_segment(a, b, c):
        return True
    if c2 == 0 and on_segment(a, b, d):
        return True
    if c3 == 0 and on_segment(c, d, a):
        return True
    if c4 == 0 and on_segment(c, d, b):
        return True
    return c1 * c2 < 0 and c3 * c4 < 0

51. 凸包 Andrew

返回逆时针凸包,不保留边上的共线点。

def convex_hull(points):
    pts = sorted(set(points))
    if len(pts) <= 1:
        return pts

    lower = []
    for p in pts:
        while len(lower) >= 2 and cross3(lower[-2], lower[-1], p) <= 0:
            lower.pop()
        lower.append(p)

    upper = []
    for p in reversed(pts):
        while len(upper) >= 2 and cross3(upper[-2], upper[-1], p) <= 0:
            upper.pop()
        upper.append(p)

    return lower[:-1] + upper[:-1]

若要保留边上共线点,把 <= 0 改成 < 0

九、搜索与回溯

52. 网格 BFS

DIR4 = [(1, 0), (-1, 0), (0, 1), (0, -1)]

def grid_bfs(grid, start):
    n, m = len(grid), len(grid[0])
    sx, sy = start
    dist = [[-1] * m for _ in range(n)]
    dist[sx][sy] = 0
    q = deque([(sx, sy)])
    while q:
        x, y = q.popleft()
        for dx, dy in DIR4:
            nx, ny = x + dx, y + dy
            if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] != "#" and dist[nx][ny] == -1:
                dist[nx][ny] = dist[x][y] + 1
                q.append((nx, ny))
    return dist

53. 多源 BFS

def multi_source_bfs(grid, sources):
    n, m = len(grid), len(grid[0])
    dist = [[-1] * m for _ in range(n)]
    q = deque()
    for x, y in sources:
        dist[x][y] = 0
        q.append((x, y))

    while q:
        x, y = q.popleft()
        for dx, dy in DIR4:
            nx, ny = x + dx, y + dy
            if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] != "#" and dist[nx][ny] == -1:
                dist[nx][ny] = dist[x][y] + 1
                q.append((nx, ny))
    return dist

54. 0-1 BFS

边权只能是 0 或 1。

def zero_one_bfs(n, g, s):
    dist = [INF] * n
    dist[s] = 0
    q = deque([s])
    while q:
        u = q.popleft()
        for v, w in g[u]:
            nd = dist[u] + w
            if nd < dist[v]:
                dist[v] = nd
                if w == 0:
                    q.appendleft(v)
                else:
                    q.append(v)
    return dist

55. 双向 BFS

适合无权图、状态空间巨大且可以从终点反向扩展。

def bidirectional_bfs(start, target, neighbors):
    if start == target:
        return 0
    front = {start}
    back = {target}
    dist_f = {start: 0}
    dist_b = {target: 0}

    while front and back:
        if len(front) > len(back):
            front, back = back, front
            dist_f, dist_b = dist_b, dist_f

        nxt_front = set()
        for u in front:
            for v in neighbors(u):
                if v in dist_f:
                    continue
                if v in dist_b:
                    return dist_f[u] + 1 + dist_b[v]
                dist_f[v] = dist_f[u] + 1
                nxt_front.add(v)
        front = nxt_front
    return -1

56. 状态空间 BFS

状态必须可哈希,如整数、字符串、元组。

def state_bfs(start, is_goal, next_states):
    dist = {start: 0}
    q = deque([start])
    while q:
        state = q.popleft()
        if is_goal(state):
            return dist[state]
        for ns in next_states(state):
            if ns not in dist:
                dist[ns] = dist[state] + 1
                q.append(ns)
    return -1

57. DFS Flood Fill

递归版:

def flood_fill(grid, sx, sy, old, new):
    n, m = len(grid), len(grid[0])
    if old == new or grid[sx][sy] != old:
        return

    def dfs(x, y):
        grid[x][y] = new
        for dx, dy in DIR4:
            nx, ny = x + dx, y + dy
            if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] == old:
                dfs(nx, ny)

    dfs(sx, sy)

显式栈版:

def flood_fill_iter(grid, sx, sy, old, new):
    n, m = len(grid), len(grid[0])
    if old == new or grid[sx][sy] != old:
        return
    st = [(sx, sy)]
    grid[sx][sy] = new
    while st:
        x, y = st.pop()
        for dx, dy in DIR4:
            nx, ny = x + dx, y + dy
            if 0 <= nx < n and 0 <= ny < m and grid[nx][ny] == old:
                grid[nx][ny] = new
                st.append((nx, ny))

58. 图 DFS 显式栈

def iterative_dfs(g, start):
    n = len(g)
    vis = [False] * n
    order = []
    st = [start]
    vis[start] = True
    while st:
        u = st.pop()
        order.append(u)
        for v in g[u]:
            if not vis[v]:
                vis[v] = True
                st.append(v)
    return order

59. 回溯

全排列模板:

def permute(a):
    n = len(a)
    used = [False] * n
    path = []
    ans = []

    def dfs():
        if len(path) == n:
            ans.append(path[:])
            return
        for i in range(n):
            if used[i]:
                continue
            used[i] = True
            path.append(a[i])
            dfs()
            path.pop()
            used[i] = False

    dfs()
    return ans

有重复元素时跳过同层重复:

def permute_unique(a):
    a.sort()
    n = len(a)
    used = [False] * n
    path = []
    ans = []

    def dfs():
        if len(path) == n:
            ans.append(path[:])
            return
        for i in range(n):
            if used[i]:
                continue
            if i and a[i] == a[i - 1] and not used[i - 1]:
                continue
            used[i] = True
            path.append(a[i])
            dfs()
            path.pop()
            used[i] = False

    dfs()
    return ans

60. 记忆化搜索

@cache
def dfs(state):
    if is_terminal(state):
        return base_value(state)
    best = -INF
    for nxt in transitions(state):
        best = max(best, gain(state, nxt) + dfs(nxt))
    return best

常见要求:

  • state 必须可哈希,列表要转成 tuple
  • 多组测试时,记得 dfs.cache_clear()
  • 递归深度大时改成递推 DP 或显式栈。

61. 剪枝常用写法

回溯中常见剪枝位置:

def dfs(i, cur):
    global_ans = None

    # 1. 可行性剪枝:剩余资源已经不可能满足要求
    if not feasible(i, cur):
        return

    # 2. 最优性剪枝:理论上界也不可能超过当前答案
    if upper_bound(i, cur) <= best_answer():
        return

    # 3. 记忆化剪枝:同一状态只保留更优值
    key = encode_state(i, cur)
    if key in seen and seen[key] >= cur:
        return
    seen[key] = cur

    # 正常搜索

实战顺序通常是:先排序制造单调性,再写可行性剪枝,最后加记忆化或上界剪枝。


评论