feat: added cache to conversation

This commit is contained in:
tt
2025-06-18 16:17:21 +08:00
parent e67acb95ef
commit 9213c5d309
11 changed files with 3027 additions and 33 deletions

View File

@@ -0,0 +1,184 @@
# WuKongIM 开发者文档集合
本文档集合为 WuKongIM 项目的开发者提供全面的技术指导和参考资料,帮助开发者快速上手并进行二次开发。
## 📚 文档目录
### 1. [WuKongIM 开发者快速上手指南](./WuKongIM_Developer_Guide.md)
**完整的开发者指南,包含:**
- 项目架构详解
- 核心模块分析
- 开发环境搭建
- API 接口文档
- 插件开发指南
- Webhook 集成
- 监控运维
- 实战开发示例
- 性能优化
- 故障排查
**适合人群:** 需要深入了解 WuKongIM 架构和进行二次开发的开发者
### 2. [WuKongIM 快速参考手册](./WuKongIM_Quick_Reference.md)
**简洁的快速参考文档,包含:**
- 快速启动命令
- 核心配置参数
- 常用 API 接口
- 客户端 SDK 使用
- 协议格式说明
- 插件开发模板
- 监控指标
- 故障排查要点
**适合人群:** 已经熟悉项目,需要快速查阅的开发者
## 🏗️ 架构概览
WuKongIM 是一个高性能的分布式即时通讯服务,采用去中心化设计,具有以下核心特性:
### 核心特性
-**去中心化设计**:无单点故障,节点平等
-**高性能**单机支持20万+并发连接
-**分布式存储**:基于 PebbleDB 的自研存储引擎
-**自动容灾**:基于魔改 Raft 协议的故障自动转移
-**多协议支持**:二进制协议 + JSON 协议WebSocket
-**插件系统**:支持动态插件扩展
-**Webhook 集成**:支持 HTTP 和 gRPC Webhook
-**监控完善**:内置 Prometheus 监控指标
### 技术栈
- **语言**Go 1.20+
- **存储**PebbleDB (LSM-Tree)
- **网络**:基于 Reactor 模式的自研网络框架
- **协议**:自定义二进制协议 + JSON-RPC
- **集群**:基于魔改 Raft 的分布式协议
- **监控**Prometheus + Grafana
## 🚀 快速开始
### 环境要求
- Go 1.20+
- Git
- Docker (可选)
### 单机启动
```bash
git clone https://github.com/WuKongIM/WuKongIM.git
cd WuKongIM
go run main.go
```
### 集群启动
```bash
# 启动三个节点
go run main.go --config ./exampleconfig/cluster1.yaml
go run main.go --config ./exampleconfig/cluster2.yaml
go run main.go --config ./exampleconfig/cluster3.yaml
```
### Docker 部署
```bash
cd docker/cluster
docker-compose up -d
```
### 访问服务
- **管理后台**: http://127.0.0.1:5300/web
- **聊天演示**: http://127.0.0.1:5172/chatdemo
- **API 文档**: http://127.0.0.1:5001/swagger
- **监控指标**: http://127.0.0.1:5001/metrics
## 📖 学习路径
### 初学者路径
1. 阅读 [快速参考手册](./WuKongIM_Quick_Reference.md) 了解基本概念
2. 按照快速开始部分搭建环境
3. 运行聊天演示,体验基本功能
4. 尝试调用 API 接口发送消息
### 进阶开发者路径
1. 深入阅读 [开发者指南](./WuKongIM_Developer_Guide.md)
2. 理解项目架构和核心模块
3. 学习插件开发和 Webhook 集成
4. 参考实战示例进行二次开发
### 运维人员路径
1. 了解部署配置和集群搭建
2. 学习监控指标和故障排查
3. 掌握性能调优方法
4. 建立运维监控体系
## 🔧 开发工具推荐
### IDE 和编辑器
- **GoLand**: JetBrains 的 Go IDE功能强大
- **VS Code**: 轻量级编辑器,配合 Go 插件使用
- **Vim/Neovim**: 命令行编辑器,配合 vim-go 插件
### 调试工具
- **Delve**: Go 语言调试器
- **pprof**: Go 性能分析工具
- **Postman**: API 接口测试工具
- **WebSocket King**: WebSocket 连接测试工具
### 监控工具
- **Prometheus**: 监控数据收集
- **Grafana**: 监控数据可视化
- **Jaeger**: 分布式链路追踪
## 🤝 贡献指南
### 代码贡献
1. Fork 项目到个人仓库
2. 创建功能分支 (`git checkout -b feature/amazing-feature`)
3. 提交更改 (`git commit -m 'Add some amazing feature'`)
4. 推送到分支 (`git push origin feature/amazing-feature`)
5. 创建 Pull Request
### 文档贡献
- 发现文档错误或不清楚的地方,欢迎提交 Issue
- 可以直接提交 PR 改进文档
- 分享使用经验和最佳实践
### 代码规范
- 遵循 Go 官方编码规范
- 使用 `gofmt` 格式化代码
- 添加必要的注释和文档
- 编写单元测试
## 📞 获取帮助
### 官方资源
- **官方网站**: https://githubim.com
- **GitHub 仓库**: https://github.com/WuKongIM/WuKongIM
- **API 文档**: https://githubim.com/api
- **SDK 文档**: https://githubim.com/sdk
### 社区支持
- **GitHub Issues**: 报告 Bug 和功能请求
- **GitHub Discussions**: 技术讨论和经验分享
- **官方文档**: 查看详细的使用说明
### 常见问题
在提问前,请先查看:
1. [开发者指南](./WuKongIM_Developer_Guide.md) 中的常见问题部分
2. GitHub Issues 中的已知问题
3. 官方文档中的 FAQ 部分
## 📝 更新日志
文档会随着项目的更新而持续维护,主要更新内容:
- 新功能的使用说明
- API 接口的变更
- 配置参数的调整
- 最佳实践的补充
## 📄 许可证
本文档遵循与 WuKongIM 项目相同的许可证。
---
**开始你的 WuKongIM 开发之旅吧!** 🚀
如果这些文档对你有帮助,请给项目一个 ⭐️ Star

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,372 @@
# WuKongIM 快速参考手册
## 快速启动
### 单机模式
```bash
git clone https://github.com/WuKongIM/WuKongIM.git
cd WuKongIM
go run main.go
```
### 集群模式
```bash
# 节点1
go run main.go --config ./exampleconfig/cluster1.yaml
# 节点2
go run main.go --config ./exampleconfig/cluster2.yaml
# 节点3
go run main.go --config ./exampleconfig/cluster3.yaml
```
### Docker 部署
```bash
# 单机
docker run -d --name wukongim \
-p 5001:5001 -p 5100:5100 -p 5200:5200 \
registry.cn-shanghai.aliyuncs.com/wukongim/wukongim:v2
# 集群
cd docker/cluster && docker-compose up -d
```
## 核心配置
### 基础配置 (config/wk.yaml)
```yaml
mode: "release" # 运行模式
addr: "tcp://0.0.0.0:5100" # TCP 地址
httpAddr: "0.0.0.0:5001" # HTTP API 地址
wsAddr: "ws://0.0.0.0:5200" # WebSocket 地址
rootDir: "./wukongimdata" # 数据目录
tokenAuthOn: false # Token 认证
```
### 集群配置
```yaml
cluster:
nodeId: 1001 # 节点 ID
serverAddr: "0.0.0.0:11110" # 集群通信地址
initNodes: # 初始节点
1001: "127.0.0.1:11110"
1002: "127.0.0.1:11111"
1003: "127.0.0.1:11112"
```
## 核心 API
### 发送消息
```http
POST /message/send
Content-Type: application/json
{
"from_uid": "sender123",
"channel_id": "channel456",
"channel_type": 2,
"payload": {
"type": "text",
"content": "Hello World"
}
}
```
### 创建频道
```http
POST /channel
{
"channel_id": "group789",
"channel_type": 2,
"large": false,
"ban": false
}
```
### 添加订阅者
```http
POST /channel/subscriber_add
{
"channel_id": "group789",
"channel_type": 2,
"subscribers": [
{"uid": "user123", "role": 1},
{"uid": "user456", "role": 0}
]
}
```
### 同步消息
```http
POST /channel/messagesync
{
"login_uid": "user123",
"channel_id": "channel456",
"channel_type": 2,
"start_message_seq": 0,
"end_message_seq": 100,
"limit": 50
}
```
## 客户端 SDK
### JavaScript SDK
```javascript
import { WKIM, WKIMChannelType, WKIMEvent } from 'easyjssdk';
// 初始化
const im = WKIM.init("ws://localhost:5200", {
uid: "user123",
token: "auth_token"
});
// 监听连接
im.on(WKIMEvent.Connect, () => {
console.log("Connected!");
});
// 监听消息
im.on(WKIMEvent.Message, (message) => {
console.log("Received:", message);
});
// 发送消息
await im.send("target_user", WKIMChannelType.Person, {
type: "text",
content: "Hello!"
});
// 连接
await im.connect();
```
### Android SDK
```java
// 初始化
WKIM.getInstance().init(context, "ws://localhost:5200");
// 连接
WKIMConnectOptions options = new WKIMConnectOptions();
options.uid = "user123";
options.token = "auth_token";
WKIM.getInstance().connect(options);
// 发送消息
WKTextContent textContent = new WKTextContent("Hello!");
WKIM.getInstance().sendMessage(textContent, "target_user", WKChannelType.PERSON);
// 监听消息
WKIM.getInstance().addOnNewMsgListener(new INewMsgListener() {
@Override
public void newMsg(WKMsg msg) {
// 处理新消息
}
});
```
## 协议格式
### 二进制协议
```
+----------+----------+----------+----------+
| Type | Flag | Length | Payload |
| (1 byte) | (1 byte) | (4 bytes)| (变长) |
+----------+----------+----------+----------+
```
### 消息类型
- `CONNECT(1)`: 连接请求
- `CONNACK(2)`: 连接响应
- `SEND(3)`: 发送消息
- `SENDACK(4)`: 发送响应
- `RECV(5)`: 接收消息
- `RECVACK(6)`: 接收确认
- `PING(7)`: 心跳请求
- `PONG(8)`: 心跳响应
### JSON 协议 (WebSocket)
```json
// 请求
{
"id": "request_id",
"method": "send",
"params": {
"clientMsgNo": "uuid",
"channelId": "channel_id",
"channelType": 1,
"payload": {"type": "text", "content": "Hello"}
}
}
// 响应
{
"id": "request_id",
"result": {
"messageId": "123456",
"clientSeq": 1
}
}
```
## 插件开发
### 插件结构
```go
type MyPlugin struct {
plugin.BasePlugin
}
func (p *MyPlugin) Send(ctx context.Context, req *proto.SendPacket) (*proto.SendPacket, error) {
// 处理发送消息
return req, nil
}
func main() {
plugin.Run(&plugin.Config{
Info: plugin.PluginInfo{
No: "my-plugin",
Name: "My Plugin",
Version: "1.0.0",
},
Plugin: &MyPlugin{},
})
}
```
### 插件方法
- `Send`: 发送消息处理
- `ChannelGet`: 获取频道信息
- `UserGet`: 获取用户信息
- `MessageStore`: 消息存储处理
## Webhook 配置
### HTTP Webhook
```yaml
webhook:
httpUrl: "http://your-server.com/webhook"
events:
- "user.connect"
- "user.disconnect"
- "message.send"
- "message.receive"
```
### 事件格式
```json
{
"event": "message.send",
"data": {
"message_id": "123456789",
"from_uid": "sender123",
"channel_id": "channel456",
"channel_type": 2,
"payload": {"type": "text", "content": "Hello"},
"timestamp": 1640995200
}
}
```
## 监控指标
### Prometheus 指标
- `wukongim_connections_total`: 总连接数
- `wukongim_messages_sent_total`: 发送消息数
- `wukongim_messages_received_total`: 接收消息数
- `wukongim_memory_usage_bytes`: 内存使用量
- `wukongim_cpu_usage_percent`: CPU 使用率
### 健康检查
```bash
curl http://localhost:5001/health
```
### 系统信息
```bash
curl http://localhost:5001/varz
```
## 常用命令
### 编译
```bash
go build -o wukongim main.go
```
### 测试
```bash
go test ./...
```
### 性能测试
```bash
go test -bench=. ./pkg/wkdb
```
### 查看日志
```bash
tail -f logs/wukongim.log
```
## 故障排查
### 常见问题
1. **连接失败**: 检查端口和防火墙
2. **消息丢失**: 检查频道订阅关系
3. **集群同步**: 检查网络和时钟同步
4. **性能问题**: 调整连接池和缓存配置
### 错误码
- `1001`: 连接失败
- `1002`: 认证失败
- `2001`: 消息过大
- `3001`: 频道不存在
- `5001`: 系统过载
### 日志分析
```bash
# 查看错误
grep "ERROR" logs/wukongim.log
# 统计连接数
grep "connection" logs/wukongim.log | wc -l
# 查看慢查询
grep "slow" logs/wukongim.log
```
## 性能调优
### 连接配置
```yaml
userMsgQueueMaxSize: 1000 # 用户消息队列大小
connectTimeout: 30s # 连接超时
writeTimeout: 10s # 写超时
readTimeout: 10s # 读超时
```
### 数据库配置
```yaml
db:
memTableSize: 64MB # 内存表大小
blockCacheSize: 256MB # 块缓存大小
writeBufferSize: 32MB # 写缓冲区大小
maxOpenFiles: 1000 # 最大打开文件数
```
## 访问地址
- **管理后台**: http://127.0.0.1:5300/web
- **聊天演示**: http://127.0.0.1:5172/chatdemo
- **API 文档**: http://127.0.0.1:5001/swagger
- **监控指标**: http://127.0.0.1:5001/metrics
- **健康检查**: http://127.0.0.1:5001/health
## 相关链接
- [官方文档](https://githubim.com)
- [GitHub](https://github.com/WuKongIM/WuKongIM)
- [问题反馈](https://github.com/WuKongIM/WuKongIM/issues)
- [更新日志](https://github.com/WuKongIM/WuKongIM/releases)

2
go.mod
View File

@@ -20,6 +20,7 @@ require (
github.com/grafana/loki/v3 v3.2.1
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/lni/goutils v1.4.0
github.com/nsqio/go-diskqueue v1.1.0
github.com/panjf2000/ants/v2 v2.11.0
github.com/panjf2000/gnet/v2 v2.7.1
github.com/pkg/errors v0.9.1
@@ -46,7 +47,6 @@ require (
require (
github.com/WuKongIM/wklog v0.0.0-20250123094253-32484fb54d05 // indirect
github.com/nsqio/go-diskqueue v1.1.0 // indirect
go.opentelemetry.io/otel/sdk v1.28.0 // indirect
go.opentelemetry.io/otel/trace v1.28.0 // indirect
)

2
go.sum
View File

@@ -270,8 +270,6 @@ github.com/Shopify/toxiproxy/v2 v2.5.0/go.mod h1:yhM2epWtAmel9CB8r2+L+PCmhH6yH2p
github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg=
github.com/Workiva/go-datastructures v1.1.5 h1:5YfhQ4ry7bZc2Mc7R0YZyYwpf5c6t1cEFvdAhd6Mkf4=
github.com/Workiva/go-datastructures v1.1.5/go.mod h1:1yZL+zfsztete+ePzZz/Zb1/t5BnDuE2Ya2MMGhzP6A=
github.com/WuKongIM/WuKongIMGoProto v1.0.9 h1:yJAhWc95/gwCRDsNiAjdiGkFg2upl1hBLoSwrk3zGrE=
github.com/WuKongIM/WuKongIMGoProto v1.0.9/go.mod h1:dUQCRuqwMoyYeLiHTsLBfbiWlVtB+8Gdsyq1M1oeEzg=
github.com/WuKongIM/WuKongIMGoProto v1.1.2-0.20250618030603-b0f04ad80a21 h1:AC3w0DDREI566Tc02VaG2oAohYgo57S0zm51LcVrNZA=
github.com/WuKongIM/WuKongIMGoProto v1.1.2-0.20250618030603-b0f04ad80a21/go.mod h1:dUQCRuqwMoyYeLiHTsLBfbiWlVtB+8Gdsyq1M1oeEzg=
github.com/WuKongIM/crypto v0.0.0-20240416072338-b872b70b395f h1:erzPrCjuS7yvfpMyUxQjaMDHhBicUKt/qxwC/8s25VQ=

View File

@@ -0,0 +1,264 @@
package wkdb
import (
"crypto/md5"
"fmt"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/WuKongIM/WuKongIM/pkg/wklog"
lru "github.com/hashicorp/golang-lru/v2"
)
// ConversationCache 会话缓存 - 专注于 GetLastConversations 结果缓存
type ConversationCache struct {
// GetLastConversations 结果缓存 key: uid:tp:updatedAt:excludeChannelTypes:limit
lastConversationsCache *lru.Cache[string, *LastConversationsResult]
// 读写锁,保护缓存操作
mu sync.RWMutex
// 配置
maxCacheSize int // 缓存最大数量
cacheTTL time.Duration // 缓存过期时间
wklog.Log
}
// LastConversationsResult GetLastConversations 的缓存结果
type LastConversationsResult struct {
Conversations []Conversation `json:"conversations"`
CachedAt time.Time `json:"cached_at"`
TTL time.Duration `json:"ttl"`
}
// IsExpired 检查缓存是否过期
func (r *LastConversationsResult) IsExpired() bool {
return time.Since(r.CachedAt) > r.TTL
}
// NewConversationCache 创建会话缓存
func NewConversationCache(maxCacheSize int) *ConversationCache {
if maxCacheSize <= 0 {
maxCacheSize = 2000 // 默认缓存2000个查询结果
}
lastConversationsCache, _ := lru.New[string, *LastConversationsResult](maxCacheSize)
return &ConversationCache{
lastConversationsCache: lastConversationsCache,
maxCacheSize: maxCacheSize,
cacheTTL: 2 * time.Minute, // 缓存2分钟
Log: wklog.NewWKLog("ConversationCache"),
}
}
// GetLastConversations 从缓存获取 GetLastConversations 结果
func (c *ConversationCache) GetLastConversations(uid string, tp ConversationType, updatedAt uint64, excludeChannelTypes []uint8, limit int) ([]Conversation, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
key := c.getLastConversationsKey(uid, tp, updatedAt, excludeChannelTypes, limit)
if result, ok := c.lastConversationsCache.Get(key); ok {
if !result.IsExpired() {
return result.Conversations, true
}
// 缓存过期,异步删除
go func() {
c.mu.Lock()
c.lastConversationsCache.Remove(key)
c.mu.Unlock()
}()
}
return nil, false
}
// SetLastConversations 设置 GetLastConversations 结果到缓存
func (c *ConversationCache) SetLastConversations(uid string, tp ConversationType, updatedAt uint64, excludeChannelTypes []uint8, limit int, conversations []Conversation) {
c.mu.Lock()
defer c.mu.Unlock()
key := c.getLastConversationsKey(uid, tp, updatedAt, excludeChannelTypes, limit)
// 创建副本避免外部修改影响缓存
conversationsCopy := make([]Conversation, len(conversations))
copy(conversationsCopy, conversations)
result := &LastConversationsResult{
Conversations: conversationsCopy,
CachedAt: time.Now(),
TTL: c.cacheTTL,
}
c.lastConversationsCache.Add(key, result)
}
// InvalidateUserConversations 使指定用户的所有会话缓存失效
func (c *ConversationCache) InvalidateUserConversations(uid string) {
c.mu.Lock()
defer c.mu.Unlock()
// 删除该用户相关的所有缓存
keys := c.lastConversationsCache.Keys()
for _, key := range keys {
if strings.HasPrefix(key, uid+":") {
c.lastConversationsCache.Remove(key)
}
}
}
// UpdateConversationsInCache 智能更新缓存中的会话数据
func (c *ConversationCache) UpdateConversationsInCache(conversations []Conversation) {
c.mu.Lock()
defer c.mu.Unlock()
// 按用户分组
userConversations := make(map[string][]Conversation)
for _, conv := range conversations {
userConversations[conv.Uid] = append(userConversations[conv.Uid], conv)
}
// 为每个用户更新缓存
for uid, convs := range userConversations {
c.updateUserConversationsInCache(uid, convs)
}
}
// updateUserConversationsInCache 更新指定用户的缓存
func (c *ConversationCache) updateUserConversationsInCache(uid string, updatedConversations []Conversation) {
// 获取该用户相关的所有缓存键
keys := c.lastConversationsCache.Keys()
userKeys := make([]string, 0)
for _, key := range keys {
if strings.HasPrefix(key, uid+":") {
userKeys = append(userKeys, key)
}
}
// 为每个缓存项更新数据
for _, key := range userKeys {
if result, ok := c.lastConversationsCache.Get(key); ok && !result.IsExpired() {
// 创建会话 channelId+channelType 到新会话的映射
updatedMap := make(map[string]Conversation)
for _, conv := range updatedConversations {
channelKey := c.getChannelKey(conv.ChannelId, conv.ChannelType)
updatedMap[channelKey] = conv
}
// 更新缓存中的会话数据
updated := false
newConversations := make([]Conversation, 0, len(result.Conversations))
// 首先处理已存在的会话
for _, cachedConv := range result.Conversations {
channelKey := c.getChannelKey(cachedConv.ChannelId, cachedConv.ChannelType)
if updatedConv, exists := updatedMap[channelKey]; exists {
// 找到匹配的会话,使用新数据
newConversations = append(newConversations, updatedConv)
updated = true
delete(updatedMap, channelKey) // 从映射中移除已处理的
} else {
// 保持原有数据
newConversations = append(newConversations, cachedConv)
}
}
// 检查是否有新增的会话在updatedMap中剩余的
// 将新增的会话添加到缓存中
for _, newConv := range updatedMap {
newConversations = append(newConversations, newConv)
updated = true
}
// 如果有更新,重新缓存
if updated {
// 按照更新时间重新排序(保持与 GetLastConversations 一致的排序)
sort.Slice(newConversations, func(i, j int) bool {
c1 := newConversations[i]
c2 := newConversations[j]
if c1.UpdatedAt == nil {
return false
}
if c2.UpdatedAt == nil {
return true
}
return c1.UpdatedAt.After(*c2.UpdatedAt)
})
newResult := &LastConversationsResult{
Conversations: newConversations,
CachedAt: time.Now(),
TTL: c.cacheTTL,
}
c.lastConversationsCache.Add(key, newResult)
}
}
}
}
// getChannelKey 生成频道唯一键
func (c *ConversationCache) getChannelKey(channelId string, channelType uint8) string {
return fmt.Sprintf("%s:%d", channelId, channelType)
}
// GetCacheStats 获取缓存统计信息
func (c *ConversationCache) GetCacheStats() map[string]interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return map[string]interface{}{
"last_conversations_cache_len": c.lastConversationsCache.Len(),
"last_conversations_cache_max": c.maxCacheSize,
"cache_ttl_seconds": c.cacheTTL.Seconds(),
}
}
// ClearCache 清空所有缓存
func (c *ConversationCache) ClearCache() {
c.mu.Lock()
defer c.mu.Unlock()
c.lastConversationsCache.Purge()
c.Info("Conversation cache cleared")
}
// 生成 GetLastConversations 缓存键
func (c *ConversationCache) getLastConversationsKey(uid string, tp ConversationType, updatedAt uint64, excludeChannelTypes []uint8, limit int) string {
// 构建基础键
keyParts := []string{
uid,
strconv.Itoa(int(tp)),
strconv.FormatUint(updatedAt, 10),
strconv.Itoa(limit),
}
// 处理 excludeChannelTypes
if len(excludeChannelTypes) > 0 {
// 排序确保一致性
excludeTypes := make([]uint8, len(excludeChannelTypes))
copy(excludeTypes, excludeChannelTypes)
sort.Slice(excludeTypes, func(i, j int) bool {
return excludeTypes[i] < excludeTypes[j]
})
excludeStrs := make([]string, len(excludeTypes))
for i, t := range excludeTypes {
excludeStrs[i] = strconv.Itoa(int(t))
}
keyParts = append(keyParts, strings.Join(excludeStrs, ","))
} else {
keyParts = append(keyParts, "")
}
key := strings.Join(keyParts, ":")
// 如果键太长,使用 MD5 哈希
if len(key) > 200 {
hash := md5.Sum([]byte(key))
return fmt.Sprintf("%s:%x", uid, hash)
}
return key
}

View File

@@ -0,0 +1,198 @@
package wkdb_test
import (
"testing"
"github.com/WuKongIM/WuKongIM/pkg/wkdb"
"github.com/stretchr/testify/assert"
)
func TestConversationCache(t *testing.T) {
cache := wkdb.NewConversationCache(1000)
// 测试 GetLastConversations 缓存
conversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1, UnreadCount: 5},
{Id: 2, Uid: "user1", ChannelId: "channel2", ChannelType: 1, UnreadCount: 3},
}
// 设置缓存
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10, conversations)
// 获取缓存
cached, found := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.True(t, found)
assert.Len(t, cached, 2)
assert.Equal(t, uint64(1), cached[0].Id)
assert.Equal(t, uint64(2), cached[1].Id)
// 测试缓存失效
cache.InvalidateUserConversations("user1")
_, found = cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.False(t, found)
}
func TestCacheStats(t *testing.T) {
cache := wkdb.NewConversationCache(1000)
// 添加一些数据
conversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1},
}
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10, conversations)
// 获取缓存统计
stats := cache.GetCacheStats()
assert.Equal(t, 1, stats["last_conversations_cache_len"])
assert.Equal(t, 1000, stats["last_conversations_cache_max"])
assert.Equal(t, 120.0, stats["cache_ttl_seconds"]) // 2分钟 = 120秒
}
func TestCacheClear(t *testing.T) {
cache := wkdb.NewConversationCache(1000)
// 添加一些数据
conversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1},
}
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10, conversations)
// 验证数据存在
_, found := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.True(t, found)
// 清空缓存
cache.ClearCache()
// 验证缓存已清空
_, found = cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.False(t, found)
stats := cache.GetCacheStats()
assert.Equal(t, 0, stats["last_conversations_cache_len"])
}
func TestCacheKeyGeneration(t *testing.T) {
cache := wkdb.NewConversationCache(1000)
conversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1},
}
// 测试不同参数生成不同的缓存键
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10, conversations)
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 20, conversations) // 不同的limit
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 100, nil, 10, conversations) // 不同的updatedAt
// 验证不同参数的缓存是独立的
cached1, found1 := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.True(t, found1)
assert.Len(t, cached1, 1)
cached2, found2 := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 20)
assert.True(t, found2)
assert.Len(t, cached2, 1)
cached3, found3 := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 100, nil, 10)
assert.True(t, found3)
assert.Len(t, cached3, 1)
// 验证缓存统计
stats := cache.GetCacheStats()
assert.Equal(t, 3, stats["last_conversations_cache_len"])
}
func TestCacheWithExcludeChannelTypes(t *testing.T) {
cache := wkdb.NewConversationCache(1000)
conversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1},
}
// 测试带有 excludeChannelTypes 的缓存
excludeTypes1 := []uint8{1, 2}
excludeTypes2 := []uint8{2, 1} // 相同内容但顺序不同,应该生成相同的键
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, excludeTypes1, 10, conversations)
// 验证顺序不同但内容相同的 excludeChannelTypes 能命中缓存
cached, found := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, excludeTypes2, 10)
assert.True(t, found)
assert.Len(t, cached, 1)
// 验证不同的 excludeChannelTypes 不会命中缓存
excludeTypes3 := []uint8{3, 4}
_, found = cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, excludeTypes3, 10)
assert.False(t, found)
}
func TestUpdateConversationsInCache(t *testing.T) {
cache := wkdb.NewConversationCache(1000)
// 初始会话数据
originalConversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1, UnreadCount: 5},
{Id: 2, Uid: "user1", ChannelId: "channel2", ChannelType: 1, UnreadCount: 3},
{Id: 3, Uid: "user1", ChannelId: "channel3", ChannelType: 1, UnreadCount: 1},
}
// 设置初始缓存
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10, originalConversations)
// 验证初始缓存
cached, found := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.True(t, found)
assert.Len(t, cached, 3)
assert.Equal(t, uint32(5), cached[0].UnreadCount)
assert.Equal(t, uint32(3), cached[1].UnreadCount)
// 更新部分会话数据
updatedConversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1, UnreadCount: 10}, // 更新未读数
{Id: 2, Uid: "user1", ChannelId: "channel2", ChannelType: 1, UnreadCount: 8}, // 更新未读数
}
// 智能更新缓存
cache.UpdateConversationsInCache(updatedConversations)
// 验证缓存已更新
cachedAfterUpdate, found := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.True(t, found)
assert.Len(t, cachedAfterUpdate, 3)
// 验证更新的会话
assert.Equal(t, uint32(10), cachedAfterUpdate[0].UnreadCount) // 已更新
assert.Equal(t, uint32(8), cachedAfterUpdate[1].UnreadCount) // 已更新
assert.Equal(t, uint32(1), cachedAfterUpdate[2].UnreadCount) // 未更新,保持原值
}
func TestUpdateConversationsInCacheWithNewConversation(t *testing.T) {
cache := wkdb.NewConversationCache(1000)
// 初始会话数据
originalConversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1, UnreadCount: 5},
{Id: 2, Uid: "user1", ChannelId: "channel2", ChannelType: 1, UnreadCount: 3},
}
// 设置初始缓存
cache.SetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10, originalConversations)
// 验证初始缓存
cached, found := cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.True(t, found)
assert.Len(t, cached, 2)
// 更新数据,包含新增会话
updatedConversations := []wkdb.Conversation{
{Id: 1, Uid: "user1", ChannelId: "channel1", ChannelType: 1, UnreadCount: 10}, // 更新现有
{Id: 3, Uid: "user1", ChannelId: "channel3", ChannelType: 1, UnreadCount: 7}, // 新增会话
}
// 智能更新缓存
cache.UpdateConversationsInCache(updatedConversations)
// 验证缓存已失效(因为有新增会话)
_, found = cache.GetLastConversations("user1", wkdb.ConversationTypeChat, 0, nil, 10)
assert.False(t, found, "Cache should be invalidated when new conversations are added")
}

View File

@@ -69,6 +69,9 @@ func (wk *wukongDB) AddOrUpdateConversations(conversations []Conversation) error
return nil
}
// 智能更新缓存中的会话数据
wk.conversationCache.UpdateConversationsInCache(conversations)
return nil
}
@@ -168,7 +171,15 @@ func (wk *wukongDB) AddOrUpdateConversationsWithUser(uid string, conversations [
// return err
// }
return batch.CommitWait()
err = batch.CommitWait()
if err != nil {
return err
}
// 智能更新缓存中的会话数据
wk.conversationCache.UpdateConversationsInCache(conversations)
return nil
}
// UpdateConversationDeletedAtMsgSeq 更新最近会话的已删除的消息序号位置
@@ -188,6 +199,7 @@ func (wk *wukongDB) UpdateConversationDeletedAtMsgSeq(uid string, channelId stri
if err != nil {
return err
}
wk.conversationCache.InvalidateUserConversations(uid)
return w.Commit(wk.sync)
}
@@ -210,6 +222,7 @@ func (wk *wukongDB) UpdateConversationIfSeqGreaterAsync(uid, channelId string, c
var msgSeqBytes = make([]byte, 8)
wk.endian.PutUint64(msgSeqBytes, readToMsgSeq)
w.Set(key.NewConversationColumnKey(uid, existConversation.Id, key.TableConversation.Column.ReadedToMsgSeq), msgSeqBytes)
wk.conversationCache.InvalidateUserConversations(uid)
return w.Commit()
}
@@ -218,6 +231,7 @@ func (wk *wukongDB) GetConversations(uid string) ([]Conversation, error) {
wk.metrics.GetConversationsAdd(1)
// 直接从数据库获取(不再单独缓存,由 GetLastConversations 统一缓存)
db := wk.shardDB(uid)
iter := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationPrimaryKey(uid, 0),
@@ -233,6 +247,7 @@ func (wk *wukongDB) GetConversations(uid string) ([]Conversation, error) {
if err != nil {
return nil, err
}
return conversations, nil
}
@@ -270,6 +285,12 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
wk.metrics.GetLastConversationsAdd(1)
// 先从缓存获取
if cached, found := wk.conversationCache.GetLastConversations(uid, tp, updatedAt, excludeChannelTypes, limit); found {
return cached, nil
}
// 缓存未命中,从数据库获取
ids, err := wk.getLastConversationIds(uid, updatedAt, limit)
if err != nil {
return nil, err
@@ -278,21 +299,20 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
return nil, nil
}
conversations := make([]Conversation, 0, len(ids))
// 使用批量查询优化避免N+1查询问题
conversations, err := wk.getConversationsBatch(uid, ids)
if err != nil {
return nil, err
}
for _, id := range ids {
conversation, err := wk.getConversation(uid, id)
if err != nil && err != ErrNotFound {
return nil, err
}
if err == ErrNotFound {
continue
}
// 过滤会话类型和排除的频道类型
filteredConversations := make([]Conversation, 0, len(conversations))
for _, conversation := range conversations {
if conversation.Type != tp {
continue
}
exclude := false
exclude := false
if len(excludeChannelTypes) > 0 {
for _, excludeChannelType := range excludeChannelTypes {
if conversation.ChannelType == excludeChannelType {
@@ -305,15 +325,16 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
continue
}
conversations = append(conversations, conversation)
filteredConversations = append(filteredConversations, conversation)
}
// conversations 根据id去重复
conversations = uniqueConversation(conversations)
filteredConversations = uniqueConversation(filteredConversations)
// 按照更新时间排序
sort.Slice(conversations, func(i, j int) bool {
c1 := conversations[i]
c2 := conversations[j]
sort.Slice(filteredConversations, func(i, j int) bool {
c1 := filteredConversations[i]
c2 := filteredConversations[j]
if c1.UpdatedAt == nil {
return false
}
@@ -323,7 +344,10 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
return c1.UpdatedAt.After(*c2.UpdatedAt)
})
return conversations, nil
// 将结果写入缓存
wk.conversationCache.SetLastConversations(uid, tp, updatedAt, excludeChannelTypes, limit, filteredConversations)
return filteredConversations, nil
}
func (wk *wukongDB) GetChannelConversationLocalUsers(channelId string, channelType uint8) ([]string, error) {
@@ -382,6 +406,7 @@ func removeDupliConversationByChannel(conversations []Conversation) []Conversati
}
func (wk *wukongDB) getLastConversationIds(uid string, updatedAt uint64, limit int) ([]uint64, error) {
// 直接从数据库获取不再单独缓存ID列表
db := wk.shardDB(uid)
iter := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationSecondIndexKey(uid, key.TableConversation.SecondIndex.UpdatedAt, updatedAt, 0),
@@ -418,6 +443,7 @@ func (wk *wukongDB) getLastConversationIds(uid string, updatedAt uint64, limit i
wk.Warn("getLastConversationIds duplicate ids", zap.Int("oldCount", len(ids)), zap.Int("newCount", len(uniqueIdsMap)))
}
// 不再单独缓存ID列表由 GetLastConversations 统一缓存最终结果
return uniqueIdsMap, nil
}
@@ -437,7 +463,15 @@ func (wk *wukongDB) DeleteConversation(uid string, channelId string, channelType
return err
}
return batch.CommitWait()
err = batch.CommitWait()
if err != nil {
return err
}
// 使相关缓存失效
wk.conversationCache.InvalidateUserConversations(uid)
return nil
}
@@ -460,7 +494,15 @@ func (wk *wukongDB) DeleteConversations(uid string, channels []Channel) error {
return err
}
return batch.CommitWait()
err = batch.CommitWait()
if err != nil {
return err
}
// 使相关缓存失效
wk.conversationCache.InvalidateUserConversations(uid)
return nil
}
func (wk *wukongDB) SearchConversation(req ConversationSearchReq) ([]Conversation, error) {
@@ -525,6 +567,7 @@ func (wk *wukongDB) GetConversation(uid string, channelId string, channelType ui
wk.metrics.GetConversationAdd(1)
// 直接从数据库获取(不再单独缓存单个会话)
id, err := wk.getConversationIdByChannel(uid, channelId, channelType)
if err != nil {
return EmptyConversation, err
@@ -621,6 +664,114 @@ func (wk *wukongDB) getConversation(uid string, id uint64) (Conversation, error)
return conversation, nil
}
// getConversationsBatch 批量获取会话避免N+1查询问题
func (wk *wukongDB) getConversationsBatch(uid string, ids []uint64) ([]Conversation, error) {
if len(ids) == 0 {
return nil, nil
}
// 先尝试从缓存获取部分数据
conversations := make([]Conversation, 0, len(ids))
missingIds := make([]uint64, 0)
// 检查缓存中已有的会话
for _, id := range ids {
// 这里需要通过ID查找会话但缓存是按 uid:channelId:channelType 索引的
// 所以我们还是需要从数据库查询,但可以批量缓存结果
missingIds = append(missingIds, id)
}
if len(missingIds) == 0 {
return conversations, nil
}
// 从数据库获取缺失的会话
var dbConversations []Conversation
var err error
// 如果ID数量较少使用优化的多范围查询
if len(missingIds) <= 10 {
dbConversations, err = wk.getConversationsBatchOptimized(uid, missingIds)
} else {
// ID数量较多时使用全表扫描+过滤的方式
dbConversations, err = wk.getConversationsBatchFiltered(uid, missingIds)
}
if err != nil {
return nil, err
}
// 不再单独缓存批量查询结果,由 GetLastConversations 统一缓存
// 合并结果
conversations = append(conversations, dbConversations...)
return conversations, nil
}
// getConversationsBatchOptimized 对少量ID使用多个精确范围查询
func (wk *wukongDB) getConversationsBatchOptimized(uid string, ids []uint64) ([]Conversation, error) {
conversations := make([]Conversation, 0, len(ids))
db := wk.shardDB(uid)
for _, id := range ids {
iter := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationColumnKey(uid, id, key.MinColumnKey),
UpperBound: key.NewConversationColumnKey(uid, id, key.MaxColumnKey),
})
var conversation = EmptyConversation
err := wk.iterateConversation(iter, func(cn Conversation) bool {
conversation = cn
return false
})
iter.Close()
if err != nil {
return nil, err
}
if conversation != EmptyConversation {
conversations = append(conversations, conversation)
}
}
return conversations, nil
}
// getConversationsBatchFiltered 对大量ID使用全表扫描+过滤
func (wk *wukongDB) getConversationsBatchFiltered(uid string, ids []uint64) ([]Conversation, error) {
// 创建ID集合用于快速查找
idSet := make(map[uint64]struct{}, len(ids))
for _, id := range ids {
idSet[id] = struct{}{}
}
// 使用单个迭代器查询所有会话数据
db := wk.shardDB(uid)
iter := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationPrimaryKey(uid, 0),
UpperBound: key.NewConversationPrimaryKey(uid, math.MaxUint64),
})
defer iter.Close()
conversations := make([]Conversation, 0, len(ids))
err := wk.iterateConversation(iter, func(conversation Conversation) bool {
// 只收集我们需要的会话ID
if _, exists := idSet[conversation.Id]; exists {
conversations = append(conversations, conversation)
// 如果已经找到所有需要的会话,可以提前退出
if len(conversations) == len(ids) {
return false
}
}
return true
})
if err != nil {
return nil, err
}
return conversations, nil
}
// func (wk *wukongDB) getConversationIdsByUid(uid string) ([]uint64, error) {
// iter := wk.shardDB(uid).NewIter(&pebble.IterOptions{
// LowerBound: key.NewConversationPrimaryKey(uid, 0),

View File

@@ -0,0 +1,379 @@
package wkdb_test
import (
"fmt"
"testing"
"time"
"github.com/WuKongIM/WuKongIM/pkg/wkdb"
"github.com/stretchr/testify/assert"
)
// 测试缓存对 GetLastConversations 性能的影响
func TestGetLastConversationsWithCache(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "cache_test_user"
now := time.Now()
// 创建大量会话数据
conversations := make([]wkdb.Conversation, 0, 50)
for i := 0; i < 50; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),
Uid: uid,
ChannelId: fmt.Sprintf("channel_%d", i),
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
ReadToMsgSeq: uint64(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
// 添加会话
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 第一次查询(缓存未命中)
start1 := time.Now()
result1, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 20)
duration1 := time.Since(start1)
assert.NoError(t, err)
assert.LessOrEqual(t, len(result1), 20)
// 第二次查询(缓存命中)
start2 := time.Now()
result2, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 20)
duration2 := time.Since(start2)
assert.NoError(t, err)
assert.Equal(t, len(result1), len(result2))
// 缓存命中的查询应该更快
t.Logf("First query (cache miss): %v", duration1)
t.Logf("Second query (cache hit): %v", duration2)
// 通常缓存命中应该比缓存未命中快很多
if duration2 < duration1 {
speedup := float64(duration1) / float64(duration2)
t.Logf("Cache hit is %.2fx faster", speedup)
}
// 验证结果一致性
assert.Equal(t, len(result1), len(result2))
for i := range result1 {
assert.Equal(t, result1[i].Id, result2[i].Id)
assert.Equal(t, result1[i].ChannelId, result2[i].ChannelId)
}
}
// 基准测试:对比有缓存和无缓存的性能
func BenchmarkGetLastConversationsWithCache(b *testing.B) {
d := newTestDB(b)
err := d.Open()
assert.NoError(b, err)
defer func() {
err := d.Close()
assert.NoError(b, err)
}()
uid := "bench_user"
now := time.Now()
// 创建测试数据
conversations := make([]wkdb.Conversation, 0, 30)
for i := 0; i < 30; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),
Uid: uid,
ChannelId: fmt.Sprintf("channel_%d", i),
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
ReadToMsgSeq: uint64(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
assert.NoError(b, err)
}
}
// 测试智能缓存更新的性能
func TestSmartCacheUpdatePerformance(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "smart_update_test_user"
now := time.Now()
// 创建测试数据
conversations := make([]wkdb.Conversation, 0, 20)
for i := 0; i < 20; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),
Uid: uid,
ChannelId: fmt.Sprintf("channel_%d", i),
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
ReadToMsgSeq: uint64(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
// 添加会话
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 第一次查询(缓存未命中)
result1, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
assert.NoError(t, err)
assert.Len(t, result1, 10)
// 第二次查询(缓存命中)
start2 := time.Now()
result2, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
duration2 := time.Since(start2)
assert.NoError(t, err)
assert.Equal(t, len(result1), len(result2))
// 更新部分会话(智能更新缓存)
conversations[0].UnreadCount = 100
conversations[1].UnreadCount = 200
updatedAt := now.Add(time.Hour)
conversations[0].UpdatedAt = &updatedAt
conversations[1].UpdatedAt = &updatedAt
err = d.AddOrUpdateConversationsWithUser(uid, conversations[:2])
assert.NoError(t, err)
// 第三次查询(缓存应该被智能更新,而不是失效)
start3 := time.Now()
result3, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
duration3 := time.Since(start3)
assert.NoError(t, err)
t.Logf("Cached query: %v", duration2)
t.Logf("Query after smart cache update: %v", duration3)
// 验证缓存仍然有效且数据已更新
found1, found2 := false, false
for _, conv := range result3 {
if conv.Id == 1 {
assert.Equal(t, uint32(100), conv.UnreadCount)
found1 = true
}
if conv.Id == 2 {
assert.Equal(t, uint32(200), conv.UnreadCount)
found2 = true
}
}
assert.True(t, found1, "Updated conversation 1 should be found")
assert.True(t, found2, "Updated conversation 2 should be found")
// 智能更新后的查询应该仍然很快(缓存命中)
if duration3 < duration2*2 {
t.Logf("Smart cache update maintained good performance")
}
}
// 测试缓存失效后的性能
func TestCacheInvalidationPerformance(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "invalidation_test_user"
now := time.Now()
// 创建测试数据
conversations := make([]wkdb.Conversation, 0, 20)
for i := 0; i < 20; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),
Uid: uid,
ChannelId: fmt.Sprintf("channel_%d", i),
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
ReadToMsgSeq: uint64(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
// 添加会话
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 第一次查询(缓存未命中)
result1, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
assert.NoError(t, err)
assert.Len(t, result1, 10)
// 第二次查询(缓存命中)
start2 := time.Now()
result2, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
duration2 := time.Since(start2)
assert.NoError(t, err)
assert.Equal(t, len(result1), len(result2))
// 更新会话(这会使缓存失效)
conversations[0].UnreadCount = 100
updatedAt := now.Add(time.Hour)
conversations[0].UpdatedAt = &updatedAt
err = d.AddOrUpdateConversationsWithUser(uid, conversations[:1])
assert.NoError(t, err)
// 第三次查询(缓存失效后,需要重新从数据库查询)
start3 := time.Now()
result3, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
duration3 := time.Since(start3)
assert.NoError(t, err)
t.Logf("Cached query: %v", duration2)
t.Logf("Query after cache invalidation: %v", duration3)
// 验证缓存失效后数据是最新的
found := false
for _, conv := range result3 {
if conv.Id == 1 {
assert.Equal(t, uint32(100), conv.UnreadCount)
found = true
break
}
}
assert.True(t, found, "Updated conversation should be found")
}
// 压力测试:大量并发查询
func TestCacheConcurrentQueries(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "concurrent_test_user"
now := time.Now()
// 创建测试数据
conversations := make([]wkdb.Conversation, 0, 15)
for i := 0; i < 15; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),
Uid: uid,
ChannelId: fmt.Sprintf("channel_%d", i),
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
ReadToMsgSeq: uint64(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 并发查询测试
done := make(chan bool, 10)
errors := make(chan error, 10)
for i := 0; i < 10; i++ {
go func(id int) {
for j := 0; j < 5; j++ {
_, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 5)
if err != nil {
errors <- err
return
}
}
done <- true
}(i)
}
// 等待所有goroutine完成
for i := 0; i < 10; i++ {
select {
case <-done:
// 成功完成
case err := <-errors:
t.Fatalf("Concurrent query failed: %v", err)
case <-time.After(10 * time.Second):
t.Fatal("Concurrent query timeout")
}
}
}
// 基准测试:缓存性能
func BenchmarkConversationCacheOperations(b *testing.B) {
cache := wkdb.NewConversationCache(10000)
// 预填充缓存
conversations := make([]wkdb.Conversation, 100)
for i := 0; i < 100; i++ {
conversations[i] = wkdb.Conversation{
Id: uint64(i),
Uid: fmt.Sprintf("user%d", i%10),
ChannelId: fmt.Sprintf("channel%d", i),
ChannelType: 1,
}
}
b.Run("SetLastConversations", func(b *testing.B) {
for i := 0; i < b.N; i++ {
uid := fmt.Sprintf("user%d", i%10)
cache.SetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10, conversations)
}
})
// 预填充一些数据用于读取测试
for i := 0; i < 10; i++ {
uid := fmt.Sprintf("user%d", i)
cache.SetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10, conversations)
}
b.Run("GetLastConversations", func(b *testing.B) {
for i := 0; i < b.N; i++ {
uid := fmt.Sprintf("user%d", i%10)
cache.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
}
})
}

View File

@@ -1,6 +1,7 @@
package wkdb_test
import (
"fmt"
"testing"
"time"
@@ -184,8 +185,53 @@ func TestDeleteConversation(t *testing.T) {
// assert.Equal(t, conversations[1], conversations2[1])
// }
// 测试 AddOrUpdateConversations 的性能
// 测试 GetLastConversations 批量查询优化
func TestGetLastConversationsBatch(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "test_batch_user"
now := time.Now()
// 创建多个会话用于测试
conversations := make([]wkdb.Conversation, 0, 20)
for i := 0; i < 20; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),
Uid: uid,
ChannelId: fmt.Sprintf("channel_%d", i),
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
ReadToMsgSeq: uint64(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
// 添加会话
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 测试批量查询
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
assert.NoError(t, err)
assert.LessOrEqual(t, len(result), 10)
// 验证结果按更新时间排序(最新的在前)
for i := 1; i < len(result); i++ {
assert.True(t, result[i-1].UpdatedAt.After(*result[i].UpdatedAt) || result[i-1].UpdatedAt.Equal(*result[i].UpdatedAt))
}
}
// 测试 AddOrUpdateConversations 的性能
func BenchmarkAddOrUpdateConversations(b *testing.B) {
d := newTestDB(b)
err := d.Open()
@@ -228,3 +274,44 @@ func BenchmarkAddOrUpdateConversations(b *testing.B) {
assert.NoError(b, err)
}
}
// 测试 GetLastConversations 的性能对比
func BenchmarkGetLastConversations(b *testing.B) {
d := newTestDB(b)
err := d.Open()
assert.NoError(b, err)
defer func() {
err := d.Close()
assert.NoError(b, err)
}()
uid := "bench_user"
now := time.Now()
// 创建大量会话数据
conversations := make([]wkdb.Conversation, 0, 1000)
for i := 0; i < 100; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),
Uid: uid,
ChannelId: fmt.Sprintf("channel_%d", i),
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
ReadToMsgSeq: uint64(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 20)
assert.NoError(b, err)
}
}

View File

@@ -37,7 +37,8 @@ type wukongDB struct {
metrics trace.IDBMetrics
channelSeqCache *channelSeqCache
channelSeqCache *channelSeqCache
conversationCache *ConversationCache
h hash.Hash32
}
@@ -59,15 +60,16 @@ func NewWukongDB(opts *Options) DB {
cancelCtx, cancelFunc := context.WithCancel(context.Background())
return &wukongDB{
opts: opts,
shardNum: uint32(opts.ShardNum),
prmaryKeyGen: prmaryKeyGen,
endian: endian,
cancelCtx: cancelCtx,
cancelFunc: cancelFunc,
metrics: metrics,
channelSeqCache: newChannelSeqCache(10000, endian),
h: fnv.New32(),
opts: opts,
shardNum: uint32(opts.ShardNum),
prmaryKeyGen: prmaryKeyGen,
endian: endian,
cancelCtx: cancelCtx,
cancelFunc: cancelFunc,
metrics: metrics,
channelSeqCache: newChannelSeqCache(10000, endian),
conversationCache: NewConversationCache(2000), // 缓存2000个 GetLastConversations 查询结果
h: fnv.New32(),
sync: &pebble.WriteOptions{
Sync: true,
},