底层数据结构设计
学习目标
- 深入理解 Redis 的核心数据结构设计
- 掌握 SDS、字典、跳跃表、压缩列表等关键数据结构
- 实现简化版的数据结构库
- 理解数据结构选择对性能的影响
️ Redis 数据结构体系
Redis 的数据结构分为两个层次:
1. 基础数据结构(底层)
- SDS(Simple Dynamic String):字符串实现
- 字典(Dict):哈希表实现
- 跳跃表(Skip List):有序集合实现
- 压缩列表(ZipList):紧凑的列表和哈希实现
- 整数集合(IntSet):整数集合实现
2. 对象类型(上层)
- String:字符串对象
- Hash:哈希对象
- List:列表对象
- Set:集合对象
- ZSet:有序集合对象
SDS(Simple Dynamic String)
设计目标
SDS 是 Redis 的字符串实现,解决了 C 字符串的以下问题:
- 长度获取:O(1) 时间复杂度获取字符串长度
- 二进制安全:可以存储任意二进制数据
- 内存安全:避免缓冲区溢出
- 高效操作:减少内存重分配次数
SDS 结构设计
// sds/sds.go
package sds
import (
"unsafe"
)
// SDS 结构体
type SDS struct {
len int // 字符串长度
free int // 剩余空间
buf []byte // 字符数组
}
// 创建 SDS
func NewSDS(initStr string) *SDS {
buf := make([]byte, len(initStr))
copy(buf, initStr)
return &SDS{
len: len(initStr),
free: 0,
buf: buf,
}
}
// 获取长度 O(1)
func (s *SDS) Len() int {
return s.len
}
// 获取字符串内容
func (s *SDS) String() string {
return string(s.buf[:s.len])
}
// 追加字符串
func (s *SDS) Append(str string) {
strLen := len(str)
totalLen := s.len + strLen
// 空间不足时扩容
if s.free < strLen {
s.expand(totalLen)
}
// 追加字符串
copy(s.buf[s.len:], str)
s.len = totalLen
s.free = len(s.buf) - s.len
}
// 扩容策略
func (s *SDS) expand(newLen int) {
// 计算新容量(2倍增长)
newCap := len(s.buf) * 2
if newCap < newLen {
newCap = newLen
}
// 分配新空间
newBuf := make([]byte, newCap)
copy(newBuf, s.buf[:s.len])
s.buf = newBuf
s.free = newCap - s.len
}
// 截取字符串
func (s *SDS) Substr(start, end int) *SDS {
if start < 0 {
start = 0
}
if end >= s.len {
end = s.len - 1
}
if start > end {
return NewSDS("")
}
substr := string(s.buf[start : end+1])
return NewSDS(substr)
}
// 比较字符串
func (s *SDS) Compare(other *SDS) int {
minLen := s.len
if other.len < minLen {
minLen = other.len
}
for i := 0; i < minLen; i++ {
if s.buf[i] < other.buf[i] {
return -1
} else if s.buf[i] > other.buf[i] {
return 1
}
}
if s.len < other.len {
return -1
} else if s.len > other.len {
return 1
}
return 0
}
SDS 优化策略
// sds/optimization.go
package sds
// 空间预分配策略
func (s *SDS) expandWithStrategy(newLen int) {
newCap := len(s.buf)
// 如果新长度小于 1MB,则翻倍
if newLen < 1024*1024 {
newCap = newLen * 2
} else {
// 否则每次增加 1MB
newCap = newLen + 1024*1024
}
newBuf := make([]byte, newCap)
copy(newBuf, s.buf[:s.len])
s.buf = newBuf
s.free = newCap - s.len
}
// 惰性空间释放
func (s *SDS) Trim() {
if s.free > s.len {
newBuf := make([]byte, s.len)
copy(newBuf, s.buf[:s.len])
s.buf = newBuf
s.free = 0
}
}
️ 字典(Dict)
设计目标
字典是 Redis 的核心数据结构,用于实现:
- 数据库:键值对存储
- 哈希对象:字段值映射
- 集合对象:成员存储
- 有序集合对象:成员到分数的映射
字典结构设计
// dict/dict.go
package dict
import (
"hash/fnv"
"math"
)
// 字典节点
type DictEntry struct {
key interface{}
value interface{}
next *DictEntry
}
// 哈希表
type HashTable struct {
table []*DictEntry
size int64
sizemask int64
used int64
}
// 字典
type Dict struct {
ht [2]*HashTable // 两个哈希表用于渐进式 rehash
rehashidx int64 // rehash 进度,-1 表示未进行 rehash
iterators int64 // 正在运行的迭代器数量
}
// 创建字典
func NewDict() *Dict {
return &Dict{
ht: [2]*HashTable{newHashTable(4), nil},
rehashidx: -1,
iterators: 0,
}
}
// 创建哈希表
func newHashTable(size int64) *HashTable {
// 确保大小为 2 的幂
actualSize := nextPowerOfTwo(size)
return &HashTable{
table: make([]*DictEntry, actualSize),
size: actualSize,
sizemask: actualSize - 1,
used: 0,
}
}
// 计算下一个 2 的幂
func nextPowerOfTwo(n int64) int64 {
if n <= 0 {
return 1
}
n--
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n++
return n
}
// 哈希函数
func hashKey(key interface{}) uint64 {
switch k := key.(type) {
case string:
h := fnv.New64a()
h.Write([]byte(k))
return h.Sum64()
case int:
return uint64(k)
case int64:
return uint64(k)
default:
// 默认使用字符串表示
h := fnv.New64a()
h.Write([]byte(fmt.Sprintf("%v", k)))
return h.Sum64()
}
}
// 设置键值对
func (d *Dict) Set(key, value interface{}) {
// 如果正在 rehash,执行一步 rehash
if d.isRehashing() {
d.rehashStep()
}
// 计算哈希值
hash := hashKey(key)
idx := hash & d.ht[0].sizemask
// 检查键是否已存在
entry := d.ht[0].table[idx]
for entry != nil {
if entry.key == key {
entry.value = value
return
}
entry = entry.next
}
// 添加新节点
newEntry := &DictEntry{
key: key,
value: value,
next: d.ht[0].table[idx],
}
d.ht[0].table[idx] = newEntry
d.ht[0].used++
// 检查是否需要扩容
if !d.isRehashing() && d.ht[0].used >= d.ht[0].size {
d.expand()
}
}
// 获取值
func (d *Dict) Get(key interface{}) (interface{}, bool) {
// 如果正在 rehash,先在 ht[0] 中查找
if d.isRehashing() {
if value, found := d.getInTable(d.ht[0], key); found {
return value, true
}
if value, found := d.getInTable(d.ht[1], key); found {
return value, true
}
} else {
if value, found := d.getInTable(d.ht[0], key); found {
return value, true
}
}
return nil, false
}
// 在指定哈希表中查找
func (d *Dict) getInTable(ht *HashTable, key interface{}) (interface{}, bool) {
if ht == nil {
return nil, false
}
hash := hashKey(key)
idx := hash & ht.sizemask
entry := ht.table[idx]
for entry != nil {
if entry.key == key {
return entry.value, true
}
entry = entry.next
}
return nil, false
}
// 删除键
func (d *Dict) Delete(key interface{}) bool {
if d.isRehashing() {
d.rehashStep()
}
hash := hashKey(key)
idx := hash & d.ht[0].sizemask
var prev *DictEntry
entry := d.ht[0].table[idx]
for entry != nil {
if entry.key == key {
if prev == nil {
d.ht[0].table[idx] = entry.next
} else {
prev.next = entry.next
}
d.ht[0].used--
return true
}
prev = entry
entry = entry.next
}
return false
}
// 检查是否正在 rehash
func (d *Dict) isRehashing() bool {
return d.rehashidx != -1
}
// 扩容
func (d *Dict) expand() {
if d.isRehashing() {
return
}
// 计算新大小
newSize := d.ht[0].used * 2
if newSize < 4 {
newSize = 4
}
// 创建新的哈希表
d.ht[1] = newHashTable(newSize)
d.rehashidx = 0
}
// 渐进式 rehash
func (d *Dict) rehashStep() {
if !d.isRehashing() {
return
}
// 每次移动一个桶
ht0 := d.ht[0]
ht1 := d.ht[1]
for d.rehashidx < ht0.size {
if ht0.table[d.rehashidx] != nil {
// 移动整个链表
entry := ht0.table[d.rehashidx]
ht0.table[d.rehashidx] = nil
for entry != nil {
next := entry.next
hash := hashKey(entry.key)
idx := hash & ht1.sizemask
entry.next = ht1.table[idx]
ht1.table[idx] = entry
entry = next
ht0.used--
ht1.used++
}
}
d.rehashidx++
// 限制每次 rehash 的时间
if ht0.used == 0 {
break
}
}
// rehash 完成
if ht0.used == 0 {
d.ht[0] = d.ht[1]
d.ht[1] = nil
d.rehashidx = -1
}
}
跳跃表(Skip List)
设计目标
跳跃表用于实现有序集合,提供:
- 有序性:元素按分数排序
- 高效查找:O(log n) 时间复杂度
- 范围查询:支持按范围获取元素
- 简单实现:比平衡树更容易实现
跳跃表结构设计
// skiplist/skiplist.go
package skiplist
import (
"math/rand"
"time"
)
// 跳跃表节点
type SkipListNode struct {
key interface{}
value interface{}
score float64
level int
forward []*SkipListNode
}
// 跳跃表
type SkipList struct {
header *SkipListNode
level int
length int64
}
// 创建跳跃表
func NewSkipList() *SkipList {
header := &SkipListNode{
level: MAX_LEVEL,
forward: make([]*SkipListNode, MAX_LEVEL+1),
}
return &SkipList{
header: header,
level: 1,
length: 0,
}
}
const MAX_LEVEL = 16
const P = 0.25 // 概率因子
// 随机生成层数
func randomLevel() int {
level := 1
for rand.Float64() < P && level < MAX_LEVEL {
level++
}
return level
}
// 插入元素
func (sl *SkipList) Insert(key interface{}, value interface{}, score float64) {
// 查找插入位置
update := make([]*SkipListNode, MAX_LEVEL+1)
current := sl.header
// 从最高层开始查找
for i := sl.level; i >= 1; i-- {
for current.forward[i] != nil &&
(current.forward[i].score < score ||
(current.forward[i].score == score &&
compareKeys(current.forward[i].key, key) < 0)) {
current = current.forward[i]
}
update[i] = current
}
// 移动到第 0 层
current = current.forward[0]
// 如果键已存在,更新值
if current != nil && current.key == key && current.score == score {
current.value = value
return
}
// 创建新节点
newLevel := randomLevel()
newNode := &SkipListNode{
key: key,
value: value,
score: score,
level: newLevel,
forward: make([]*SkipListNode, newLevel+1),
}
// 如果新层数大于当前层数,更新头节点
if newLevel > sl.level {
for i := sl.level + 1; i <= newLevel; i++ {
update[i] = sl.header
}
sl.level = newLevel
}
// 插入节点
for i := 1; i <= newLevel; i++ {
newNode.forward[i] = update[i].forward[i]
update[i].forward[i] = newNode
}
sl.length++
}
// 查找元素
func (sl *SkipList) Find(key interface{}, score float64) (interface{}, bool) {
current := sl.header
// 从最高层开始查找
for i := sl.level; i >= 1; i-- {
for current.forward[i] != nil &&
(current.forward[i].score < score ||
(current.forward[i].score == score &&
compareKeys(current.forward[i].key, key) < 0)) {
current = current.forward[i]
}
}
// 移动到第 0 层
current = current.forward[0]
if current != nil && current.key == key && current.score == score {
return current.value, true
}
return nil, false
}
// 删除元素
func (sl *SkipList) Delete(key interface{}, score float64) bool {
update := make([]*SkipListNode, MAX_LEVEL+1)
current := sl.header
// 查找要删除的节点
for i := sl.level; i >= 1; i-- {
for current.forward[i] != nil &&
(current.forward[i].score < score ||
(current.forward[i].score == score &&
compareKeys(current.forward[i].key, key) < 0)) {
current = current.forward[i]
}
update[i] = current
}
current = current.forward[0]
if current != nil && current.key == key && current.score == score {
// 删除节点
for i := 1; i <= sl.level; i++ {
if update[i].forward[i] != current {
break
}
update[i].forward[i] = current.forward[i]
}
// 更新层数
for sl.level > 1 && sl.header.forward[sl.level] == nil {
sl.level--
}
sl.length--
return true
}
return false
}
// 范围查询
func (sl *SkipList) RangeByScore(minScore, maxScore float64) []*SkipListNode {
var result []*SkipListNode
current := sl.header
// 找到起始位置
for i := sl.level; i >= 1; i-- {
for current.forward[i] != nil && current.forward[i].score < minScore {
current = current.forward[i]
}
}
// 遍历范围内的节点
current = current.forward[0]
for current != nil && current.score <= maxScore {
result = append(result, current)
current = current.forward[0]
}
return result
}
// 比较键的大小
func compareKeys(key1, key2 interface{}) int {
switch k1 := key1.(type) {
case string:
if k2, ok := key2.(string); ok {
if k1 < k2 {
return -1
} else if k1 > k2 {
return 1
}
return 0
}
case int:
if k2, ok := key2.(int); ok {
if k1 < k2 {
return -1
} else if k1 > k2 {
return 1
}
return 0
}
}
// 默认按字符串比较
s1 := fmt.Sprintf("%v", key1)
s2 := fmt.Sprintf("%v", key2)
if s1 < s2 {
return -1
} else if s1 > s2 {
return 1
}
return 0
}
压缩列表(ZipList)
设计目标
压缩列表用于实现:
- 内存效率:紧凑存储小列表和哈希
- 缓存友好:连续内存布局
- 简单操作:支持基本的增删改查
压缩列表结构设计
// ziplist/ziplist.go
package ziplist
import (
"encoding/binary"
"fmt"
)
// 压缩列表结构
type ZipList struct {
data []byte
}
// 创建压缩列表
func NewZipList() *ZipList {
// 初始大小:头部(10字节) + 结束标记(1字节)
data := make([]byte, 11)
// 设置头部信息
binary.LittleEndian.PutUint32(data[0:4], 11) // zlbytes
binary.LittleEndian.PutUint32(data[4:8], 10) // zltail
binary.LittleEndian.PutUint16(data[8:10], 0) // zllen
data[10] = 0xFF // zlend
return &ZipList{data: data}
}
// 添加元素到末尾
func (zl *ZipList) Push(data []byte) {
// 计算新元素的大小
entrySize := zl.calculateEntrySize(data)
// 扩展压缩列表
oldLen := len(zl.data)
newLen := oldLen - 1 + entrySize // 减去结束标记,加上新元素
newData := make([]byte, newLen+1) // +1 为新的结束标记
// 复制原有数据
copy(newData, zl.data[:oldLen-1])
// 添加新元素
zl.encodeEntry(newData[oldLen-1:], data)
// 添加结束标记
newData[newLen] = 0xFF
// 更新头部信息
binary.LittleEndian.PutUint32(newData[0:4], uint32(newLen+1)) // zlbytes
binary.LittleEndian.PutUint32(newData[4:8], uint32(oldLen-1)) // zltail
zl.updateLength(newData)
zl.data = newData
}
// 计算元素大小
func (zl *ZipList) calculateEntrySize(data []byte) int {
dataLen := len(data)
// 前一个元素长度字段:1 或 5 字节
prevLenSize := 1
if dataLen >= 254 {
prevLenSize = 5
}
// 编码字段:1 字节
encodingSize := 1
// 数据长度
dataSize := dataLen
return prevLenSize + encodingSize + dataSize
}
// 编码元素
func (zl *ZipList) encodeEntry(dst []byte, data []byte) {
pos := 0
dataLen := len(data)
// 编码前一个元素长度
if dataLen >= 254 {
dst[pos] = 0xFE
binary.LittleEndian.PutUint32(dst[pos+1:pos+5], uint32(dataLen))
pos += 5
} else {
dst[pos] = byte(dataLen)
pos++
}
// 编码数据
if dataLen <= 63 {
// 6位长度 + 数据
dst[pos] = byte(dataLen)
pos++
} else if dataLen <= 16383 {
// 14位长度 + 数据
binary.LittleEndian.PutUint16(dst[pos:pos+2], uint16(dataLen)|0x4000)
pos += 2
} else {
// 32位长度 + 数据
binary.LittleEndian.PutUint32(dst[pos:pos+4], uint32(dataLen)|0x80000000)
pos += 4
}
// 复制数据
copy(dst[pos:], data)
}
// 更新长度字段
func (zl *ZipList) updateLength(data []byte) {
// 计算元素数量
count := 0
pos := 10 // 跳过头部
for pos < len(data)-1 {
// 读取前一个元素长度
if data[pos] == 0xFE {
pos += 5
} else {
pos++
}
// 读取编码
if data[pos]&0x80 == 0 {
pos++
} else if data[pos]&0x40 == 0 {
pos += 2
} else {
pos += 4
}
// 跳过数据
dataLen := zl.getDataLength(data, pos-1)
pos += dataLen
count++
}
binary.LittleEndian.PutUint16(data[8:10], uint16(count))
}
// 获取数据长度
func (zl *ZipList) getDataLength(data []byte, encodingPos int) int {
encoding := data[encodingPos]
if encoding&0x80 == 0 {
return int(encoding)
} else if encoding&0x40 == 0 {
return int(binary.LittleEndian.Uint16(data[encodingPos:encodingPos+2]) & 0x3FFF)
} else {
return int(binary.LittleEndian.Uint32(data[encodingPos:encodingPos+4]) & 0x3FFFFFFF)
}
}
// 遍历元素
func (zl *ZipList) ForEach(fn func([]byte) bool) {
pos := 10 // 跳过头部
for pos < len(zl.data)-1 {
// 读取前一个元素长度
if zl.data[pos] == 0xFE {
pos += 5
} else {
pos++
}
// 读取编码
encodingPos := pos
if zl.data[pos]&0x80 == 0 {
pos++
} else if zl.data[pos]&0x40 == 0 {
pos += 2
} else {
pos += 4
}
// 获取数据
dataLen := zl.getDataLength(zl.data, encodingPos)
data := zl.data[pos : pos+dataLen]
if !fn(data) {
break
}
pos += dataLen
}
}
// 获取长度
func (zl *ZipList) Len() int {
return int(binary.LittleEndian.Uint16(zl.data[8:10]))
}
测试验证
1. SDS 测试
// sds/sds_test.go
package sds
import (
"testing"
)
func TestSDS(t *testing.T) {
sds := NewSDS("Hello")
// 测试长度
if sds.Len() != 5 {
t.Errorf("Expected length 5, got %d", sds.Len())
}
// 测试追加
sds.Append(" World")
if sds.String() != "Hello World" {
t.Errorf("Expected 'Hello World', got '%s'", sds.String())
}
// 测试截取
substr := sds.Substr(0, 4)
if substr.String() != "Hello" {
t.Errorf("Expected 'Hello', got '%s'", substr.String())
}
}
2. 字典测试
// dict/dict_test.go
package dict
import (
"testing"
)
func TestDict(t *testing.T) {
dict := NewDict()
// 测试设置和获取
dict.Set("key1", "value1")
dict.Set("key2", "value2")
if value, ok := dict.Get("key1"); !ok || value != "value1" {
t.Errorf("Expected 'value1', got %v", value)
}
// 测试删除
if !dict.Delete("key1") {
t.Error("Failed to delete key1")
}
if _, ok := dict.Get("key1"); ok {
t.Error("Key1 should be deleted")
}
}
3. 跳跃表测试
// skiplist/skiplist_test.go
package skiplist
import (
"testing"
)
func TestSkipList(t *testing.T) {
sl := NewSkipList()
// 测试插入
sl.Insert("key1", "value1", 1.0)
sl.Insert("key2", "value2", 2.0)
sl.Insert("key3", "value3", 1.5)
// 测试查找
if value, ok := sl.Find("key1", 1.0); !ok || value != "value1" {
t.Errorf("Expected 'value1', got %v", value)
}
// 测试范围查询
results := sl.RangeByScore(1.0, 2.0)
if len(results) != 3 {
t.Errorf("Expected 3 results, got %d", len(results))
}
}
性能分析
时间复杂度对比
操作 | SDS | 字典 | 跳跃表 | 压缩列表 |
---|---|---|---|---|
查找 | O(1) | O(1) | O(log n) | O(n) |
插入 | O(1) | O(1) | O(log n) | O(n) |
删除 | O(1) | O(1) | O(log n) | O(n) |
范围查询 | - | - | O(log n + k) | O(n) |
内存使用分析
// benchmark/memory_benchmark.go
package benchmark
import (
"testing"
"unsafe"
)
func BenchmarkSDSMemory(b *testing.B) {
for i := 0; i < b.N; i++ {
sds := NewSDS("Hello World")
_ = sds.String()
}
}
func BenchmarkDictMemory(b *testing.B) {
for i := 0; i < b.N; i++ {
dict := NewDict()
dict.Set("key", "value")
_, _ = dict.Get("key")
}
}
func BenchmarkSkipListMemory(b *testing.B) {
for i := 0; i < b.N; i++ {
sl := NewSkipList()
sl.Insert("key", "value", 1.0)
_, _ = sl.Find("key", 1.0)
}
}
面试要点
1. 为什么 Redis 选择这些数据结构?
答案要点:
- SDS:解决 C 字符串的局限性,提供二进制安全和高效操作
- 字典:哈希表提供 O(1) 的查找性能,渐进式 rehash 避免阻塞
- 跳跃表:比平衡树简单,支持范围查询,适合有序集合
- 压缩列表:内存紧凑,适合小数据量的列表和哈希
2. 渐进式 rehash 的优势
答案要点:
- 避免阻塞:将 rehash 分散到多次操作中
- 内存控制:避免大量内存分配
- 性能稳定:保持操作的响应时间稳定
3. 跳跃表 vs 平衡树
答案要点:
- 实现简单:跳跃表比平衡树更容易实现
- 并发友好:跳跃表更容易支持并发操作
- 范围查询:跳跃表天然支持范围查询
- 内存开销:跳跃表可能有额外的指针开销
总结
通过本章学习,我们深入理解了:
- Redis 核心数据结构的设计原理和实现
- 性能优化技巧和内存管理策略
- 数据结构选择对系统性能的影响
- 实际代码实现和测试验证
这些底层数据结构为 Redis 的高性能提供了坚实的基础。在下一章中,我们将学习字符串和 SDS 的详细实现,了解 Redis 如何优化字符串操作。