Title here
Summary here
在实际应用中,很多请求共享相同的前缀:
典型场景:
问题:
Radix Tree(基数树)是一种压缩的 Trie 树,用于高效存储和查找字符串:
特点:
在 SGLang 中,Radix Tree 存储 Token 序列:
class RadixNode:
def __init__(self):
self.children: Dict[int, RadixNode] = {} # token -> child
self.parent: Optional[RadixNode] = None
self.edge_tokens: List[int] = [] # 边上的 tokens
self.kv_indices: Optional[torch.Tensor] = None # KV Cache 位置
self.ref_count: int = 0 # 引用计数
self.last_access_time: float = 0 # 最后访问时间关键文件:python/sglang/srt/mem_cache/radix_cache.py
class RadixCache:
def __init__(
self,
token_to_kv_pool: KVCache,
page_size: int = 1,
disable: bool = False,
):
self.root = RadixNode()
self.token_to_kv_pool = token_to_kv_pool
self.page_size = page_size
self.disable = disable
# 统计
self.total_tokens = 0
self.evictable_tokens = 0def match_prefix(self, token_ids: List[int]) -> Tuple[RadixNode, int]:
"""匹配最长前缀,返回节点和匹配长度"""
if self.disable:
return self.root, 0
node = self.root
matched_len = 0
while token_ids:
# 查找匹配的子节点
first_token = token_ids[0]
if first_token not in node.children:
break
child = node.children[first_token]
edge_len = len(child.edge_tokens)
# 检查边上的 tokens 是否匹配
if len(token_ids) >= edge_len:
if token_ids[:edge_len] == child.edge_tokens:
# 完全匹配,继续搜索
node = child
token_ids = token_ids[edge_len:]
matched_len += edge_len
else:
# 部分匹配
partial_len = self._get_partial_match_len(
token_ids, child.edge_tokens
)
matched_len += partial_len
break
else:
# 输入序列较短,检查部分匹配
partial_len = self._get_partial_match_len(
token_ids, child.edge_tokens
)
matched_len += partial_len
break
return node, matched_lendef cache_req(
self,
token_ids: List[int],
kv_indices: torch.Tensor,
last_node: RadixNode,
prefix_len: int,
) -> RadixNode:
"""缓存请求的 KV Cache"""
if self.disable:
return self.root
# 从上次匹配的位置开始插入
node = last_node
remaining_tokens = token_ids[prefix_len:]
remaining_indices = kv_indices[prefix_len:]
while remaining_tokens:
first_token = remaining_tokens[0]
if first_token not in node.children:
# 创建新节点
new_node = RadixNode()
new_node.parent = node
new_node.edge_tokens = remaining_tokens.copy()
new_node.kv_indices = remaining_indices.clone()
node.children[first_token] = new_node
node = new_node
break
else:
# 可能需要分裂节点
child = node.children[first_token]
# ... 分裂逻辑
pass
return nodedef inc_ref(self, node: RadixNode):
"""增加引用计数"""
while node is not None:
node.ref_count += 1
node = node.parent
def dec_ref(self, node: RadixNode):
"""减少引用计数"""
while node is not None:
node.ref_count -= 1
if node.ref_count == 0 and node != self.root:
# 可以被驱逐
self.evictable_tokens += len(node.edge_tokens)
node = node.parentdef init_next_round_input(self, tree_cache: RadixCache):
"""初始化请求的下一轮输入"""
# 匹配前缀
self.last_node, prefix_len = tree_cache.match_prefix(self.input_ids)
# 计算需要扩展的 tokens
self.extend_input_len = len(self.input_ids) - prefix_len
self.prefix_len = prefix_len
# 增加引用计数
tree_cache.inc_ref(self.last_node)
# 获取前缀的 KV 位置
if prefix_len > 0:
self.prefix_kv_indices = self.last_node.kv_indices[:prefix_len]def evict(self, num_tokens: int) -> int:
"""驱逐指定数量的 tokens"""
if self.disable:
return 0
evicted = 0
candidates = self._get_evictable_nodes()
# 按 LRU 排序
candidates.sort(key=lambda n: n.last_access_time)
for node in candidates:
if evicted >= num_tokens:
break
# 驱逐节点
tokens_in_node = len(node.edge_tokens)
self._remove_node(node)
evicted += tokens_in_node
# 释放 KV 空间
self.token_to_kv_pool.free(node.kv_indices)
return evicted
def _get_evictable_nodes(self) -> List[RadixNode]:
"""获取所有可驱逐的节点"""
evictable = []
def dfs(node):
if node.ref_count == 0 and node != self.root:
evictable.append(node)
for child in node.children.values():
dfs(child)
dfs(self.root)
return evictable提升原因:
| 场景 | 传统方法 | RadixAttention |
|---|---|---|
| 10 请求,相同系统提示 | 10x 系统提示 KV | 1x 系统提示 KV |
| RAG,相似上下文 | 重复存储 | 共享公共部分 |
| 多轮对话 | 完整历史 | 复用历史 KV |
传统: TTFT = Prefill(系统提示 + 用户输入)
RadixAttention: TTFT = Prefill(用户输入) // 系统提示已缓存
当新请求只匹配前缀的一部分时:
def _split_node(self, node: RadixNode, split_pos: int) -> RadixNode:
"""分裂节点"""
# 创建新的中间节点
new_node = RadixNode()
new_node.parent = node.parent
new_node.edge_tokens = node.edge_tokens[:split_pos]
new_node.kv_indices = node.kv_indices[:split_pos]
# 调整原节点
node.edge_tokens = node.edge_tokens[split_pos:]
node.kv_indices = node.kv_indices[split_pos:]
node.parent = new_node
# 更新父节点的子节点指针
new_node.children[node.edge_tokens[0]] = node
return new_node结合 Paged Attention:
def cache_req_with_pages(
self,
token_ids: List[int],
page_indices: List[int],
page_size: int,
):
"""分页模式下的缓存"""
# 将 token 位置转换为页
# 支持跨页的前缀匹配
pass处理图像等多模态输入的缓存:
def match_prefix_multimodal(
self,
token_ids: List[int],
image_hash: Optional[str],
) -> Tuple[RadixNode, int]:
"""多模态前缀匹配"""
# 结合图像哈希进行匹配
passdef get_stats(self) -> Dict:
"""获取缓存统计"""
return {
"total_nodes": self._count_nodes(),
"total_tokens": self.total_tokens,
"evictable_tokens": self.evictable_tokens,
"hit_rate": self.hits / (self.hits + self.misses),
}def visualize(self) -> str:
"""可视化 Radix Tree"""
lines = []
def dfs(node, prefix=""):
if node != self.root:
tokens_str = str(node.edge_tokens[:5]) + "..."
lines.append(f"{prefix}{tokens_str} (ref={node.ref_count})")
for child in node.children.values():
dfs(child, prefix + " ")
dfs(self.root)
return "\n".join(lines)| 优势 | 说明 |
|---|---|
| 自动复用 | 无需手动管理前缀缓存 |
| 细粒度 | 任意长度前缀都可复用 |
| 动态调整 | 根据访问模式自动优化 |
在下一章《内存池设计》中,我们将: