mirror of
https://gitee.com/WuKongDev/WuKongIM.git
synced 2025-12-06 14:59:08 +08:00
ref: user reactor
This commit is contained in:
27
internal/reactor/action_channel.go
Normal file
27
internal/reactor/action_channel.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package reactor
|
||||
|
||||
type ChannelActionType uint8
|
||||
|
||||
const (
|
||||
// 未知
|
||||
ChannelActionUnknown ChannelActionType = iota
|
||||
// 选举
|
||||
ChannelActionElection
|
||||
// 加入
|
||||
ChannelActionJoin
|
||||
ChannelActionJoinResp
|
||||
// 心跳
|
||||
ChannelActionHeartbeatReq
|
||||
ChannelActionHeartbeatResp
|
||||
// 收件箱
|
||||
ChannelActionInboundAdd
|
||||
ChannelActionInbound
|
||||
|
||||
// 发件箱
|
||||
ChannelActionOutboundAdd
|
||||
ChannelActionOutboundForward
|
||||
ChannelActionOutboundForwardResp
|
||||
)
|
||||
|
||||
type ChannelAction struct {
|
||||
}
|
||||
@@ -14,8 +14,6 @@ const (
|
||||
UserActionJoinResp
|
||||
// 认证
|
||||
UserActionAuthAdd
|
||||
UserActionAuth
|
||||
UserActionAuthResp
|
||||
// 收件箱
|
||||
UserActionInboundAdd
|
||||
UserActionInbound
|
||||
@@ -27,6 +25,8 @@ const (
|
||||
UserActionNodeHeartbeatReq
|
||||
// 节点心跳返回 replica --> leader
|
||||
UserActionNodeHeartbeatResp
|
||||
// write
|
||||
UserActionWrite
|
||||
// 关闭连接
|
||||
UserActionConnClose
|
||||
// 用户关闭
|
||||
@@ -47,8 +47,6 @@ func (a UserActionType) String() string {
|
||||
return "UserActionConfigUpdate"
|
||||
case UserActionAuthAdd:
|
||||
return "UserActionAuthAdd"
|
||||
case UserActionAuth:
|
||||
return "UserActionAuth"
|
||||
case UserActionInboundAdd:
|
||||
return "UserActionInboundAdd"
|
||||
case UserActionInbound:
|
||||
@@ -65,25 +63,31 @@ func (a UserActionType) String() string {
|
||||
return "UserActionConnClose"
|
||||
case UserActionUserClose:
|
||||
return "UserActionUserClose"
|
||||
case UserActionWrite:
|
||||
return "UserActionWrite"
|
||||
case UserActionOutboundAdd:
|
||||
return "UserActionOutboundAdd"
|
||||
default:
|
||||
return "UserUnknown"
|
||||
}
|
||||
}
|
||||
|
||||
type UserAction struct {
|
||||
No string // 唯一编号
|
||||
From uint64 // 发送节点
|
||||
To uint64 // 接收节点
|
||||
No string // 唯一编号
|
||||
Uid string
|
||||
Type UserActionType
|
||||
Messages []UserMessage
|
||||
Messages UserMessageBatch
|
||||
Index uint64
|
||||
LeaderId uint64
|
||||
Cfg UserConfig
|
||||
Conns []Conn
|
||||
Conns []*Conn
|
||||
Term uint32 // 任期
|
||||
NodeVersion uint64 // 节点的数据版本
|
||||
Success bool
|
||||
Role Role
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (a UserAction) Size() uint64 {
|
||||
|
||||
@@ -1,18 +1,122 @@
|
||||
package reactor
|
||||
|
||||
import wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
import (
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkutil"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
)
|
||||
|
||||
type Conn interface {
|
||||
// ConnId 用户唯一连接id
|
||||
ConnId() int64
|
||||
// Uid 用户uid
|
||||
Uid() string
|
||||
// DeviceFlag 设备标记
|
||||
DeviceFlag() wkproto.DeviceFlag
|
||||
// FromNode 连接属于节点
|
||||
FromNode() uint64
|
||||
// SetAuth 设置是否认证
|
||||
SetAuth(auth bool)
|
||||
// IsAuth 是否认证
|
||||
IsAuth() bool
|
||||
// type Conn interface {
|
||||
// // ConnId 用户唯一连接id
|
||||
// ConnId() int64
|
||||
// // Uid 用户uid
|
||||
// Uid() string
|
||||
// // DeviceId 设备id
|
||||
// DeviceId() string
|
||||
// // DeviceFlag 设备标记
|
||||
// DeviceFlag() wkproto.DeviceFlag
|
||||
// // FromNode 连接属于节点
|
||||
// FromNode() uint64
|
||||
// // SetAuth 设置是否认证
|
||||
// SetAuth(auth bool)
|
||||
// // IsAuth 是否认证
|
||||
// IsAuth() bool
|
||||
// // Equal 判断是否相等
|
||||
// Equal(conn Conn) bool
|
||||
// // SetString 设置扩展字段
|
||||
// SetString(key string, value string)
|
||||
|
||||
// SetProtoVersion(version uint8)
|
||||
// GetProtoVersion() uint8
|
||||
|
||||
// DeviceLevel() wkproto.DeviceLevel
|
||||
// SetDeviceLevel(level wkproto.DeviceLevel)
|
||||
|
||||
// // Encode 编码连接数据
|
||||
// Encode() ([]byte, error)
|
||||
// // Decode 解密连接数据
|
||||
// Decode(data []byte) error
|
||||
// }
|
||||
|
||||
// const (
|
||||
// ConnAesIV string = "aesIV"
|
||||
// ConnAesKey string = "aesKey"
|
||||
// )
|
||||
|
||||
type Conn struct {
|
||||
ConnId int64
|
||||
Uid string
|
||||
DeviceId string
|
||||
DeviceFlag wkproto.DeviceFlag
|
||||
DeviceLevel wkproto.DeviceLevel
|
||||
FromNode uint64
|
||||
Auth bool
|
||||
AesIV []byte
|
||||
AesKey []byte
|
||||
ProtoVersion uint8
|
||||
}
|
||||
|
||||
func (c *Conn) Encode() ([]byte, error) {
|
||||
enc := wkproto.NewEncoder()
|
||||
enc.WriteInt64(c.ConnId)
|
||||
enc.WriteString(c.Uid)
|
||||
enc.WriteString(c.DeviceId)
|
||||
enc.WriteUint8(c.DeviceFlag.ToUint8())
|
||||
enc.WriteUint8(uint8(c.DeviceLevel))
|
||||
enc.WriteUint64(c.FromNode)
|
||||
enc.WriteUint8(wkutil.BoolToUint8(c.Auth))
|
||||
enc.WriteBinary(c.AesIV)
|
||||
enc.WriteBinary(c.AesKey)
|
||||
enc.WriteUint8(c.ProtoVersion)
|
||||
return enc.Bytes(), nil
|
||||
}
|
||||
|
||||
func (c *Conn) Decode(data []byte) error {
|
||||
dec := wkproto.NewDecoder(data)
|
||||
var err error
|
||||
if c.ConnId, err = dec.Int64(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Uid, err = dec.String(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.DeviceId, err = dec.String(); err != nil {
|
||||
return err
|
||||
}
|
||||
var deviceFlag uint8
|
||||
if deviceFlag, err = dec.Uint8(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.DeviceFlag = wkproto.DeviceFlag(deviceFlag)
|
||||
|
||||
var deviceLevel uint8
|
||||
if deviceLevel, err = dec.Uint8(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.DeviceLevel = wkproto.DeviceLevel(deviceLevel)
|
||||
if c.FromNode, err = dec.Uint64(); err != nil {
|
||||
return err
|
||||
}
|
||||
var auth uint8
|
||||
if auth, err = dec.Uint8(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Auth = wkutil.Uint8ToBool(auth)
|
||||
|
||||
if c.AesIV, err = dec.Binary(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.AesKey, err = dec.Binary(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ProtoVersion, err = dec.Uint8(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Equal(cn *Conn) bool {
|
||||
|
||||
return c.Uid == cn.Uid && c.ConnId == cn.ConnId && c.FromNode == cn.FromNode
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package reactor
|
||||
|
||||
import wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
|
||||
type UserMessage interface {
|
||||
Conn() Conn
|
||||
Frame() wkproto.Frame
|
||||
Size() uint64
|
||||
SetIndex(index uint64)
|
||||
Index() uint64
|
||||
}
|
||||
|
||||
type defaultUserMessage struct {
|
||||
conn Conn
|
||||
frame wkproto.Frame
|
||||
index uint64
|
||||
}
|
||||
|
||||
func (m *defaultUserMessage) Conn() Conn {
|
||||
return m.conn
|
||||
}
|
||||
|
||||
func (m *defaultUserMessage) Frame() wkproto.Frame {
|
||||
return m.frame
|
||||
}
|
||||
|
||||
func (m *defaultUserMessage) Size() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *defaultUserMessage) SetIndex(index uint64) {
|
||||
m.index = index
|
||||
}
|
||||
|
||||
func (m *defaultUserMessage) Index() uint64 {
|
||||
return m.index
|
||||
}
|
||||
4
internal/reactor/message_channel.go
Normal file
4
internal/reactor/message_channel.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package reactor
|
||||
|
||||
type ChannelMessage interface {
|
||||
}
|
||||
249
internal/reactor/message_user.go
Normal file
249
internal/reactor/message_user.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package reactor
|
||||
|
||||
import wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
|
||||
// type UserMessage interface {
|
||||
// GetConn() *Conn
|
||||
// GetFrame() wkproto.Frame
|
||||
// Size() uint64
|
||||
// // 消息顺序下标
|
||||
// SetIndex(index uint64)
|
||||
// GetIndex() uint64
|
||||
// // ToNode 此条消息只发给此节点
|
||||
// GetToNode() uint64
|
||||
// SetToNode(to uint64)
|
||||
|
||||
// // GetWriteData 获取写入数据
|
||||
// GetWriteData() []byte
|
||||
// SetWriteData(data []byte)
|
||||
// }
|
||||
|
||||
type UserMessage struct {
|
||||
Conn *Conn
|
||||
Frame wkproto.Frame
|
||||
WriteData []byte
|
||||
|
||||
Index uint64
|
||||
ToNode uint64
|
||||
}
|
||||
|
||||
// func (m *DefaultUserMessage) GetConn() Conn {
|
||||
// return m.Conn
|
||||
// }
|
||||
|
||||
// func (m *DefaultUserMessage) GetFrame() wkproto.Frame {
|
||||
// return m.Frame
|
||||
// }
|
||||
|
||||
// func (m *DefaultUserMessage) Size() uint64 {
|
||||
// return 0
|
||||
// }
|
||||
|
||||
// func (m *DefaultUserMessage) SetIndex(index uint64) {
|
||||
// m.Index = index
|
||||
// }
|
||||
|
||||
// func (m *DefaultUserMessage) GetIndex() uint64 {
|
||||
// return m.Index
|
||||
// }
|
||||
|
||||
// func (m *DefaultUserMessage) GetToNode() uint64 {
|
||||
// return m.ToNode
|
||||
// }
|
||||
|
||||
// func (m *DefaultUserMessage) SetToNode(to uint64) {
|
||||
// m.ToNode = to
|
||||
// }
|
||||
|
||||
// func (m *DefaultUserMessage) GetWriteData() []byte {
|
||||
// return m.WriteData
|
||||
// }
|
||||
// func (m *DefaultUserMessage) SetWriteData(data []byte) {
|
||||
// m.WriteData = data
|
||||
// }
|
||||
|
||||
func (m *UserMessage) Size() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *UserMessage) Encode() ([]byte, error) {
|
||||
// flag
|
||||
// conn frame writeData 0 0 0 0
|
||||
hasConn := m.hasConn()
|
||||
hasFrame := m.hasFrame()
|
||||
hasWriteData := m.hasWriteData()
|
||||
|
||||
var flag uint8 = hasConn<<7 | hasFrame<<6 | hasWriteData<<5
|
||||
|
||||
enc := wkproto.NewEncoder()
|
||||
defer enc.End()
|
||||
|
||||
enc.WriteUint8(flag)
|
||||
if hasConn == 1 {
|
||||
data, err := m.Conn.Encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc.WriteBinary(data)
|
||||
}
|
||||
if hasFrame == 1 {
|
||||
data, err := Proto.EncodeFrame(m.Frame, wkproto.LatestVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc.WriteUint32(uint32(len(data)))
|
||||
enc.WriteBytes(data)
|
||||
}
|
||||
if hasWriteData == 1 {
|
||||
enc.WriteUint32(uint32(len(m.WriteData)))
|
||||
enc.WriteBytes(m.WriteData)
|
||||
}
|
||||
|
||||
enc.WriteUint64(m.Index)
|
||||
enc.WriteUint64(m.ToNode)
|
||||
|
||||
return enc.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *UserMessage) Decode(data []byte) error {
|
||||
|
||||
dec := wkproto.NewDecoder(data)
|
||||
flag, err := dec.Uint8()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasConn := flag >> 7 & 0x01
|
||||
hasFrame := flag >> 6 & 0x01
|
||||
hasWriteData := flag >> 5 & 0x01
|
||||
|
||||
if hasConn == 1 {
|
||||
connBytes, err := dec.Binary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn := &Conn{}
|
||||
err = conn.Decode(connBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Conn = conn
|
||||
}
|
||||
if hasFrame == 1 {
|
||||
frameLen, err := dec.Uint32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := dec.Bytes(int(frameLen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
frame, _, err := Proto.DecodeFrame(data, wkproto.LatestVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Frame = frame
|
||||
}
|
||||
|
||||
if hasWriteData == 1 {
|
||||
writeDataLen, err := dec.Uint32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := dec.Bytes(int(writeDataLen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.WriteData = data
|
||||
}
|
||||
if m.Index, err = dec.Uint64(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.ToNode, err = dec.Uint64(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (m *UserMessage) hasConn() uint8 {
|
||||
if m.Conn != nil {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
func (m *UserMessage) hasFrame() uint8 {
|
||||
if m.Frame != nil {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *UserMessage) hasWriteData() uint8 {
|
||||
if len(m.WriteData) > 0 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type UserMessageBatch []*UserMessage
|
||||
|
||||
func (u UserMessageBatch) Encode() ([]byte, error) {
|
||||
// 创建一个新的编码器
|
||||
enc := wkproto.NewEncoder()
|
||||
defer enc.End()
|
||||
|
||||
// 编码 UserMessage 的数量
|
||||
enc.WriteUint16(uint16(len(u)))
|
||||
|
||||
// 编码每个 UserMessage
|
||||
for _, msg := range u {
|
||||
// 确保每个消息都有正确的编码
|
||||
encodedMsg, err := msg.Encode() // 解引用 msg 获取 UserMessage
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc.WriteUint32(uint32(len(encodedMsg)))
|
||||
enc.WriteBytes(encodedMsg)
|
||||
}
|
||||
|
||||
// 返回最终编码的字节数据
|
||||
return enc.Bytes(), nil
|
||||
}
|
||||
|
||||
func (u *UserMessageBatch) Decode(data []byte) error {
|
||||
// 创建解码器
|
||||
dec := wkproto.NewDecoder(data)
|
||||
|
||||
// 解码 UserMessage 的数量
|
||||
numMessages, err := dec.Uint32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解码每个 UserMessage
|
||||
*u = make(UserMessageBatch, numMessages)
|
||||
for i := uint32(0); i < numMessages; i++ {
|
||||
// 解码每个 UserMessage 的字节数据
|
||||
dataLen, err := dec.Uint32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msgData, err := dec.Bytes(int(dataLen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 创建一个新的 UserMessage
|
||||
msg := &UserMessage{}
|
||||
// 使用解码数据填充 UserMessage
|
||||
err = msg.Decode(msgData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 将解码后的消息添加到 UserMessageBatch 中
|
||||
(*u)[i] = msg
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
106
internal/reactor/message_user_test.go
Normal file
106
internal/reactor/message_user_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package reactor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type MockFrame struct{}
|
||||
|
||||
func TestDefaultUserMessage_EncodeDecode(t *testing.T) {
|
||||
// Initialize the message with mock values
|
||||
msg := &UserMessage{
|
||||
Conn: &Conn{},
|
||||
Frame: &wkproto.PingPacket{},
|
||||
WriteData: []byte("some-data"),
|
||||
Index: 12345,
|
||||
ToNode: 67890,
|
||||
}
|
||||
|
||||
// Test Encode
|
||||
encoded, err := msg.Encode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, encoded)
|
||||
t.Logf("Encoded message: %v", encoded)
|
||||
|
||||
// Test Decode
|
||||
decodedMsg := &UserMessage{}
|
||||
err = decodedMsg.Decode(encoded)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, msg.Index, decodedMsg.Index)
|
||||
assert.Equal(t, msg.ToNode, decodedMsg.ToNode)
|
||||
assert.Equal(t, string(msg.WriteData), string(decodedMsg.WriteData))
|
||||
}
|
||||
|
||||
func TestDefaultUserMessage_EncodeNoConn(t *testing.T) {
|
||||
// Test the case when Conn is nil
|
||||
msg := &UserMessage{
|
||||
Frame: &wkproto.PingPacket{},
|
||||
WriteData: []byte("some-data"),
|
||||
Index: 12345,
|
||||
ToNode: 67890,
|
||||
}
|
||||
|
||||
encoded, err := msg.Encode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, encoded)
|
||||
|
||||
// Decode it back
|
||||
decodedMsg := &UserMessage{}
|
||||
err = decodedMsg.Decode(encoded)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, msg.Index, decodedMsg.Index)
|
||||
assert.Equal(t, msg.ToNode, decodedMsg.ToNode)
|
||||
assert.Equal(t, string(msg.WriteData), string(decodedMsg.WriteData))
|
||||
}
|
||||
|
||||
func TestDefaultUserMessage_EncodeNoFrame(t *testing.T) {
|
||||
// Test the case when Frame is nil
|
||||
msg := &UserMessage{
|
||||
Conn: &Conn{
|
||||
Uid: "u1",
|
||||
},
|
||||
WriteData: []byte("some-data"),
|
||||
Index: 12345,
|
||||
ToNode: 67890,
|
||||
}
|
||||
|
||||
encoded, err := msg.Encode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, encoded)
|
||||
|
||||
// Decode it back
|
||||
decodedMsg := &UserMessage{}
|
||||
err = decodedMsg.Decode(encoded)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, msg.Index, decodedMsg.Index)
|
||||
assert.Equal(t, msg.ToNode, decodedMsg.ToNode)
|
||||
assert.Equal(t, string(msg.WriteData), string(decodedMsg.WriteData))
|
||||
assert.Equal(t, msg.Conn.Uid, decodedMsg.Conn.Uid)
|
||||
}
|
||||
|
||||
func TestDefaultUserMessage_EncodeNoWriteData(t *testing.T) {
|
||||
// Test the case when WriteData is empty
|
||||
msg := &UserMessage{
|
||||
Conn: &Conn{
|
||||
Uid: "u1",
|
||||
},
|
||||
Frame: &wkproto.PingPacket{},
|
||||
Index: 12345,
|
||||
ToNode: 67890,
|
||||
}
|
||||
|
||||
encoded, err := msg.Encode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, encoded)
|
||||
|
||||
// Decode it back
|
||||
decodedMsg := &UserMessage{}
|
||||
err = decodedMsg.Decode(encoded)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, msg.Index, decodedMsg.Index)
|
||||
assert.Equal(t, msg.ToNode, decodedMsg.ToNode)
|
||||
assert.Equal(t, msg.Conn.Uid, decodedMsg.Conn.Uid)
|
||||
}
|
||||
@@ -1,14 +1,15 @@
|
||||
package reactor
|
||||
|
||||
import wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
|
||||
var User *UserPlus
|
||||
var Channel IChannel
|
||||
var Channel *ChannelPlus
|
||||
var Proto wkproto.Protocol = wkproto.New()
|
||||
|
||||
func RegisterUser(u IUser) {
|
||||
User = &UserPlus{
|
||||
user: u,
|
||||
}
|
||||
User = newUserPlus(u)
|
||||
}
|
||||
|
||||
func RegisterChannel(c IChannel) {
|
||||
Channel = c
|
||||
Channel = newChannelPlus(c)
|
||||
}
|
||||
|
||||
@@ -2,3 +2,16 @@ package reactor
|
||||
|
||||
type IChannel interface {
|
||||
}
|
||||
|
||||
type ChannelPlus struct {
|
||||
ch IChannel
|
||||
}
|
||||
|
||||
func newChannelPlus(ch IChannel) *ChannelPlus {
|
||||
return &ChannelPlus{
|
||||
ch: ch,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelPlus) AddMessagesToInBound(channelId string, channelType uint8, messages []ChannelMessage) {
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package reactor
|
||||
|
||||
import wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wklog"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type IUser interface {
|
||||
// Start 开始
|
||||
@@ -11,10 +17,20 @@ type IUser interface {
|
||||
WakeIfNeed(uid string)
|
||||
// AddAction 添加用户行为
|
||||
AddAction(a UserAction) bool
|
||||
// CloseConn 关闭连接
|
||||
CloseConn(conn *Conn)
|
||||
}
|
||||
|
||||
type UserPlus struct {
|
||||
user IUser
|
||||
wklog.Log
|
||||
}
|
||||
|
||||
func newUserPlus(user IUser) *UserPlus {
|
||||
return &UserPlus{
|
||||
user: user,
|
||||
Log: wklog.NewWKLog("UserPlus"),
|
||||
}
|
||||
}
|
||||
|
||||
// WakeIfNeed 根据需要唤醒用户(如果用户在就不需要唤醒)
|
||||
@@ -32,25 +48,26 @@ func (u *UserPlus) UpdateConfig(uid string, cfg UserConfig) {
|
||||
}
|
||||
|
||||
// AddAuth 添加认证
|
||||
func (u *UserPlus) AddAuth(conn Conn, connectPacket *wkproto.ConnectPacket) {
|
||||
func (u *UserPlus) AddAuth(conn *Conn, connectPacket *wkproto.ConnectPacket) {
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionAuthAdd,
|
||||
Uid: connectPacket.UID,
|
||||
Messages: []UserMessage{
|
||||
&defaultUserMessage{
|
||||
conn: conn,
|
||||
frame: connectPacket,
|
||||
Messages: []*UserMessage{
|
||||
{
|
||||
Conn: conn,
|
||||
Frame: connectPacket,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// AddMessages 添加消息
|
||||
func (u *UserPlus) AddMessages(uid string, msgs []UserMessage) {
|
||||
// Join 副本加入到领导
|
||||
func (u *UserPlus) Join(uid string, nodeId uint64) {
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionInboundAdd,
|
||||
Uid: uid,
|
||||
Messages: msgs,
|
||||
Type: UserActionAuthAdd,
|
||||
Uid: uid,
|
||||
From: nodeId,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -70,14 +87,14 @@ func (u *UserPlus) LeaderId(uid string) uint64 {
|
||||
}
|
||||
|
||||
// Kick 踢掉连接
|
||||
func (u *UserPlus) Kick(conn Conn, reasonCode wkproto.ReasonCode, reason string) {
|
||||
func (u *UserPlus) Kick(conn *Conn, reasonCode wkproto.ReasonCode, reason string) {
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionOutboundAdd,
|
||||
Uid: conn.Uid(),
|
||||
Messages: []UserMessage{
|
||||
&defaultUserMessage{
|
||||
conn: conn,
|
||||
frame: &wkproto.DisconnectPacket{
|
||||
Uid: conn.Uid,
|
||||
Messages: []*UserMessage{
|
||||
&UserMessage{
|
||||
Conn: conn,
|
||||
Frame: &wkproto.DisconnectPacket{
|
||||
ReasonCode: reasonCode,
|
||||
Reason: reason,
|
||||
},
|
||||
@@ -90,49 +107,100 @@ func (u *UserPlus) AllUserCount() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// ========================================== conn ==========================================
|
||||
// ========================================== message ==========================================
|
||||
|
||||
// ConnsByDeviceFlag 根据设备标识获取连接
|
||||
func (u *UserPlus) ConnsByDeviceFlag(uid string, deviceFlag wkproto.DeviceFlag) []Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnsByUid 根据用户uid获取连接
|
||||
func (u *UserPlus) ConnsByUid(uid string) []Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalConnById 获取本地连接
|
||||
func (u *UserPlus) LocalConnById(uid string, id int64) Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnById 获取连接
|
||||
func (u *UserPlus) ConnById(uid string, fromNode uint64, id int64) Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseConn 关闭连接
|
||||
func (u *UserPlus) CloseConn(conn Conn) {
|
||||
|
||||
}
|
||||
|
||||
// ConnWrite 连接写包
|
||||
func (u *UserPlus) ConnWrite(conn Conn, frame wkproto.Frame) {
|
||||
// AddMessages 添加消息到收件箱
|
||||
func (u *UserPlus) AddMessages(uid string, msgs []*UserMessage) {
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionOutboundAdd,
|
||||
Uid: conn.Uid(),
|
||||
Messages: []UserMessage{
|
||||
&defaultUserMessage{
|
||||
conn: conn,
|
||||
frame: frame,
|
||||
},
|
||||
Type: UserActionInboundAdd,
|
||||
Uid: uid,
|
||||
Messages: msgs,
|
||||
})
|
||||
}
|
||||
|
||||
// AddMessage 添加消息到收件箱
|
||||
func (u *UserPlus) AddMessage(uid string, msg *UserMessage) {
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionInboundAdd,
|
||||
Uid: uid,
|
||||
Messages: []*UserMessage{
|
||||
msg,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (u *UserPlus) ConnWriteBytes(conn Conn, bytes []byte) {
|
||||
// AddMessageToOutbound 添加消息到发件箱
|
||||
func (u *UserPlus) AddMessageToOutbound(uid string, msg *UserMessage) {
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionOutboundAdd,
|
||||
Uid: uid,
|
||||
Messages: []*UserMessage{msg},
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================== conn ==========================================
|
||||
|
||||
// ConnsByDeviceFlag 根据设备标识获取连接
|
||||
func (u *UserPlus) ConnsByDeviceFlag(uid string, deviceFlag wkproto.DeviceFlag) []*Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserPlus) ConnCountByDeviceFlag(uid string, deviceFlag wkproto.DeviceFlag) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// ConnsByUid 根据用户uid获取连接
|
||||
func (u *UserPlus) ConnsByUid(uid string) []*Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserPlus) ConnCountByUid(uid string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// LocalConnById 获取本地连接
|
||||
func (u *UserPlus) LocalConnById(uid string, id int64) *Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnById 获取连接
|
||||
func (u *UserPlus) ConnById(uid string, fromNode uint64, id int64) *Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseConn 关闭连接
|
||||
func (u *UserPlus) CloseConn(conn *Conn) {
|
||||
|
||||
fmt.Println("CloseConn---->", conn.Uid)
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionConnClose,
|
||||
Uid: conn.Uid,
|
||||
Conns: []*Conn{conn},
|
||||
})
|
||||
}
|
||||
|
||||
// ConnWrite 连接写包
|
||||
func (u *UserPlus) ConnWrite(conn *Conn, frame wkproto.Frame) {
|
||||
|
||||
data, err := Proto.EncodeFrame(frame, conn.ProtoVersion)
|
||||
if err != nil {
|
||||
u.Error("encode failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
u.ConnWriteBytes(conn, data)
|
||||
}
|
||||
|
||||
func (u *UserPlus) ConnWriteBytes(conn *Conn, bytes []byte) {
|
||||
u.user.AddAction(UserAction{
|
||||
Type: UserActionWrite,
|
||||
Uid: conn.Uid,
|
||||
Messages: []*UserMessage{
|
||||
{
|
||||
Conn: conn,
|
||||
WriteData: bytes,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// AllConnCount 所有连接数量
|
||||
|
||||
@@ -7,32 +7,32 @@ import (
|
||||
)
|
||||
|
||||
type conns struct {
|
||||
conns []reactor.Conn // 一个用户有多个连接
|
||||
conns []*reactor.Conn // 一个用户有多个连接
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *conns) add(cn reactor.Conn) {
|
||||
func (c *conns) add(cn *reactor.Conn) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.conns = append(c.conns, cn)
|
||||
}
|
||||
|
||||
func (c *conns) remove(cn reactor.Conn) {
|
||||
func (c *conns) remove(cn *reactor.Conn) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for i, conn := range c.conns {
|
||||
if conn == cn {
|
||||
if conn.ConnId == cn.ConnId && conn.FromNode == cn.FromNode {
|
||||
c.conns = append(c.conns[:i], c.conns[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conns) connByConnId(nodeId uint64, connId int64) reactor.Conn {
|
||||
func (c *conns) connByConnId(nodeId uint64, connId int64) *reactor.Conn {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
for _, conn := range c.conns {
|
||||
if conn.ConnId() == connId && conn.FromNode() == nodeId {
|
||||
if conn.ConnId == connId && conn.FromNode == nodeId {
|
||||
return conn
|
||||
}
|
||||
}
|
||||
@@ -43,18 +43,18 @@ func (c *conns) updateConnAuth(nodeId uint64, connId int64, auth bool) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for i, conn := range c.conns {
|
||||
if conn.ConnId() == connId && conn.FromNode() == nodeId {
|
||||
c.conns[i].SetAuth(auth)
|
||||
if conn.ConnId == connId && conn.FromNode == nodeId {
|
||||
c.conns[i].Auth = auth
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conns) updateConn(connId int64, nodeId uint64, newConn reactor.Conn) {
|
||||
func (c *conns) updateConn(connId int64, nodeId uint64, newConn *reactor.Conn) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for i, conn := range c.conns {
|
||||
if conn.ConnId() == connId && conn.FromNode() == nodeId {
|
||||
if conn.ConnId == connId && conn.FromNode == nodeId {
|
||||
c.conns[i] = newConn
|
||||
return
|
||||
}
|
||||
@@ -67,18 +67,18 @@ func (c *conns) len() int {
|
||||
return len(c.conns)
|
||||
}
|
||||
|
||||
func (c *conns) allConns() []reactor.Conn {
|
||||
func (c *conns) allConns() []*reactor.Conn {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
return c.conns
|
||||
}
|
||||
|
||||
func (c *conns) connsByNodeId(nodeId uint64) []reactor.Conn {
|
||||
func (c *conns) connsByNodeId(nodeId uint64) []*reactor.Conn {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
conns := make([]reactor.Conn, 0, len(c.conns))
|
||||
conns := make([]*reactor.Conn, 0, len(c.conns))
|
||||
for _, conn := range c.conns {
|
||||
if conn.FromNode() == nodeId {
|
||||
if conn.FromNode == nodeId {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
@@ -91,11 +91,11 @@ func (c *conns) nodeIds() []uint64 {
|
||||
nodeIds := make([]uint64, 0, len(c.conns))
|
||||
for _, conn := range c.conns {
|
||||
for _, nodeId := range nodeIds {
|
||||
if nodeId == conn.FromNode() {
|
||||
if nodeId == conn.FromNode {
|
||||
continue
|
||||
}
|
||||
}
|
||||
nodeIds = append(nodeIds, conn.FromNode())
|
||||
nodeIds = append(nodeIds, conn.FromNode)
|
||||
}
|
||||
return nodeIds
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
type msgQueue struct {
|
||||
messages []reactor.UserMessage
|
||||
messages []*reactor.UserMessage
|
||||
offsetMsgIndex uint64
|
||||
wklog.Log
|
||||
lastIndex uint64 // 最新下标
|
||||
@@ -21,7 +21,7 @@ func newMsgQueue(prefix string) *msgQueue {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *msgQueue) append(message reactor.UserMessage) {
|
||||
func (m *msgQueue) append(message *reactor.UserMessage) {
|
||||
m.messages = append(m.messages, message)
|
||||
m.lastIndex++
|
||||
}
|
||||
@@ -31,12 +31,12 @@ func (m *msgQueue) len() int {
|
||||
}
|
||||
|
||||
// [lo,hi)
|
||||
func (m *msgQueue) slice(startMsgIndex uint64, endMsgIndex uint64) []reactor.UserMessage {
|
||||
func (m *msgQueue) slice(startMsgIndex uint64, endMsgIndex uint64) []*reactor.UserMessage {
|
||||
|
||||
return m.messages[startMsgIndex-m.offsetMsgIndex : endMsgIndex-m.offsetMsgIndex : endMsgIndex-m.offsetMsgIndex]
|
||||
}
|
||||
|
||||
func (m *msgQueue) sliceWithSize(startMsgIndex uint64, endMsgIndex uint64, maxSize uint64) []reactor.UserMessage {
|
||||
func (m *msgQueue) sliceWithSize(startMsgIndex uint64, endMsgIndex uint64, maxSize uint64) []*reactor.UserMessage {
|
||||
if startMsgIndex == endMsgIndex {
|
||||
return nil
|
||||
}
|
||||
@@ -74,13 +74,13 @@ func (m *msgQueue) shrinkMessagesArray() {
|
||||
if len(m.messages) == 0 {
|
||||
m.messages = nil
|
||||
} else if len(m.messages)*lenMultiple < cap(m.messages) {
|
||||
newMessages := make([]reactor.UserMessage, len(m.messages))
|
||||
newMessages := make([]*reactor.UserMessage, len(m.messages))
|
||||
copy(newMessages, m.messages)
|
||||
m.messages = newMessages
|
||||
}
|
||||
}
|
||||
|
||||
func limitSize(messages []reactor.UserMessage, maxSize uint64) []reactor.UserMessage {
|
||||
func limitSize(messages []*reactor.UserMessage, maxSize uint64) []*reactor.UserMessage {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -3,13 +3,14 @@ package reactor
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMsgQueue(t *testing.T) {
|
||||
q := newMsgQueue("queue")
|
||||
for i := 0; i < 100; i++ {
|
||||
q.append(&testMessage{})
|
||||
q.append(&reactor.UserMessage{})
|
||||
}
|
||||
|
||||
msgs := q.slice(1, 51)
|
||||
|
||||
@@ -6,12 +6,13 @@ import (
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkutil"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
)
|
||||
|
||||
type Reactor struct {
|
||||
subs []*reactorSub
|
||||
|
||||
mu sync.Mutex
|
||||
subs []*reactorSub
|
||||
proto wkproto.Protocol
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewReactor(opt ...Option) *Reactor {
|
||||
@@ -56,29 +57,30 @@ func (r *Reactor) WakeIfNeed(uid string) {
|
||||
sub.addUser(user)
|
||||
}
|
||||
|
||||
// func (r *Reactor) AddConn(c reactor.Conn) {
|
||||
// r.mu.Lock()
|
||||
// defer r.mu.Unlock()
|
||||
func (r *Reactor) CloseConn(c *reactor.Conn) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// sub := r.getSub(c.Uid())
|
||||
// user := sub.user(c.Uid())
|
||||
// if user == nil {
|
||||
// user = NewUser(wkutil.GenUUID(), c.Uid())
|
||||
// sub.addUser(user)
|
||||
// }
|
||||
|
||||
// sub.addAction(reactor.UserAction{
|
||||
// No: user.no,
|
||||
// Uid: c.Uid(),
|
||||
// Type: reactor.UserActionConnectAdd,
|
||||
// Conns: []reactor.Conn{c},
|
||||
// })
|
||||
// }
|
||||
sub := r.getSub(c.Uid)
|
||||
user := sub.user(c.Uid)
|
||||
if user == nil {
|
||||
return
|
||||
}
|
||||
user.conns.remove(c)
|
||||
}
|
||||
|
||||
func (r *Reactor) AddAction(a reactor.UserAction) bool {
|
||||
return r.getSub(a.Uid).addAction(a)
|
||||
}
|
||||
|
||||
func (r *Reactor) SetProto(proto wkproto.Protocol) {
|
||||
r.proto = proto
|
||||
}
|
||||
|
||||
func (r *Reactor) GetProto() wkproto.Protocol {
|
||||
return r.proto
|
||||
}
|
||||
|
||||
func (r *Reactor) getSub(uid string) *reactorSub {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(uid))
|
||||
|
||||
@@ -127,6 +127,13 @@ func (r *reactorSub) handleEvent(u *User) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
switch action.Type {
|
||||
case reactor.UserActionUserClose:
|
||||
r.users.remove(u.uid)
|
||||
}
|
||||
}
|
||||
|
||||
r.r.send(actions)
|
||||
|
||||
return true
|
||||
@@ -138,7 +145,6 @@ func (r *reactorSub) handleReceivedActions() bool {
|
||||
if len(actions) == 0 {
|
||||
return false
|
||||
}
|
||||
fmt.Println("a.Uid--->", actions)
|
||||
for _, a := range actions {
|
||||
user := r.users.get(a.Uid)
|
||||
if user == nil {
|
||||
|
||||
@@ -22,16 +22,16 @@ func newReady(logPrefix string) *ready {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ready) slice() []reactor.UserMessage {
|
||||
func (r *ready) slice() []*reactor.UserMessage {
|
||||
r.endIndex = 0
|
||||
msgs := r.queue.sliceWithSize(r.offsetIndex+1, r.queue.lastIndex+1, 0)
|
||||
if len(msgs) > 0 {
|
||||
r.endIndex = msgs[len(msgs)-1].Index()
|
||||
r.endIndex = msgs[len(msgs)-1].Index
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
func (r *ready) sliceWith(startIndex, endIndex uint64) []reactor.UserMessage {
|
||||
func (r *ready) sliceWith(startIndex, endIndex uint64) []*reactor.UserMessage {
|
||||
if endIndex == 0 {
|
||||
endIndex = r.queue.lastIndex + 1
|
||||
}
|
||||
@@ -40,10 +40,10 @@ func (r *ready) sliceWith(startIndex, endIndex uint64) []reactor.UserMessage {
|
||||
|
||||
}
|
||||
|
||||
func (r *ready) sliceAndTruncate() []reactor.UserMessage {
|
||||
func (r *ready) sliceAndTruncate() []*reactor.UserMessage {
|
||||
msgs := r.queue.sliceWithSize(r.offsetIndex+1, r.queue.lastIndex+1, 0)
|
||||
if len(msgs) > 0 {
|
||||
r.endIndex = msgs[len(msgs)-1].Index()
|
||||
r.endIndex = msgs[len(msgs)-1].Index
|
||||
}
|
||||
r.truncate()
|
||||
return msgs
|
||||
@@ -57,14 +57,6 @@ func (r *ready) truncate() {
|
||||
r.offsetIndex = r.endIndex
|
||||
}
|
||||
|
||||
func (r *ready) truncateTo(index uint64) {
|
||||
r.queue.truncateTo(index)
|
||||
}
|
||||
|
||||
func (r *ready) resetState() {
|
||||
r.state.Reset()
|
||||
}
|
||||
|
||||
func (r *ready) reset() {
|
||||
r.state.Reset()
|
||||
r.queue.reset()
|
||||
@@ -84,57 +76,41 @@ func (r *ready) tick() {
|
||||
r.state.Tick()
|
||||
}
|
||||
|
||||
func (r *ready) startProcessing() {
|
||||
r.state.StartProcessing()
|
||||
}
|
||||
|
||||
func (r *ready) processSuccess() {
|
||||
r.state.ProcessSuccess()
|
||||
}
|
||||
|
||||
func (r *ready) processFail() {
|
||||
r.state.ProcessFail()
|
||||
}
|
||||
|
||||
func (r *ready) isMaxRetry() bool {
|
||||
return r.state.IsMaxRetry()
|
||||
}
|
||||
|
||||
func (r *ready) append(m reactor.UserMessage) {
|
||||
m.SetIndex(r.queue.lastIndex + 1)
|
||||
func (r *ready) append(m *reactor.UserMessage) {
|
||||
m.Index = r.queue.lastIndex + 1
|
||||
r.queue.append(m)
|
||||
}
|
||||
|
||||
type outboundReady struct {
|
||||
queue *msgQueue // 消息队列
|
||||
followers map[uint64]*followerState // 副本数据状态
|
||||
offsetIndex uint64 // 当前偏移的下标
|
||||
commitIndex uint64 // 已提交的下标
|
||||
queue *msgQueue // 消息队列
|
||||
replicas map[uint64]*replicaState // 副本数据状态
|
||||
offsetIndex uint64 // 当前偏移的下标
|
||||
commitIndex uint64 // 已提交的下标
|
||||
wklog.Log
|
||||
}
|
||||
|
||||
func newOutboundReady(logPrefix string) *outboundReady {
|
||||
|
||||
return &outboundReady{
|
||||
queue: newMsgQueue(logPrefix),
|
||||
followers: map[uint64]*followerState{},
|
||||
Log: wklog.NewWKLog(logPrefix),
|
||||
queue: newMsgQueue(logPrefix),
|
||||
replicas: map[uint64]*replicaState{},
|
||||
Log: wklog.NewWKLog(logPrefix),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *outboundReady) append(m reactor.UserMessage) {
|
||||
m.SetIndex(o.queue.lastIndex + 1)
|
||||
func (o *outboundReady) append(m *reactor.UserMessage) {
|
||||
m.Index = o.queue.lastIndex + 1
|
||||
o.queue.append(m)
|
||||
}
|
||||
|
||||
func (o *outboundReady) has() bool {
|
||||
for _, follower := range o.followers {
|
||||
for _, replica := range o.replicas {
|
||||
// 需要等会再发起转发,别太快
|
||||
if follower.forwardIdleTick < options.OutboundForwardIntervalTick {
|
||||
if replica.forwardIdleTick < options.OutboundForwardIntervalTick {
|
||||
return false
|
||||
}
|
||||
// 是否符合转发条件
|
||||
if follower.outboundForwardedIndex >= o.commitIndex && follower.outboundForwardedIndex < o.queue.lastIndex {
|
||||
if replica.outboundForwardedIndex >= o.commitIndex && replica.outboundForwardedIndex < o.queue.lastIndex {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -144,65 +120,83 @@ func (o *outboundReady) has() bool {
|
||||
func (o *outboundReady) ready() []reactor.UserAction {
|
||||
var actions []reactor.UserAction
|
||||
var endIndex = o.queue.lastIndex
|
||||
for nodeId, follower := range o.followers {
|
||||
msgs := o.queue.sliceWithSize(o.offsetIndex+1, endIndex+1, 0)
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
lastIndex := msgs[len(msgs)-1].Index
|
||||
o.queue.truncateTo(lastIndex + 1)
|
||||
|
||||
if follower.forwardIdleTick < options.OutboundForwardIntervalTick {
|
||||
hasToNode := false
|
||||
for _, msg := range msgs {
|
||||
// 如果ToNode有值,说明指定了接收的节点,这种情况需要判断是不是当前节点
|
||||
if msg.ToNode != 0 {
|
||||
actions = append(actions, reactor.UserAction{
|
||||
Type: reactor.UserActionOutboundForward,
|
||||
From: options.NodeId,
|
||||
To: msg.ToNode,
|
||||
Messages: []*reactor.UserMessage{
|
||||
msg,
|
||||
},
|
||||
})
|
||||
hasToNode = true
|
||||
continue
|
||||
}
|
||||
if follower.outboundForwardedIndex >= o.commitIndex && follower.outboundForwardedIndex < o.queue.lastIndex {
|
||||
if actions == nil {
|
||||
actions = make([]reactor.UserAction, 0, len(o.followers))
|
||||
}
|
||||
|
||||
var newMsgs []*reactor.UserMessage
|
||||
if hasToNode {
|
||||
for _, msg := range msgs {
|
||||
if msg.ToNode == 0 {
|
||||
newMsgs = append(newMsgs, msg)
|
||||
}
|
||||
// 不能超过指定数量
|
||||
if o.queue.lastIndex-follower.outboundForwardedIndex > options.OutboundForwardMaxMessageCount {
|
||||
endIndex = o.queue.lastIndex - follower.outboundForwardedIndex + options.OutboundForwardMaxMessageCount
|
||||
}
|
||||
msgs := o.queue.sliceWithSize(follower.outboundForwardedIndex+1, endIndex+1, 0)
|
||||
if len(msgs) == 0 {
|
||||
continue
|
||||
}
|
||||
actions = append(actions, reactor.UserAction{
|
||||
Type: reactor.UserActionOutboundForward,
|
||||
From: options.NodeId,
|
||||
To: nodeId,
|
||||
Messages: msgs,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
newMsgs = msgs
|
||||
}
|
||||
|
||||
for nodeId := range o.replicas {
|
||||
actions = append(actions, reactor.UserAction{
|
||||
Type: reactor.UserActionOutboundForward,
|
||||
From: options.NodeId,
|
||||
To: nodeId,
|
||||
Messages: newMsgs,
|
||||
})
|
||||
}
|
||||
return actions
|
||||
}
|
||||
|
||||
func (o *outboundReady) updateFollowerIndex(nodeId uint64, index uint64) {
|
||||
follower := o.followers[nodeId]
|
||||
if follower == nil {
|
||||
o.Info("follower not exist", zap.Uint64("nodeId", nodeId))
|
||||
func (o *outboundReady) updateReplicaIndex(nodeId uint64, index uint64) {
|
||||
replica := o.replicas[nodeId]
|
||||
if replica == nil {
|
||||
o.Info("replica not exist", zap.Uint64("nodeId", nodeId))
|
||||
return
|
||||
}
|
||||
if index > follower.outboundForwardedIndex {
|
||||
follower.outboundForwardedIndex = index
|
||||
if index > replica.outboundForwardedIndex {
|
||||
replica.outboundForwardedIndex = index
|
||||
o.checkCommit()
|
||||
}
|
||||
follower.forwardIdleTick = 0
|
||||
replica.forwardIdleTick = 0
|
||||
}
|
||||
|
||||
func (o *outboundReady) updateFollowerHeartbeat(nodeId uint64) {
|
||||
follower := o.followers[nodeId]
|
||||
if follower == nil {
|
||||
o.Info("follower not exist", zap.Uint64("nodeId", nodeId))
|
||||
func (o *outboundReady) updateReplicaHeartbeat(nodeId uint64) {
|
||||
replica := o.replicas[nodeId]
|
||||
if replica == nil {
|
||||
o.Info("replica not exist", zap.Uint64("nodeId", nodeId))
|
||||
return
|
||||
}
|
||||
follower.heartbeatIdleTick = 0
|
||||
replica.heartbeatIdleTick = 0
|
||||
}
|
||||
|
||||
func (o *outboundReady) addNewFollower(nodeId uint64) {
|
||||
o.followers[nodeId] = newFollowerState(nodeId, o.commitIndex)
|
||||
func (o *outboundReady) addNewReplica(nodeId uint64) {
|
||||
o.replicas[nodeId] = newReplicaState(nodeId, o.commitIndex)
|
||||
}
|
||||
|
||||
func (o *outboundReady) checkCommit() {
|
||||
var minIndex = o.commitIndex
|
||||
for _, follower := range o.followers {
|
||||
if follower.outboundForwardedIndex > minIndex {
|
||||
minIndex = follower.outboundForwardedIndex
|
||||
for _, replica := range o.replicas {
|
||||
if replica.outboundForwardedIndex > minIndex {
|
||||
minIndex = replica.outboundForwardedIndex
|
||||
}
|
||||
}
|
||||
if minIndex > o.commitIndex {
|
||||
@@ -216,13 +210,13 @@ func (o *outboundReady) commit(index uint64) {
|
||||
}
|
||||
|
||||
func (o *outboundReady) tick() {
|
||||
for _, follower := range o.followers {
|
||||
follower.heartbeatIdleTick++
|
||||
follower.forwardIdleTick++
|
||||
for _, replica := range o.replicas {
|
||||
replica.heartbeatIdleTick++
|
||||
replica.forwardIdleTick++
|
||||
|
||||
if follower.heartbeatIdleTick >= options.NodeHeartbeatTimeoutTick {
|
||||
o.Info("follower heartbeat timeout", zap.Uint64("nodeId", follower.nodeId))
|
||||
delete(o.followers, follower.nodeId)
|
||||
if replica.heartbeatIdleTick >= options.NodeHeartbeatTimeoutTick {
|
||||
o.Info("replica heartbeat timeout", zap.Uint64("nodeId", replica.nodeId))
|
||||
delete(o.replicas, replica.nodeId)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -232,20 +226,20 @@ func (o *outboundReady) reset() {
|
||||
o.queue.reset()
|
||||
o.commitIndex = 0
|
||||
o.offsetIndex = 0
|
||||
o.followers = map[uint64]*followerState{}
|
||||
o.replicas = map[uint64]*replicaState{}
|
||||
|
||||
}
|
||||
|
||||
type followerState struct {
|
||||
type replicaState struct {
|
||||
nodeId uint64 // 节点id
|
||||
outboundForwardedIndex uint64 // 发件箱已转发索引
|
||||
forwardIdleTick int // 转发空闲tick
|
||||
heartbeatIdleTick int // 心跳空闲tick
|
||||
}
|
||||
|
||||
func newFollowerState(nodeId uint64, outboundForwardedIndex uint64) *followerState {
|
||||
func newReplicaState(nodeId uint64, outboundForwardedIndex uint64) *replicaState {
|
||||
|
||||
return &followerState{
|
||||
return &replicaState{
|
||||
nodeId: nodeId,
|
||||
outboundForwardedIndex: outboundForwardedIndex,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package reactor
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
)
|
||||
@@ -11,6 +13,12 @@ type testConn struct {
|
||||
from uint64
|
||||
auth bool
|
||||
deviceFlag wkproto.DeviceFlag
|
||||
deviceId string
|
||||
valueLock sync.RWMutex
|
||||
valueMap map[string]string
|
||||
|
||||
protoVersion uint8
|
||||
deviceLevel wkproto.DeviceLevel
|
||||
}
|
||||
|
||||
func (t *testConn) ConnId() int64 {
|
||||
@@ -35,25 +43,50 @@ func (t *testConn) DeviceFlag() wkproto.DeviceFlag {
|
||||
return t.deviceFlag
|
||||
}
|
||||
|
||||
type testMessage struct {
|
||||
index uint64
|
||||
conn *testConn
|
||||
func (t *testConn) DeviceId() string {
|
||||
return t.deviceId
|
||||
}
|
||||
|
||||
func (t *testMessage) Conn() reactor.Conn {
|
||||
return t.conn
|
||||
func (t *testConn) Equal(conn *reactor.Conn) bool {
|
||||
return t.connId == conn.ConnId && t.uid == conn.Uid && t.from == conn.FromNode
|
||||
}
|
||||
func (t *testMessage) Frame() wkproto.Frame {
|
||||
|
||||
func (t *testConn) SetString(key string, value string) {
|
||||
t.valueLock.Lock()
|
||||
defer t.valueLock.Unlock()
|
||||
if t.valueMap == nil {
|
||||
t.valueMap = make(map[string]string)
|
||||
}
|
||||
t.valueMap[key] = value
|
||||
}
|
||||
|
||||
func (t *testConn) GetString(key string) string {
|
||||
t.valueLock.RLock()
|
||||
defer t.valueLock.RUnlock()
|
||||
return t.valueMap[key]
|
||||
}
|
||||
|
||||
func (t *testConn) SetProtoVersion(version uint8) {
|
||||
t.protoVersion = version
|
||||
}
|
||||
|
||||
func (t *testConn) GetProtoVersion() uint8 {
|
||||
return t.protoVersion
|
||||
}
|
||||
|
||||
func (t *testConn) DeviceLevel() wkproto.DeviceLevel {
|
||||
return t.deviceLevel
|
||||
}
|
||||
|
||||
func (t *testConn) SetDeviceLevel(level wkproto.DeviceLevel) {
|
||||
t.deviceLevel = level
|
||||
}
|
||||
|
||||
func (t *testConn) Encode() ([]byte, error) {
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (t *testConn) Decode(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
func (t *testMessage) Size() uint64 {
|
||||
return 10
|
||||
}
|
||||
|
||||
func (t *testMessage) SetIndex(index uint64) {
|
||||
t.index = index
|
||||
}
|
||||
|
||||
func (t *testMessage) Index() uint64 {
|
||||
return t.index
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ type User struct {
|
||||
no string // 唯一编号
|
||||
uid string
|
||||
conns *conns
|
||||
authReady *ready //认证
|
||||
role reactor.Role
|
||||
stepFnc func(a reactor.UserAction)
|
||||
tickFnc func()
|
||||
@@ -22,32 +21,40 @@ type User struct {
|
||||
cfg reactor.UserConfig
|
||||
joined bool // 是否已成功加入集群
|
||||
wklog.Log
|
||||
inbound *ready // 收件箱
|
||||
outbound *outboundReady // 发送箱
|
||||
// 节点收件箱
|
||||
// 来接收其他节点投递到当前节点的消息,此收件箱的数据会进入process
|
||||
inbound *ready
|
||||
// 节点发送箱
|
||||
// (节点直接投递数据用,如果当前节点是领导,转发给追随者,如果是追随者,转发给领导)
|
||||
outbound *outboundReady
|
||||
// 客户端发送箱
|
||||
// 将数据投递给自己直连的客户端
|
||||
clientOutbound *ready
|
||||
}
|
||||
|
||||
func NewUser(no, uid string) *User {
|
||||
|
||||
prefix := fmt.Sprintf("user[%s]", uid)
|
||||
return &User{
|
||||
no: no,
|
||||
uid: uid,
|
||||
authReady: newReady(prefix),
|
||||
conns: &conns{},
|
||||
Log: wklog.NewWKLog(prefix),
|
||||
inbound: newReady(prefix),
|
||||
outbound: newOutboundReady(prefix),
|
||||
no: no,
|
||||
uid: uid,
|
||||
conns: &conns{},
|
||||
Log: wklog.NewWKLog(prefix),
|
||||
inbound: newReady(prefix),
|
||||
outbound: newOutboundReady(prefix),
|
||||
clientOutbound: newReady(prefix),
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================== ready ====================================
|
||||
|
||||
func (u *User) hasReady() bool {
|
||||
if u.needElection() {
|
||||
|
||||
if len(u.actions) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if u.authReady.has() {
|
||||
if u.needElection() {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -59,7 +66,7 @@ func (u *User) hasReady() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(u.actions) > 0 {
|
||||
if u.clientOutbound.has() {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -67,28 +74,18 @@ func (u *User) hasReady() bool {
|
||||
}
|
||||
|
||||
func (u *User) ready() []reactor.UserAction {
|
||||
if u.electioned() {
|
||||
// ---------- auth ----------
|
||||
if u.authReady.has() {
|
||||
msgs := u.authReady.sliceAndTruncate()
|
||||
u.actions = append(u.actions, reactor.UserAction{
|
||||
No: u.no,
|
||||
From: options.NodeId,
|
||||
To: reactor.LocalNode,
|
||||
Type: reactor.UserActionAuth,
|
||||
Messages: msgs,
|
||||
})
|
||||
}
|
||||
|
||||
if u.allowReady() {
|
||||
// ---------- inbound ----------
|
||||
if u.inbound.has() {
|
||||
msgs := u.inbound.sliceAndTruncate()
|
||||
u.actions = append(u.actions, reactor.UserAction{
|
||||
No: u.no,
|
||||
Uid: u.uid,
|
||||
From: options.NodeId,
|
||||
To: reactor.LocalNode,
|
||||
Type: reactor.UserActionInbound,
|
||||
Messages: msgs,
|
||||
Role: u.role,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -99,38 +96,59 @@ func (u *User) ready() []reactor.UserAction {
|
||||
u.actions = append(u.actions, actions...)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- clientOutbound ----------
|
||||
if u.clientOutbound.has() {
|
||||
msgs := u.clientOutbound.sliceAndTruncate()
|
||||
u.actions = append(u.actions, reactor.UserAction{
|
||||
No: u.no,
|
||||
Uid: u.uid,
|
||||
From: options.NodeId,
|
||||
To: reactor.LocalNode,
|
||||
Type: reactor.UserActionWrite,
|
||||
Messages: msgs,
|
||||
})
|
||||
}
|
||||
}
|
||||
actions := u.actions
|
||||
u.actions = u.actions[:0]
|
||||
return actions
|
||||
}
|
||||
|
||||
func (u *User) allowReady() bool {
|
||||
|
||||
if u.electioned() {
|
||||
if u.role == reactor.RoleLeader {
|
||||
return true
|
||||
}
|
||||
if u.role == reactor.RoleFollower && u.joined {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ==================================== step ====================================
|
||||
|
||||
func (u *User) step(action reactor.UserAction) {
|
||||
fmt.Println("step----->", action.Uid, action.Type.String())
|
||||
u.idleTick = 0
|
||||
switch action.Type {
|
||||
case reactor.UserActionConfigUpdate:
|
||||
fmt.Println("UserActionConfigUpdate....")
|
||||
u.handleConfigUpdate(action.Cfg)
|
||||
case reactor.UserActionAuthAdd:
|
||||
for _, msg := range action.Messages {
|
||||
if msg.Conn() == nil {
|
||||
if msg.Conn == nil {
|
||||
u.Warn("add auth failed, msg conn not exist", zap.String("uid", action.Uid))
|
||||
return
|
||||
}
|
||||
conn := u.conns.connByConnId(msg.Conn().FromNode(), msg.Conn().ConnId())
|
||||
conn := u.conns.connByConnId(msg.Conn.FromNode, msg.Conn.ConnId)
|
||||
if conn != nil {
|
||||
u.Warn("add auth failed, conn exist", zap.String("uid", action.Uid))
|
||||
return
|
||||
}
|
||||
u.conns.add(msg.Conn())
|
||||
u.authReady.append(msg)
|
||||
}
|
||||
case reactor.UserActionAuthResp: // 认证处理返回
|
||||
for _, conn := range action.Conns {
|
||||
conn.SetAuth(action.Success)
|
||||
u.conns.updateConn(conn.ConnId(), conn.FromNode(), conn)
|
||||
u.conns.add(msg.Conn)
|
||||
u.inbound.append(msg)
|
||||
}
|
||||
case reactor.UserActionInboundAdd: // 收件箱
|
||||
for _, msg := range action.Messages {
|
||||
@@ -141,7 +159,17 @@ func (u *User) step(action reactor.UserAction) {
|
||||
u.outbound.append(msg)
|
||||
}
|
||||
case reactor.UserActionOutboundForwardResp: // 转发返回
|
||||
u.outbound.updateFollowerIndex(action.From, action.Index)
|
||||
u.outbound.updateReplicaIndex(action.From, action.Index)
|
||||
case reactor.UserActionConnClose: // 关闭连接
|
||||
for _, conn := range u.conns.conns {
|
||||
u.conns.remove(conn)
|
||||
}
|
||||
u.sendConnClose(action.Conns)
|
||||
case reactor.UserActionWrite: // 写数据
|
||||
for _, msg := range action.Messages {
|
||||
u.clientOutbound.append(msg)
|
||||
}
|
||||
|
||||
default:
|
||||
if u.stepFnc != nil {
|
||||
u.stepFnc(action)
|
||||
@@ -155,21 +183,21 @@ func (u *User) stepFollower(action reactor.UserAction) {
|
||||
case reactor.UserActionNodeHeartbeatReq:
|
||||
u.heartbeatTick = 0
|
||||
localConns := u.conns.connsByNodeId(options.NodeId)
|
||||
var closeConns []reactor.Conn
|
||||
var closeConns []*reactor.Conn
|
||||
for _, localConn := range localConns {
|
||||
if !localConn.IsAuth() {
|
||||
if !localConn.Auth {
|
||||
continue
|
||||
}
|
||||
exist := false
|
||||
for _, conn := range action.Conns {
|
||||
if localConn.ConnId() == conn.ConnId() {
|
||||
if localConn.ConnId == conn.ConnId {
|
||||
exist = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exist {
|
||||
if closeConns == nil {
|
||||
closeConns = make([]reactor.Conn, 0, len(localConns))
|
||||
closeConns = make([]*reactor.Conn, 0, len(localConns))
|
||||
}
|
||||
closeConns = append(closeConns, localConn)
|
||||
}
|
||||
@@ -188,10 +216,10 @@ func (u *User) stepLeader(action reactor.UserAction) {
|
||||
switch action.Type {
|
||||
// 副本收到心跳回应
|
||||
case reactor.UserActionNodeHeartbeatResp:
|
||||
u.outbound.updateFollowerHeartbeat(action.From)
|
||||
u.outbound.updateReplicaHeartbeat(action.From)
|
||||
// 副本请求加入
|
||||
case reactor.UserActionJoin:
|
||||
u.outbound.addNewFollower(action.From)
|
||||
u.outbound.addNewReplica(action.From)
|
||||
u.sendJoinResp(action.From)
|
||||
}
|
||||
}
|
||||
@@ -200,7 +228,6 @@ func (u *User) stepLeader(action reactor.UserAction) {
|
||||
|
||||
func (u *User) tick() {
|
||||
|
||||
u.authReady.tick()
|
||||
u.inbound.tick()
|
||||
u.outbound.tick()
|
||||
|
||||
@@ -217,7 +244,7 @@ func (u *User) tickLeader() {
|
||||
u.heartbeatTick++
|
||||
u.idleTick++
|
||||
|
||||
if len(u.outbound.followers) == 0 && u.idleTick >= options.LeaderIdleTimeoutTick {
|
||||
if u.conns.len() == 0 && len(u.outbound.replicas) == 0 && u.idleTick >= options.LeaderIdleTimeoutTick {
|
||||
u.idleTick = 0
|
||||
u.sendUserClose()
|
||||
return
|
||||
@@ -243,12 +270,12 @@ func (u *User) tickFollower() {
|
||||
// ==================================== send ====================================
|
||||
|
||||
func (u *User) sendHeartbeatReq() {
|
||||
for _, follower := range u.outbound.followers {
|
||||
conns := u.conns.connsByNodeId(follower.nodeId)
|
||||
for _, replica := range u.outbound.replicas {
|
||||
conns := u.conns.connsByNodeId(replica.nodeId)
|
||||
u.actions = append(u.actions, reactor.UserAction{
|
||||
No: u.no,
|
||||
From: options.NodeId,
|
||||
To: follower.nodeId,
|
||||
To: replica.nodeId,
|
||||
Uid: u.uid,
|
||||
Type: reactor.UserActionNodeHeartbeatReq,
|
||||
Conns: conns,
|
||||
@@ -269,7 +296,7 @@ func (u *User) sendHeartbeatResp(to uint64) {
|
||||
})
|
||||
}
|
||||
|
||||
func (u *User) sendConnClose(conns []reactor.Conn) {
|
||||
func (u *User) sendConnClose(conns []*reactor.Conn) {
|
||||
u.actions = append(u.actions, reactor.UserAction{
|
||||
No: u.no,
|
||||
From: options.NodeId,
|
||||
@@ -331,6 +358,9 @@ func (u *User) becomeFollower() {
|
||||
u.stepFnc = u.stepFollower
|
||||
u.tickFnc = u.tickFollower
|
||||
|
||||
// 如果是追随者,需要添加领导到副本列表
|
||||
u.outbound.addNewReplica(u.cfg.LeaderId)
|
||||
|
||||
u.Info("become follower")
|
||||
}
|
||||
|
||||
@@ -371,7 +401,7 @@ func (u *User) reset() {
|
||||
u.cfg = reactor.UserConfig{}
|
||||
u.inbound.reset()
|
||||
u.outbound.reset()
|
||||
u.authReady.reset()
|
||||
u.clientOutbound.reset()
|
||||
u.conns.reset()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -18,6 +19,15 @@ func hasAction(t testing.TB, actionType reactor.UserActionType, actions []reacto
|
||||
assert.True(t, exist)
|
||||
}
|
||||
|
||||
func getAction(t testing.TB, actionType reactor.UserActionType, actions []reactor.UserAction) reactor.UserAction {
|
||||
for _, action := range actions {
|
||||
if action.Type == actionType {
|
||||
return action
|
||||
}
|
||||
}
|
||||
return reactor.UserAction{}
|
||||
}
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
|
||||
options = NewOptions()
|
||||
@@ -49,6 +59,7 @@ func TestUser(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("testElection", func(t *testing.T) {
|
||||
u.tick()
|
||||
actions := u.ready()
|
||||
hasAction(t, reactor.UserActionElection, actions)
|
||||
})
|
||||
@@ -71,30 +82,27 @@ func TestUser(t *testing.T) {
|
||||
becomeLeader()
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionAuthAdd,
|
||||
Messages: []reactor.UserMessage{
|
||||
&testMessage{
|
||||
conn: &testConn{
|
||||
from: 1,
|
||||
connId: 1,
|
||||
Messages: []*reactor.UserMessage{
|
||||
&reactor.UserMessage{
|
||||
Conn: &reactor.Conn{
|
||||
FromNode: 1,
|
||||
ConnId: 1,
|
||||
},
|
||||
Frame: wkproto.ConnectPacket{},
|
||||
},
|
||||
},
|
||||
})
|
||||
actions := u.ready()
|
||||
hasAction(t, reactor.UserActionAuth, actions)
|
||||
hasAction(t, reactor.UserActionInbound, actions)
|
||||
|
||||
action := getAction(t, reactor.UserActionInbound, actions)
|
||||
|
||||
assert.Equal(t, 1, len(action.Messages))
|
||||
|
||||
u.conns.updateConnAuth(1, 1, true)
|
||||
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionAuthResp,
|
||||
Conns: []reactor.Conn{
|
||||
&testConn{
|
||||
from: 1,
|
||||
connId: 1,
|
||||
},
|
||||
},
|
||||
Success: true,
|
||||
})
|
||||
conn := u.conns.connByConnId(1, 1)
|
||||
assert.Equal(t, true, conn.IsAuth())
|
||||
assert.Equal(t, true, conn.Auth)
|
||||
|
||||
// only test
|
||||
u.conns.conns = nil
|
||||
@@ -105,8 +113,8 @@ func TestUser(t *testing.T) {
|
||||
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionInboundAdd,
|
||||
Messages: []reactor.UserMessage{
|
||||
&testMessage{},
|
||||
Messages: []*reactor.UserMessage{
|
||||
&reactor.UserMessage{},
|
||||
},
|
||||
})
|
||||
actions := u.ready()
|
||||
@@ -126,9 +134,9 @@ func TestUser(t *testing.T) {
|
||||
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionOutboundAdd,
|
||||
Messages: []reactor.UserMessage{
|
||||
&testMessage{},
|
||||
&testMessage{},
|
||||
Messages: []*reactor.UserMessage{
|
||||
&reactor.UserMessage{},
|
||||
&reactor.UserMessage{},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -160,11 +168,11 @@ func TestUser(t *testing.T) {
|
||||
becomeReplica()
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionAuthAdd,
|
||||
Messages: []reactor.UserMessage{
|
||||
&testMessage{
|
||||
conn: &testConn{
|
||||
connId: 1,
|
||||
from: 1,
|
||||
Messages: []*reactor.UserMessage{
|
||||
&reactor.UserMessage{
|
||||
Conn: &reactor.Conn{
|
||||
ConnId: 1,
|
||||
FromNode: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -177,16 +185,7 @@ func TestUser(t *testing.T) {
|
||||
})
|
||||
assert.Equal(t, 1, u.conns.len())
|
||||
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionAuthResp,
|
||||
Success: true,
|
||||
Conns: []reactor.Conn{
|
||||
&testConn{
|
||||
from: 1,
|
||||
connId: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
u.conns.updateConnAuth(1, 1, true)
|
||||
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionNodeHeartbeatReq,
|
||||
@@ -224,25 +223,17 @@ func TestUserRoleChange(t *testing.T) {
|
||||
return 1
|
||||
}
|
||||
u := NewUser("no", "uid")
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionAuth,
|
||||
Messages: []reactor.UserMessage{
|
||||
&testMessage{
|
||||
conn: &testConn{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionInboundAdd,
|
||||
Messages: []reactor.UserMessage{
|
||||
&testMessage{},
|
||||
Messages: []*reactor.UserMessage{
|
||||
&reactor.UserMessage{},
|
||||
},
|
||||
})
|
||||
u.step(reactor.UserAction{
|
||||
Type: reactor.UserActionOutboundAdd,
|
||||
Messages: []reactor.UserMessage{
|
||||
&testMessage{},
|
||||
Messages: []*reactor.UserMessage{
|
||||
&reactor.UserMessage{},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -266,13 +257,11 @@ func TestUserRoleChange(t *testing.T) {
|
||||
|
||||
becomeFollower()
|
||||
|
||||
assert.Equal(t, 1, u.authReady.queue.len())
|
||||
assert.Equal(t, 1, u.inbound.queue.len())
|
||||
assert.Equal(t, 1, u.outbound.queue.len())
|
||||
|
||||
becomeLeader()
|
||||
|
||||
assert.Equal(t, 0, u.authReady.queue.len())
|
||||
assert.Equal(t, 0, u.inbound.queue.len())
|
||||
assert.Equal(t, 0, u.outbound.queue.len())
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -240,8 +240,8 @@ func (u *UserAPI) getOnlineConns(uids []string) []*OnlinestatusResp {
|
||||
conns := reactor.User.ConnsByUid(uid)
|
||||
for _, conn := range conns {
|
||||
onlineStatusResps = append(onlineStatusResps, &OnlinestatusResp{
|
||||
UID: conn.Uid(),
|
||||
DeviceFlag: conn.DeviceFlag().ToUint8(),
|
||||
UID: conn.Uid,
|
||||
DeviceFlag: conn.DeviceFlag.ToUint8(),
|
||||
Online: 1,
|
||||
})
|
||||
}
|
||||
@@ -368,7 +368,7 @@ func (u *UserAPI) updateToken(c *wkhttp.Context) {
|
||||
oldConns := reactor.User.ConnsByDeviceFlag(req.UID, req.DeviceFlag)
|
||||
if len(oldConns) > 0 {
|
||||
for _, oldConn := range oldConns {
|
||||
u.Debug("更新Token时,存在旧连接!", zap.String("uid", req.UID), zap.Int64("id", oldConn.ConnId()), zap.String("deviceFlag", req.DeviceFlag.String()))
|
||||
u.Debug("更新Token时,存在旧连接!", zap.String("uid", req.UID), zap.Int64("id", oldConn.ConnId), zap.String("deviceFlag", req.DeviceFlag.String()))
|
||||
// 踢旧连接
|
||||
reactor.User.Kick(oldConn, wkproto.ReasonConnectKick, "账号在其他设备上登录")
|
||||
u.s.timingWheel.AfterFunc(time.Second*10, func() {
|
||||
|
||||
@@ -37,7 +37,7 @@ func (v *VarzAPI) Route(r *wkhttp.WKHttp) {
|
||||
|
||||
func (v *VarzAPI) HandleVarz(c *wkhttp.Context) {
|
||||
|
||||
show := c.Query("show")
|
||||
// show := c.Query("show")
|
||||
connLimit, _ := strconv.Atoi(c.Query("conn_limit"))
|
||||
nodeId := wkutil.ParseUint64(c.Query("node_id"))
|
||||
|
||||
@@ -61,17 +61,17 @@ func (v *VarzAPI) HandleVarz(c *wkhttp.Context) {
|
||||
}
|
||||
varz := CreateVarz(v.s)
|
||||
|
||||
if show == "conn" {
|
||||
resultConns := v.s.GetConnInfos("", ByInMsgDesc, 0, connLimit)
|
||||
connInfos := make([]*ConnInfo, 0, len(resultConns))
|
||||
for _, resultConn := range resultConns {
|
||||
if resultConn == nil || !resultConn.isAuth.Load() {
|
||||
continue
|
||||
}
|
||||
connInfos = append(connInfos, newConnInfo(resultConn))
|
||||
}
|
||||
varz.Conns = connInfos
|
||||
}
|
||||
// if show == "conn" {
|
||||
// resultConns := v.s.GetConnInfos("", ByInMsgDesc, 0, connLimit)
|
||||
// connInfos := make([]*ConnInfo, 0, len(resultConns))
|
||||
// for _, resultConn := range resultConns {
|
||||
// if resultConn == nil || !resultConn.isAuth.Load() {
|
||||
// continue
|
||||
// }
|
||||
// connInfos = append(connInfos, newConnInfo(resultConn))
|
||||
// }
|
||||
// varz.Conns = connInfos
|
||||
// }
|
||||
|
||||
c.JSON(http.StatusOK, varz)
|
||||
}
|
||||
@@ -166,10 +166,10 @@ type Varz struct {
|
||||
TreeState string `json:"tree_state"` // git tree state
|
||||
APIURL string `json:"api_url"` // api地址
|
||||
|
||||
ManagerUID string `json:"manager_uid"` // 管理员uid
|
||||
ManagerTokenOn int `json:"manager_token_on"` // 管理员token是否开启
|
||||
Conns []*ConnInfo `json:"conns,omitempty"` // 连接信息
|
||||
ConversationCacheCount int `json:"conversation_cache_count"` // 最近会话缓存数量
|
||||
ManagerUID string `json:"manager_uid"` // 管理员uid
|
||||
ManagerTokenOn int `json:"manager_token_on"` // 管理员token是否开启
|
||||
// Conns []*ConnInfo `json:"conns,omitempty"` // 连接信息
|
||||
ConversationCacheCount int `json:"conversation_cache_count"` // 最近会话缓存数量
|
||||
}
|
||||
|
||||
type SystemSetting struct {
|
||||
|
||||
@@ -195,7 +195,7 @@ func (r *channelReactor) handlePayloadDecrypt(req *payloadDecryptReq) {
|
||||
var decryptPayload []byte
|
||||
conn := reactor.User.LocalConnById(msg.FromUid, msg.FromConnId)
|
||||
if conn != nil && len(msg.SendPacket.Payload) > 0 {
|
||||
decryptPayload, err = r.s.checkAndDecodePayload(msg.SendPacket, conn.(*connContext))
|
||||
decryptPayload, err = r.s.checkAndDecodePayload(msg.SendPacket, conn)
|
||||
if err != nil {
|
||||
msg.ReasonCode = wkproto.ReasonPayloadDecodeError
|
||||
r.Warn("decrypt payload error", zap.String("uid", msg.FromUid), zap.String("deviceId", msg.FromDeviceId), zap.Int64("connId", msg.FromConnId), zap.Error(err))
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wklog"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wknet"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type connStats struct {
|
||||
inPacketCount atomic.Int64 // 输入包数量
|
||||
outPacketCount atomic.Int64 // 输出包数量
|
||||
|
||||
inPacketByteCount atomic.Int64 // 输入包字节数量
|
||||
outPacketByteCount atomic.Int64 // 输出包字节数量
|
||||
|
||||
inMsgCount atomic.Int64 // 输入消息数量
|
||||
outMsgCount atomic.Int64 // 输出消息数量
|
||||
|
||||
inMsgByteCount atomic.Int64 // 输入消息字节数量
|
||||
outMsgByteCount atomic.Int64 // 输出消息字节数量
|
||||
}
|
||||
|
||||
type connInfo struct {
|
||||
connId int64 // 连接在本节点的id
|
||||
proxyConnId int64 // 连接在代理节点的id
|
||||
uid string
|
||||
deviceId string
|
||||
deviceFlag wkproto.DeviceFlag
|
||||
deviceLevel wkproto.DeviceLevel
|
||||
aesKey []byte
|
||||
aesIV []byte
|
||||
protoVersion uint8
|
||||
|
||||
closed atomic.Bool
|
||||
|
||||
isAuth atomic.Bool // 是否已经认证
|
||||
|
||||
}
|
||||
|
||||
type connContext struct {
|
||||
connInfo // 连接信息
|
||||
connStats // 统计信息
|
||||
conn wknet.Conn
|
||||
|
||||
realNodeId uint64 // 真实的节点id
|
||||
isRealConn bool // 是否是真实的连接
|
||||
|
||||
uptime atomic.Time // 启动时间
|
||||
|
||||
lastActivity atomic.Int64 // 最后活动时间
|
||||
|
||||
wklog.Log
|
||||
}
|
||||
|
||||
func newConnContext(connInfo connInfo, conn wknet.Conn) *connContext {
|
||||
c := &connContext{
|
||||
connInfo: connInfo,
|
||||
conn: conn,
|
||||
isRealConn: true,
|
||||
Log: wklog.NewWKLog(fmt.Sprintf("connContext[%s]", connInfo.uid)),
|
||||
}
|
||||
|
||||
c.uptime.Store(time.Now())
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func newConnContextProxy(realNodeId uint64, connInfo connInfo) *connContext {
|
||||
return &connContext{
|
||||
connInfo: connInfo,
|
||||
realNodeId: realNodeId,
|
||||
isRealConn: false,
|
||||
Log: wklog.NewWKLog(fmt.Sprintf("connContext[%s]", connInfo.uid)),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connContext) keepActivity() {
|
||||
c.lastActivity.Store(time.Now().Unix())
|
||||
}
|
||||
|
||||
func (c *connContext) isClosed() bool {
|
||||
return c.closed.Load()
|
||||
}
|
||||
|
||||
func (c *connContext) String() string {
|
||||
fd := 0
|
||||
if c.conn != nil {
|
||||
fd = c.conn.Fd().Fd()
|
||||
}
|
||||
return fmt.Sprintf("uid: %s connId: %d deviceId: %s deviceFlag: %s isRealConn: %v proxyConnId: %d realNodeId: %d fd: %d", c.uid, c.connId, c.deviceId, c.deviceFlag.String(), c.isRealConn, c.proxyConnId, c.realNodeId, fd)
|
||||
}
|
||||
|
||||
func (c *connContext) ConnId() int64 {
|
||||
return c.connId
|
||||
}
|
||||
func (c *connContext) Uid() string {
|
||||
return c.uid
|
||||
}
|
||||
func (c *connContext) FromNode() uint64 {
|
||||
return c.realNodeId
|
||||
}
|
||||
func (c *connContext) SetAuth(auth bool) {
|
||||
c.isAuth.Store(auth)
|
||||
}
|
||||
func (c *connContext) IsAuth() bool {
|
||||
return c.isAuth.Load()
|
||||
}
|
||||
|
||||
func (c *connContext) DeviceFlag() wkproto.DeviceFlag {
|
||||
return c.deviceFlag
|
||||
}
|
||||
41
internal/server/conn_manager.go
Normal file
41
internal/server/conn_manager.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wknet"
|
||||
)
|
||||
|
||||
type connManager struct {
|
||||
connBlucket []map[int64]wknet.Conn
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func newConnManager(blucketCount int) *connManager {
|
||||
connBlucket := make([]map[int64]wknet.Conn, blucketCount)
|
||||
for i := 0; i < blucketCount; i++ {
|
||||
connBlucket[i] = make(map[int64]wknet.Conn)
|
||||
}
|
||||
return &connManager{connBlucket: connBlucket}
|
||||
}
|
||||
|
||||
func (m *connManager) addConn(conn wknet.Conn) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
blucketIndex := conn.ID() % int64(len(m.connBlucket))
|
||||
m.connBlucket[blucketIndex][conn.ID()] = conn
|
||||
}
|
||||
|
||||
func (m *connManager) removeConn(conn wknet.Conn) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
blucketIndex := conn.ID() % int64(len(m.connBlucket))
|
||||
delete(m.connBlucket[blucketIndex], conn.ID())
|
||||
}
|
||||
|
||||
func (m *connManager) getConn(connID int64) wknet.Conn {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
blucketIndex := connID % int64(len(m.connBlucket))
|
||||
return m.connBlucket[blucketIndex][connID]
|
||||
}
|
||||
@@ -343,9 +343,9 @@ func (d *deliverr) getPersonTag(fakeChannelId string) (*tag, error) {
|
||||
}
|
||||
|
||||
type deliverUserSlice struct {
|
||||
offlineUids []string // 离线用户(只要主设备不在线就算离线)
|
||||
toConns []*connContext // 在线接受用户的连接对象
|
||||
onlineUsers []string // 在线用户数量(只要一个客户端在线就算在线)
|
||||
offlineUids []string // 离线用户(只要主设备不在线就算离线)
|
||||
toConns []*reactor.Conn // 在线接受用户的连接对象
|
||||
onlineUsers []string // 在线用户数量(只要一个客户端在线就算在线)
|
||||
}
|
||||
|
||||
func (d *deliverUserSlice) reset() {
|
||||
@@ -358,7 +358,7 @@ var deliverSlicePool = &sync.Pool{
|
||||
New: func() any {
|
||||
return &deliverUserSlice{
|
||||
offlineUids: make([]string, 0),
|
||||
toConns: make([]*connContext, 0),
|
||||
toConns: make([]*reactor.Conn, 0),
|
||||
onlineUsers: make([]string, 0),
|
||||
}
|
||||
},
|
||||
@@ -387,12 +387,9 @@ func (d *deliverr) deliver(req *deliverReq, uids []string) {
|
||||
}
|
||||
// 获取当前用户的所有连接
|
||||
conns := reactor.User.ConnsByUid(toUid)
|
||||
connCtxs := make([]*connContext, 0, len(conns))
|
||||
hasMasterDevice := false
|
||||
for _, conn := range conns {
|
||||
connCtx := conn.(*connContext)
|
||||
connCtxs = append(connCtxs, connCtx)
|
||||
if connCtx.deviceLevel == wkproto.DeviceLevelMaster {
|
||||
if conn.DeviceLevel == wkproto.DeviceLevelMaster {
|
||||
hasMasterDevice = true
|
||||
}
|
||||
}
|
||||
@@ -405,8 +402,8 @@ func (d *deliverr) deliver(req *deliverReq, uids []string) {
|
||||
if len(conns) == 0 {
|
||||
slices.offlineUids = append(slices.offlineUids, toUid)
|
||||
} else {
|
||||
for _, conn := range connCtxs {
|
||||
if !conn.IsAuth() {
|
||||
for _, conn := range conns {
|
||||
if !conn.Auth {
|
||||
continue
|
||||
}
|
||||
slices.toConns = append(slices.toConns, conn)
|
||||
@@ -422,10 +419,10 @@ func (d *deliverr) deliver(req *deliverReq, uids []string) {
|
||||
existSendSelfDevice := false // 存在发送者自己的连接
|
||||
existSendSelfNotSendDevice := false // 存在发送者自己但是不是发送设备
|
||||
for _, toConn := range slices.toConns {
|
||||
if toConn.uid == msg.FromUid && toConn.deviceId == msg.FromDeviceId {
|
||||
if toConn.Uid == msg.FromUid && toConn.DeviceId == msg.FromDeviceId {
|
||||
existSendSelfDevice = true
|
||||
}
|
||||
if toConn.uid == msg.FromUid && toConn.deviceId != msg.FromDeviceId {
|
||||
if toConn.Uid == msg.FromUid && toConn.DeviceId != msg.FromDeviceId {
|
||||
existSendSelfNotSendDevice = true
|
||||
}
|
||||
}
|
||||
@@ -499,33 +496,44 @@ func (d *deliverr) deliver(req *deliverReq, uids []string) {
|
||||
}
|
||||
|
||||
for _, conn := range slices.toConns {
|
||||
if conn.uid == message.FromUid && conn.deviceId == message.FromDeviceId { // 自己发的不处理
|
||||
if conn.Uid == message.FromUid && conn.DeviceId == message.FromDeviceId { // 自己发的不处理
|
||||
continue
|
||||
}
|
||||
|
||||
// 这里需要把channelID改成fromUID 比如A给B发消息,B收到的消息channelID应该是A A收到的消息channelID应该是B
|
||||
recvPacket.ChannelID = sendPacket.ChannelID
|
||||
if recvPacket.ChannelType == wkproto.ChannelTypePerson && recvPacket.ChannelID == conn.uid {
|
||||
if recvPacket.ChannelType == wkproto.ChannelTypePerson &&
|
||||
recvPacket.ChannelID == conn.Uid {
|
||||
recvPacket.ChannelID = recvPacket.FromUID
|
||||
}
|
||||
|
||||
// 红点设置
|
||||
recvPacket.RedDot = sendPacket.RedDot
|
||||
if conn.uid == recvPacket.FromUID { // 如果是自己则不显示红点
|
||||
if conn.Uid == recvPacket.FromUID { // 如果是自己则不显示红点
|
||||
recvPacket.RedDot = false
|
||||
}
|
||||
|
||||
// payload内容加密
|
||||
payloadBuffer.Reset()
|
||||
|
||||
if len(conn.aesIV) == 0 || len(conn.aesKey) == 0 {
|
||||
d.Error("aesIV or aesKey is empty", zap.String("uid", conn.uid), zap.String("deviceId", conn.deviceId), zap.String("channelId", recvPacket.ChannelID), zap.Uint8("channelType", recvPacket.ChannelType))
|
||||
if len(conn.AesIV) == 0 || len(conn.AesKey) == 0 {
|
||||
d.Error("aesIV or aesKey is empty",
|
||||
zap.String("uid", conn.Uid),
|
||||
zap.String("deviceId", conn.DeviceId),
|
||||
zap.String("channelId", recvPacket.ChannelID),
|
||||
zap.Uint8("channelType", recvPacket.ChannelType),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
err = encryptMessagePayload(sendPacket.Payload, conn, payloadBuffer)
|
||||
if err != nil {
|
||||
d.Error("加密payload失败!", zap.Error(err), zap.String("uid", conn.uid), zap.String("channelId", recvPacket.ChannelID), zap.Uint8("channelType", recvPacket.ChannelType))
|
||||
d.Error("加密payload失败!",
|
||||
zap.Error(err),
|
||||
zap.String("uid", conn.Uid),
|
||||
zap.String("channelId", recvPacket.ChannelID),
|
||||
zap.Uint8("channelType", recvPacket.ChannelType),
|
||||
)
|
||||
continue
|
||||
}
|
||||
recvPacket.Payload = payloadBuffer.Bytes()
|
||||
@@ -548,9 +556,9 @@ func (d *deliverr) deliver(req *deliverReq, uids []string) {
|
||||
|
||||
// 编码接受包
|
||||
recvPacketBuffer.Reset()
|
||||
err = d.dm.s.opts.Proto.WriteFrame(recvPacketBuffer, recvPacket, conn.protoVersion)
|
||||
err = d.dm.s.opts.Proto.WriteFrame(recvPacketBuffer, recvPacket, conn.ProtoVersion)
|
||||
if err != nil {
|
||||
d.Error("encode recvPacket failed", zap.String("uid", conn.uid), zap.String("channelId", recvPacket.ChannelID), zap.Uint8("channelType", recvPacket.ChannelType), zap.Error(err))
|
||||
d.Error("encode recvPacket failed", zap.String("uid", conn.Uid), zap.String("channelId", recvPacket.ChannelID), zap.Uint8("channelType", recvPacket.ChannelType), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -560,9 +568,9 @@ func (d *deliverr) deliver(req *deliverReq, uids []string) {
|
||||
|
||||
if !recvPacket.NoPersist { // 只有存储的消息才重试
|
||||
d.dm.s.retryManager.addRetry(&retryMessage{
|
||||
uid: conn.uid,
|
||||
fromNode: conn.FromNode(),
|
||||
connId: conn.connId,
|
||||
uid: conn.Uid,
|
||||
fromNode: conn.FromNode,
|
||||
connId: conn.ConnId,
|
||||
messageId: message.MessageId,
|
||||
recvPacketData: recvPacketData,
|
||||
channelId: req.channelId,
|
||||
@@ -599,8 +607,8 @@ func (d *deliverr) releaseRecvPacket(recvPacket *wkproto.RecvPacket) {
|
||||
}
|
||||
|
||||
// 加密消息
|
||||
func encryptMessagePayload(payload []byte, conn *connContext, resultBuff *bytebufferpool.ByteBuffer) error {
|
||||
aesKey, aesIV := conn.aesKey, conn.aesIV
|
||||
func encryptMessagePayload(payload []byte, conn *reactor.Conn, resultBuff *bytebufferpool.ByteBuffer) error {
|
||||
aesKey, aesIV := conn.AesKey, conn.AesIV
|
||||
// 加密payload
|
||||
err := wkutil.AesEncryptPkcs7Base64ForPool(payload, aesKey, aesIV, resultBuff)
|
||||
if err != nil {
|
||||
@@ -611,8 +619,8 @@ func encryptMessagePayload(payload []byte, conn *connContext, resultBuff *bytebu
|
||||
}
|
||||
|
||||
// 加密消息
|
||||
func encryptMessagePayload2(payload []byte, conn *connContext) ([]byte, error) {
|
||||
aesKey, aesIV := conn.aesKey, conn.aesIV
|
||||
func encryptMessagePayload2(payload []byte, conn *reactor.Conn) ([]byte, error) {
|
||||
aesKey, aesIV := conn.AesKey, conn.AesIV
|
||||
// 加密payload
|
||||
payloadEnc, err := wkutil.AesEncryptPkcs7Base64(payload, aesKey, aesIV)
|
||||
if err != nil {
|
||||
@@ -621,8 +629,8 @@ func encryptMessagePayload2(payload []byte, conn *connContext) ([]byte, error) {
|
||||
return payloadEnc, nil
|
||||
}
|
||||
|
||||
func writeAesEncrypt(aesResultBuffer *bytebufferpool.ByteBuffer, signBuffer *bytebufferpool.ByteBuffer, conn *connContext) error {
|
||||
aesKey, aesIV := conn.aesKey, conn.aesIV
|
||||
func writeAesEncrypt(aesResultBuffer *bytebufferpool.ByteBuffer, signBuffer *bytebufferpool.ByteBuffer, conn *reactor.Conn) error {
|
||||
aesKey, aesIV := conn.AesKey, conn.AesIV
|
||||
|
||||
// 生成MsgKey
|
||||
err := wkutil.AesEncryptPkcs7Base64ForPool(signBuffer.Bytes(), aesKey, aesIV, aesResultBuffer)
|
||||
|
||||
@@ -159,6 +159,14 @@ type Options struct {
|
||||
WhitelistOffOfPerson bool // 是否关闭个人白名单验证
|
||||
DeliveryMsgPoolSize int // 投递消息协程池大小,此池的协程主要用来将消息投递给在线用户 默认大小为 10240
|
||||
|
||||
// go协程池
|
||||
GoPool struct {
|
||||
// UserProcess 用户逻辑处理协程池
|
||||
UserProcess int
|
||||
// ChannelProcess 频道逻辑处理协程池
|
||||
ChannelProcess int
|
||||
}
|
||||
|
||||
MessageRetry struct {
|
||||
Interval time.Duration // 消息重试间隔,如果消息发送后在此间隔内没有收到ack,将会在此间隔后重新发送
|
||||
MaxCount int // 消息最大重试次数
|
||||
@@ -363,6 +371,13 @@ func NewOptions(op ...Option) *Options {
|
||||
},
|
||||
DeliveryMsgPoolSize: 10240,
|
||||
EventPoolSize: 1024,
|
||||
GoPool: struct {
|
||||
UserProcess int
|
||||
ChannelProcess int
|
||||
}{
|
||||
UserProcess: 4096,
|
||||
ChannelProcess: 4096,
|
||||
},
|
||||
MessageRetry: struct {
|
||||
Interval time.Duration
|
||||
MaxCount int
|
||||
@@ -651,6 +666,9 @@ func (o *Options) ConfigureWithViper(vp *viper.Viper) {
|
||||
|
||||
o.WhitelistOffOfPerson = o.getBool("whitelistOffOfPerson", o.WhitelistOffOfPerson)
|
||||
|
||||
o.GoPool.UserProcess = o.getInt("goPool.userProcess", o.GoPool.UserProcess)
|
||||
o.GoPool.ChannelProcess = o.getInt("goPool.channelProcess", o.GoPool.ChannelProcess)
|
||||
|
||||
o.MessageRetry.Interval = o.getDuration("messageRetry.interval", o.MessageRetry.Interval)
|
||||
o.MessageRetry.ScanInterval = o.getDuration("messageRetry.scanInterval", o.MessageRetry.ScanInterval)
|
||||
o.MessageRetry.MaxCount = o.getInt("messageRetry.maxCount", o.MessageRetry.MaxCount)
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wklog"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type processUser struct {
|
||||
s *Server
|
||||
wklog.Log
|
||||
}
|
||||
|
||||
func newProcessUser(s *Server) *processUser {
|
||||
|
||||
return &processUser{
|
||||
s: s,
|
||||
Log: wklog.NewWKLog("processUser"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) send(actions []reactor.UserAction) {
|
||||
|
||||
for _, a := range actions {
|
||||
switch a.Type {
|
||||
case reactor.UserActionElection:
|
||||
p.processElection(a)
|
||||
case reactor.UserActionAuth:
|
||||
p.processAuth(a)
|
||||
default:
|
||||
fmt.Println("a-->", a.Type.String())
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理选举
|
||||
func (p *processUser) processElection(a reactor.UserAction) {
|
||||
slotId := p.s.cluster.GetSlotId(a.Uid)
|
||||
leaderInfo, err := p.s.cluster.SlotLeaderNodeInfo(slotId)
|
||||
if err != nil {
|
||||
p.Error("get slot leader info failed", zap.Error(err), zap.Uint32("slotId", slotId))
|
||||
return
|
||||
}
|
||||
if leaderInfo == nil {
|
||||
p.Error("slot not exist", zap.Uint32("slotId", slotId))
|
||||
return
|
||||
}
|
||||
if leaderInfo.Id == 0 {
|
||||
p.Error("slot leader id is 0", zap.Uint32("slotId", slotId))
|
||||
return
|
||||
}
|
||||
|
||||
reactor.User.UpdateConfig(a.Uid, reactor.UserConfig{
|
||||
LeaderId: leaderInfo.Id,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *processUser) processAuth(a reactor.UserAction) {
|
||||
fmt.Println("processAuth....")
|
||||
}
|
||||
170
internal/server/process_user_connect.go
Normal file
170
internal/server/process_user_connect.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkdb"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkutil"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (p *processUser) handleConnect(msg *reactor.UserMessage) (wkproto.ReasonCode, *wkproto.ConnackPacket, error) {
|
||||
var (
|
||||
conn = msg.Conn
|
||||
connectPacket = msg.Frame.(*wkproto.ConnectPacket)
|
||||
devceLevel wkproto.DeviceLevel
|
||||
uid = connectPacket.UID
|
||||
)
|
||||
// -------------------- token verify --------------------
|
||||
if connectPacket.UID == p.s.opts.ManagerUID {
|
||||
if p.s.opts.ManagerTokenOn && connectPacket.Token != p.s.opts.ManagerToken {
|
||||
p.Error("manager token verify fail", zap.String("uid", uid), zap.String("token", connectPacket.Token))
|
||||
return wkproto.ReasonAuthFail, nil, nil
|
||||
}
|
||||
devceLevel = wkproto.DeviceLevelSlave // 默认都是slave设备
|
||||
} else if p.s.opts.TokenAuthOn {
|
||||
if connectPacket.Token == "" {
|
||||
p.Error("token is empty")
|
||||
return wkproto.ReasonAuthFail, nil, errors.New("token is empty")
|
||||
}
|
||||
device, err := p.s.store.GetDevice(uid, connectPacket.DeviceFlag)
|
||||
if err != nil {
|
||||
p.Error("get device token err", zap.Error(err))
|
||||
return wkproto.ReasonAuthFail, nil, err
|
||||
}
|
||||
if device.Token != connectPacket.Token {
|
||||
p.Error("token verify fail", zap.String("expectToken", device.Token), zap.String("actToken", connectPacket.Token))
|
||||
return wkproto.ReasonAuthFail, nil, errors.New("token verify fail")
|
||||
}
|
||||
devceLevel = wkproto.DeviceLevel(device.DeviceLevel)
|
||||
} else {
|
||||
devceLevel = wkproto.DeviceLevelSlave // 默认都是slave设备
|
||||
}
|
||||
|
||||
// -------------------- ban --------------------
|
||||
userChannelInfo, err := p.s.store.GetChannel(uid, wkproto.ChannelTypePerson)
|
||||
if err != nil {
|
||||
p.Error("get device channel info err", zap.Error(err))
|
||||
return wkproto.ReasonAuthFail, nil, err
|
||||
}
|
||||
ban := false
|
||||
if !wkdb.IsEmptyChannelInfo(userChannelInfo) {
|
||||
ban = userChannelInfo.Ban
|
||||
}
|
||||
if ban {
|
||||
p.Error("device is ban", zap.String("uid", uid))
|
||||
return wkproto.ReasonBan, nil, errors.New("device is ban")
|
||||
}
|
||||
|
||||
// -------------------- get message encrypt key --------------------
|
||||
dhServerPrivKey, dhServerPublicKey := wkutil.GetCurve25519KeypPair() // 生成服务器的DH密钥对
|
||||
aesKey, aesIV, err := p.getClientAesKeyAndIV(connectPacket.ClientKey, dhServerPrivKey)
|
||||
if err != nil {
|
||||
p.Error("get client aes key and iv err", zap.Error(err))
|
||||
return wkproto.ReasonAuthFail, nil, err
|
||||
}
|
||||
dhServerPublicKeyEnc := base64.StdEncoding.EncodeToString(dhServerPublicKey[:])
|
||||
|
||||
// -------------------- same master kicks each other --------------------
|
||||
oldConns := reactor.User.ConnsByDeviceFlag(uid, connectPacket.DeviceFlag)
|
||||
if len(oldConns) > 0 {
|
||||
if devceLevel == wkproto.DeviceLevelMaster { // 如果设备是master级别,则把旧连接都踢掉
|
||||
for _, oldConn := range oldConns {
|
||||
if oldConn.Equal(conn) { // 不能把自己踢了
|
||||
continue
|
||||
}
|
||||
if oldConn.DeviceId != connectPacket.DeviceID {
|
||||
p.Info("auth: same master kicks each other",
|
||||
zap.String("devceLevel", devceLevel.String()),
|
||||
zap.String("uid", uid),
|
||||
zap.String("deviceID", connectPacket.DeviceID),
|
||||
zap.String("oldDeviceId", oldConn.DeviceId),
|
||||
)
|
||||
reactor.User.Kick(oldConn, wkproto.ReasonConnectKick, "login in other device")
|
||||
} else {
|
||||
reactor.User.CloseConn(oldConn)
|
||||
}
|
||||
p.Info("auth: close old conn for master", zap.Any("oldConn", oldConn))
|
||||
}
|
||||
} else if devceLevel == wkproto.DeviceLevelSlave { // 如果设备是slave级别,则把相同的deviceId关闭
|
||||
for _, oldConn := range oldConns {
|
||||
if oldConn.ConnId != conn.ConnId && oldConn.DeviceId == connectPacket.DeviceID {
|
||||
reactor.User.CloseConn(oldConn)
|
||||
p.Info("auth: close old conn for slave", zap.Any("oldConn", oldConn))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------- set conn info --------------------
|
||||
timeDiff := time.Now().UnixNano()/1000/1000 - connectPacket.ClientTimestamp
|
||||
|
||||
// connCtx := p.connContextPool.Get().(*connContext)
|
||||
|
||||
lastVersion := connectPacket.Version
|
||||
hasServerVersion := false
|
||||
if connectPacket.Version > wkproto.LatestVersion {
|
||||
lastVersion = wkproto.LatestVersion
|
||||
}
|
||||
|
||||
conn.AesIV = aesIV
|
||||
conn.AesKey = aesKey
|
||||
conn.Auth = true
|
||||
conn.ProtoVersion = lastVersion
|
||||
conn.DeviceLevel = devceLevel
|
||||
|
||||
realConn := p.s.connManager.getConn(conn.ConnId)
|
||||
if realConn == nil {
|
||||
p.Error("auth: conn not exist", zap.Int64("connId", conn.ConnId))
|
||||
return wkproto.ReasonAuthFail, nil, errors.New("conn not exist")
|
||||
}
|
||||
realConn.SetMaxIdle(p.s.opts.ConnIdleTime)
|
||||
|
||||
// -------------------- response connack --------------------
|
||||
|
||||
if connectPacket.Version > 3 {
|
||||
hasServerVersion = true
|
||||
}
|
||||
|
||||
p.Debug("auth: auth Success", zap.Uint8("protoVersion", connectPacket.Version), zap.Bool("hasServerVersion", hasServerVersion))
|
||||
connack := &wkproto.ConnackPacket{
|
||||
Salt: string(aesIV),
|
||||
ServerKey: dhServerPublicKeyEnc,
|
||||
ReasonCode: wkproto.ReasonSuccess,
|
||||
TimeDiff: timeDiff,
|
||||
ServerVersion: lastVersion,
|
||||
NodeId: p.s.opts.Cluster.NodeId,
|
||||
}
|
||||
connack.HasServerVersion = hasServerVersion
|
||||
// -------------------- user online --------------------
|
||||
// 在线webhook
|
||||
deviceOnlineCount := reactor.User.ConnCountByDeviceFlag(uid, connectPacket.DeviceFlag)
|
||||
totalOnlineCount := reactor.User.ConnCountByUid(uid)
|
||||
p.s.webhook.Online(uid, connectPacket.DeviceFlag, conn.ConnId, deviceOnlineCount, totalOnlineCount)
|
||||
|
||||
return wkproto.ReasonSuccess, connack, nil
|
||||
}
|
||||
|
||||
// 获取客户端的aesKey和aesIV
|
||||
// dhServerPrivKey 服务端私钥
|
||||
func (p *processUser) getClientAesKeyAndIV(clientKey string, dhServerPrivKey [32]byte) ([]byte, []byte, error) {
|
||||
|
||||
clientKeyBytes, err := base64.StdEncoding.DecodeString(clientKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var dhClientPubKeyArray [32]byte
|
||||
copy(dhClientPubKeyArray[:], clientKeyBytes[:32])
|
||||
|
||||
// 获得DH的共享key
|
||||
shareKey := wkutil.GetCurve25519Key(dhServerPrivKey, dhClientPubKeyArray) // 共享key
|
||||
|
||||
aesIV := wkutil.GetRandomString(16)
|
||||
aesKey := wkutil.MD5(base64.StdEncoding.EncodeToString(shareKey[:]))[:16]
|
||||
return []byte(aesKey), []byte(aesIV), nil
|
||||
}
|
||||
56
internal/server/process_user_in.go
Normal file
56
internal/server/process_user_in.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
reactor "github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkserver/proto"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 收到消息
|
||||
func (p *processUser) onMessage(m *proto.Message) {
|
||||
err := p.s.userProcessPool.Submit(func() {
|
||||
p.handleMessage(m)
|
||||
})
|
||||
if err != nil {
|
||||
p.Error("onMessage: submit error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) handleMessage(m *proto.Message) {
|
||||
fmt.Println("onMessage------>", m.MsgType)
|
||||
switch msgType(m.MsgType) {
|
||||
// 节点加入
|
||||
case msgUserJoinReq:
|
||||
p.handleJoin(m)
|
||||
// 收到发件箱
|
||||
case msgOutboundReq:
|
||||
p.handleOutboundReq(m)
|
||||
}
|
||||
}
|
||||
|
||||
// 收到加入请求
|
||||
func (p *processUser) handleJoin(m *proto.Message) {
|
||||
req := &userJoinReq{}
|
||||
err := req.decode(m.Content)
|
||||
if err != nil {
|
||||
p.Error("decode joinReq failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
reactor.User.Join(req.uid, req.from)
|
||||
}
|
||||
|
||||
func (p *processUser) handleOutboundReq(m *proto.Message) {
|
||||
req := &outboundReq{}
|
||||
err := req.decode(m.Content)
|
||||
if err != nil {
|
||||
p.Error("decode outbound failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if req.fromNode == p.s.opts.Cluster.NodeId {
|
||||
p.Warn("outbound request from self", zap.Uint64("fromNode", req.fromNode))
|
||||
return
|
||||
}
|
||||
reactor.User.AddMessages(req.uid, req.messages)
|
||||
}
|
||||
248
internal/server/process_user_out.go
Normal file
248
internal/server/process_user_out.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wklog"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wknet"
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkserver/proto"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type processUser struct {
|
||||
s *Server
|
||||
wklog.Log
|
||||
}
|
||||
|
||||
func newProcessUser(s *Server) *processUser {
|
||||
|
||||
return &processUser{
|
||||
s: s,
|
||||
Log: wklog.NewWKLog("processUser"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) send(actions []reactor.UserAction) {
|
||||
|
||||
var err error
|
||||
for _, a := range actions {
|
||||
err = p.s.userProcessPool.Submit(func() {
|
||||
p.processAction(a)
|
||||
})
|
||||
if err != nil {
|
||||
p.Error("submit err", zap.Error(err), zap.String("uid", a.Uid), zap.String("actionType", a.Type.String()))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) processAction(a reactor.UserAction) {
|
||||
fmt.Println("send-->", a.Type.String())
|
||||
switch a.Type {
|
||||
case reactor.UserActionElection: // 选举
|
||||
p.processElection(a)
|
||||
case reactor.UserActionJoin: // 加入
|
||||
p.processJoin(a)
|
||||
case reactor.UserActionOutboundForward: // 发件
|
||||
p.processOutbound(a)
|
||||
case reactor.UserActionInbound: // 收件
|
||||
p.processInbound(a)
|
||||
case reactor.UserActionWrite: // 连接写
|
||||
p.processWrite(a)
|
||||
case reactor.UserActionConnClose: // 连接关闭
|
||||
p.processConnClose(a)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// 处理选举
|
||||
func (p *processUser) processElection(a reactor.UserAction) {
|
||||
slotId := p.s.cluster.GetSlotId(a.Uid)
|
||||
leaderInfo, err := p.s.cluster.SlotLeaderNodeInfo(slotId)
|
||||
if err != nil {
|
||||
p.Error("get slot leader info failed", zap.Error(err), zap.Uint32("slotId", slotId))
|
||||
return
|
||||
}
|
||||
if leaderInfo == nil {
|
||||
p.Error("slot not exist", zap.Uint32("slotId", slotId))
|
||||
return
|
||||
}
|
||||
if leaderInfo.Id == 0 {
|
||||
p.Error("slot leader id is 0", zap.Uint32("slotId", slotId))
|
||||
return
|
||||
}
|
||||
|
||||
reactor.User.UpdateConfig(a.Uid, reactor.UserConfig{
|
||||
LeaderId: leaderInfo.Id,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *processUser) processJoin(a reactor.UserAction) {
|
||||
req := &userJoinReq{
|
||||
from: p.s.opts.Cluster.NodeId,
|
||||
uid: a.Uid,
|
||||
}
|
||||
|
||||
err := p.s.cluster.Send(a.To, &proto.Message{
|
||||
MsgType: uint32(msgUserJoinReq),
|
||||
Content: req.encode(),
|
||||
})
|
||||
if err != nil {
|
||||
p.Error("send join req failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) processConnClose(a reactor.UserAction) {
|
||||
if len(a.Conns) == 0 {
|
||||
p.Warn("processConnClose: conns is empty", zap.String("uid", a.Uid))
|
||||
return
|
||||
}
|
||||
for _, c := range a.Conns {
|
||||
if !p.s.opts.IsLocalNode(c.FromNode) {
|
||||
p.Info("processConnClose: conn not local node", zap.String("uid", a.Uid), zap.Uint64("fromNode", c.FromNode))
|
||||
continue
|
||||
}
|
||||
conn := p.s.connManager.getConn(c.ConnId)
|
||||
if conn == nil {
|
||||
p.Warn("processConnClose: conn not exist", zap.String("uid", a.Uid), zap.Int64("connId", c.ConnId))
|
||||
continue
|
||||
}
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
p.Debug("Failed to close the conn", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) processOutbound(a reactor.UserAction) {
|
||||
fmt.Println("processOutbound....")
|
||||
if len(a.Messages) == 0 {
|
||||
p.Warn("processOutbound: messages is empty")
|
||||
return
|
||||
}
|
||||
req := &outboundReq{
|
||||
fromNode: p.s.opts.Cluster.NodeId,
|
||||
uid: a.Uid,
|
||||
messages: a.Messages,
|
||||
}
|
||||
data, err := req.encode()
|
||||
if err != nil {
|
||||
p.Error("encode failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = p.s.cluster.Send(a.To, &proto.Message{
|
||||
MsgType: uint32(msgOutboundReq),
|
||||
Content: data,
|
||||
})
|
||||
if err != nil {
|
||||
p.Error("processOutbound: send failed", zap.Error(err))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *processUser) processInbound(a reactor.UserAction) {
|
||||
if len(a.Messages) == 0 {
|
||||
return
|
||||
}
|
||||
// 从收件箱中取出消息
|
||||
for _, m := range a.Messages {
|
||||
if m.Frame == nil {
|
||||
continue
|
||||
}
|
||||
fmt.Println("processInbound-->", m.Frame.GetFrameType().String())
|
||||
switch m.Frame.GetFrameType() {
|
||||
case wkproto.CONNECT: // 连接包
|
||||
if a.Role == reactor.RoleLeader {
|
||||
p.processConnect(a.Uid, m)
|
||||
} else {
|
||||
// 如果不是领导节点,则专投递给发件箱这样就会被领导节点处理
|
||||
reactor.User.AddMessageToOutbound(a.Uid, m)
|
||||
}
|
||||
case wkproto.CONNACK: // 连接回执包
|
||||
p.processConnack(a.Uid, m)
|
||||
case wkproto.PING: // 心跳包
|
||||
p.processPing(m)
|
||||
case wkproto.SEND: // 发送消息
|
||||
p.processSend(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) processWrite(a reactor.UserAction) {
|
||||
|
||||
if len(a.Messages) == 0 {
|
||||
return
|
||||
}
|
||||
for _, m := range a.Messages {
|
||||
if m.Conn == nil {
|
||||
continue
|
||||
}
|
||||
if !p.s.opts.IsLocalNode(m.Conn.FromNode) {
|
||||
reactor.User.AddMessageToOutbound(a.Uid, m)
|
||||
continue
|
||||
}
|
||||
conn := p.s.connManager.getConn(m.Conn.ConnId)
|
||||
if conn == nil {
|
||||
p.Warn("conn not exist", zap.String("uid", a.Uid), zap.Int64("connId", m.Conn.ConnId))
|
||||
continue
|
||||
}
|
||||
wsConn, wsok := conn.(wknet.IWSConn) // websocket连接
|
||||
if wsok {
|
||||
err := wsConn.WriteServerBinary(m.WriteData)
|
||||
if err != nil {
|
||||
p.Warn("Failed to ws write the message", zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
_, err := conn.WriteToOutboundBuffer(m.WriteData)
|
||||
if err != nil {
|
||||
p.Warn("Failed to write the message", zap.Error(err))
|
||||
}
|
||||
}
|
||||
_ = conn.WakeWrite()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *processUser) processConnect(uid string, msg *reactor.UserMessage) {
|
||||
reasonCode, packet, err := p.handleConnect(msg)
|
||||
if err != nil {
|
||||
p.Error("handle connect failed", zap.Error(err), zap.String("uid", uid))
|
||||
return
|
||||
}
|
||||
if reasonCode != wkproto.ReasonSuccess && packet == nil {
|
||||
packet = &wkproto.ConnackPacket{
|
||||
ReasonCode: reasonCode,
|
||||
}
|
||||
}
|
||||
|
||||
reactor.User.AddMessage(uid, &reactor.UserMessage{
|
||||
Conn: msg.Conn,
|
||||
Frame: packet,
|
||||
ToNode: msg.Conn.FromNode,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *processUser) processConnack(uid string, msg *reactor.UserMessage) {
|
||||
conn := msg.Conn
|
||||
if conn.FromNode == 0 {
|
||||
p.Error("from node is 0", zap.String("uid", uid))
|
||||
return
|
||||
}
|
||||
if p.s.opts.IsLocalNode(conn.FromNode) {
|
||||
reactor.User.ConnWrite(conn, msg.Frame)
|
||||
} else {
|
||||
reactor.User.AddMessageToOutbound(uid, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processUser) processPing(msg *reactor.UserMessage) {
|
||||
p.handlePing(msg)
|
||||
}
|
||||
|
||||
func (p *processUser) processSend(msg *reactor.UserMessage) {
|
||||
p.handleSend(msg)
|
||||
}
|
||||
13
internal/server/process_user_ping.go
Normal file
13
internal/server/process_user_ping.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
)
|
||||
|
||||
func (p *processUser) handlePing(msg *reactor.UserMessage) {
|
||||
fmt.Println("handlePing---->")
|
||||
reactor.User.ConnWrite(msg.Conn, &wkproto.PongPacket{})
|
||||
}
|
||||
77
internal/server/process_user_proto.go
Normal file
77
internal/server/process_user_proto.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
)
|
||||
|
||||
type msgType uint32
|
||||
|
||||
const (
|
||||
// 用户节点加入,用户副本节点加入领导节点时发起
|
||||
msgUserJoinReq msgType = 2000
|
||||
// 消息发件箱请求
|
||||
msgOutboundReq msgType = 2001
|
||||
)
|
||||
|
||||
type userJoinReq struct {
|
||||
from uint64
|
||||
uid string
|
||||
}
|
||||
|
||||
func (u *userJoinReq) encode() []byte {
|
||||
enc := wkproto.NewEncoder()
|
||||
defer enc.End()
|
||||
enc.WriteUint64(u.from)
|
||||
enc.WriteString(u.uid)
|
||||
|
||||
return enc.Bytes()
|
||||
}
|
||||
|
||||
func (u *userJoinReq) decode(data []byte) error {
|
||||
dec := wkproto.NewDecoder(data)
|
||||
var err error
|
||||
if u.from, err = dec.Uint64(); err != nil {
|
||||
return err
|
||||
}
|
||||
if u.uid, err = dec.String(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type outboundReq struct {
|
||||
fromNode uint64
|
||||
uid string
|
||||
messages reactor.UserMessageBatch
|
||||
}
|
||||
|
||||
func (o *outboundReq) encode() ([]byte, error) {
|
||||
enc := wkproto.NewEncoder()
|
||||
enc.WriteString(o.uid)
|
||||
enc.WriteUint64(o.fromNode)
|
||||
if len(o.messages) > 0 {
|
||||
msgData, err := o.messages.Encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc.WriteBytes(msgData)
|
||||
}
|
||||
return enc.Bytes(), nil
|
||||
}
|
||||
|
||||
func (o *outboundReq) decode(data []byte) error {
|
||||
dec := wkproto.NewDecoder(data)
|
||||
var err error
|
||||
if o.uid, err = dec.String(); err != nil {
|
||||
return err
|
||||
}
|
||||
if o.fromNode, err = dec.Uint64(); err != nil {
|
||||
return err
|
||||
}
|
||||
msgData, err := dec.BinaryAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return o.messages.Decode(msgData)
|
||||
}
|
||||
7
internal/server/process_user_send.go
Normal file
7
internal/server/process_user_send.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package server
|
||||
|
||||
import "github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
|
||||
func (p *processUser) handleSend(msg *reactor.UserMessage) {
|
||||
|
||||
}
|
||||
@@ -19,11 +19,11 @@ func (s *Server) onData(conn wknet.Conn) error {
|
||||
}
|
||||
|
||||
var isAuth bool
|
||||
var connCtx *connContext
|
||||
var connCtx *reactor.Conn
|
||||
connCtxObj := conn.Context()
|
||||
if connCtxObj != nil {
|
||||
connCtx = connCtxObj.(*connContext)
|
||||
isAuth = connCtx.isAuth.Load()
|
||||
connCtx = connCtxObj.(*reactor.Conn)
|
||||
isAuth = connCtx.Auth
|
||||
} else {
|
||||
isAuth = false
|
||||
}
|
||||
@@ -83,14 +83,14 @@ func (s *Server) onData(conn wknet.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
connInfo := connInfo{
|
||||
connId: conn.ID(),
|
||||
uid: connectPacket.UID,
|
||||
deviceId: connectPacket.DeviceID,
|
||||
deviceFlag: wkproto.DeviceFlag(connectPacket.DeviceFlag),
|
||||
protoVersion: connectPacket.Version,
|
||||
connCtx = &reactor.Conn{
|
||||
FromNode: s.opts.Cluster.NodeId,
|
||||
ConnId: conn.ID(),
|
||||
Uid: connectPacket.UID,
|
||||
DeviceId: connectPacket.DeviceID,
|
||||
DeviceFlag: wkproto.DeviceFlag(connectPacket.DeviceFlag),
|
||||
ProtoVersion: connectPacket.Version,
|
||||
}
|
||||
connCtx = newConnContext(connInfo, conn)
|
||||
conn.SetContext(connCtx)
|
||||
|
||||
// 如果用户不存在则唤醒用户
|
||||
@@ -101,9 +101,9 @@ func (s *Server) onData(conn wknet.Conn) error {
|
||||
_, _ = conn.Discard(len(data))
|
||||
} else {
|
||||
offset := 0
|
||||
var messages []reactor.UserMessage
|
||||
var messages []*reactor.UserMessage
|
||||
for len(data) > offset {
|
||||
frame, size, err := s.opts.Proto.DecodeFrame(data[offset:], connCtx.protoVersion)
|
||||
frame, size, err := s.opts.Proto.DecodeFrame(data[offset:], connCtx.ProtoVersion)
|
||||
if err != nil { //
|
||||
s.Warn("Failed to decode the message", zap.Error(err))
|
||||
conn.Close()
|
||||
@@ -113,16 +113,17 @@ func (s *Server) onData(conn wknet.Conn) error {
|
||||
break
|
||||
}
|
||||
if messages == nil {
|
||||
messages = make([]reactor.UserMessage, 0, 10)
|
||||
messages = make([]*reactor.UserMessage, 0, 10)
|
||||
}
|
||||
messages = append(messages, &reactorUserMessage{
|
||||
frame: frame,
|
||||
messages = append(messages, &reactor.UserMessage{
|
||||
Frame: frame,
|
||||
Conn: connCtx,
|
||||
})
|
||||
offset += size
|
||||
}
|
||||
if len(messages) > 0 {
|
||||
// 添加消息
|
||||
reactor.User.AddMessages(connCtx.uid, messages)
|
||||
reactor.User.AddMessages(connCtx.Uid, messages)
|
||||
}
|
||||
|
||||
_, _ = conn.Discard(offset)
|
||||
|
||||
@@ -1,32 +1,28 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/WuKongIM/WuKongIM/internal/reactor"
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
)
|
||||
// type reactorUserMessage struct {
|
||||
// conn reactor.Conn
|
||||
// frame wkproto.Frame
|
||||
// index uint64
|
||||
// toNodeId uint64
|
||||
// }
|
||||
|
||||
type reactorUserMessage struct {
|
||||
conn *connContext
|
||||
frame wkproto.Frame
|
||||
index uint64
|
||||
}
|
||||
// func (m *reactorUserMessage) Conn() reactor.Conn {
|
||||
// return m.conn
|
||||
// }
|
||||
|
||||
func (m *reactorUserMessage) Conn() reactor.Conn {
|
||||
return m.conn
|
||||
}
|
||||
// func (m *reactorUserMessage) Frame() wkproto.Frame {
|
||||
// return m.frame
|
||||
// }
|
||||
|
||||
func (m *reactorUserMessage) Frame() wkproto.Frame {
|
||||
return m.frame
|
||||
}
|
||||
// func (m *reactorUserMessage) Size() uint64 {
|
||||
// return 0
|
||||
// }
|
||||
|
||||
func (m *reactorUserMessage) Size() uint64 {
|
||||
return 0
|
||||
}
|
||||
// func (m *reactorUserMessage) SetIndex(index uint64) {
|
||||
// m.index = index
|
||||
// }
|
||||
|
||||
func (m *reactorUserMessage) SetIndex(index uint64) {
|
||||
m.index = index
|
||||
}
|
||||
|
||||
func (m *reactorUserMessage) Index() uint64 {
|
||||
return m.index
|
||||
}
|
||||
// func (m *reactorUserMessage) Index() uint64 {
|
||||
// return m.index
|
||||
// }
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
wkproto "github.com/WuKongIM/WuKongIMGoProto"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/judwhite/go-svc"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"go.etcd.io/etcd/pkg/v3/idutil"
|
||||
@@ -75,8 +76,11 @@ type Server struct {
|
||||
|
||||
promtailServer *promtail.Promtail // 日志收集, 负责收集WuKongIM的日志 上报给Loki
|
||||
|
||||
userReactor reactor.IUser
|
||||
processUser *processUser
|
||||
connManager *connManager // 连接管理
|
||||
userReactor reactor.IUser // 用户控制中心
|
||||
processUser *processUser // 用户逻辑处理
|
||||
userProcessPool *ants.Pool // 用户逻辑处理协程池
|
||||
|
||||
}
|
||||
|
||||
func New(opts *Options) *Server {
|
||||
@@ -88,9 +92,16 @@ func New(opts *Options) *Server {
|
||||
reqIDGen: idutil.NewGenerator(uint16(opts.Cluster.NodeId), time.Now()),
|
||||
start: now,
|
||||
}
|
||||
|
||||
s.connManager = newConnManager(18)
|
||||
var err error
|
||||
s.userProcessPool, err = ants.NewPool(opts.GoPool.UserProcess, ants.WithPanicHandler(func(i interface{}) {
|
||||
s.Panic("user process pool is panic", zap.Any("err", err), zap.Stack("stack"))
|
||||
}))
|
||||
if err != nil {
|
||||
s.Panic("new user process pool failed", zap.Error(err))
|
||||
}
|
||||
// 配置检查
|
||||
err := opts.Check()
|
||||
err = opts.Check()
|
||||
if err != nil {
|
||||
s.Panic("config check error", zap.Error(err))
|
||||
}
|
||||
@@ -144,6 +155,7 @@ func New(opts *Options) *Server {
|
||||
userreactor.WithNodeId(opts.Cluster.NodeId),
|
||||
userreactor.WithSend(s.processUser.send),
|
||||
)
|
||||
reactor.Proto = s.opts.Proto
|
||||
// 注册user reactor
|
||||
reactor.RegisterUser(s.userReactor)
|
||||
s.demoServer = NewDemoServer(s) // demo server
|
||||
@@ -438,6 +450,8 @@ func (s *Server) onConnect(conn wknet.Conn) error {
|
||||
conn.SetMaxIdle(time.Second * 2) // 在认证之前,连接最多空闲2秒
|
||||
s.trace.Metrics.App().ConnCountAdd(1)
|
||||
|
||||
s.connManager.addConn(conn)
|
||||
|
||||
// if conn.InboundBuffer().BoundBufferSize() == 0 {
|
||||
// conn.SetValue(ConnKeyParseProxyProto, true) // 设置需要解析代理协议
|
||||
// return nil
|
||||
@@ -478,15 +492,10 @@ func (s *Server) onClose(conn wknet.Conn) {
|
||||
s.trace.Metrics.App().ConnCountAdd(-1)
|
||||
connCtxObj := conn.Context()
|
||||
if connCtxObj != nil {
|
||||
connCtx := connCtxObj.(*connContext)
|
||||
connCtx := connCtxObj.(*reactor.Conn)
|
||||
reactor.User.CloseConn(connCtx)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// 代理节点关闭
|
||||
func (s *Server) onCloseForProxy(conn *connContext) {
|
||||
|
||||
s.connManager.removeConn(conn)
|
||||
}
|
||||
|
||||
// Schedule 延迟任务
|
||||
@@ -497,9 +506,9 @@ func (s *Server) Schedule(interval time.Duration, f func()) *timingwheel.Timer {
|
||||
}
|
||||
|
||||
// decode payload
|
||||
func (s *Server) checkAndDecodePayload(sendPacket *wkproto.SendPacket, conn *connContext) ([]byte, error) {
|
||||
func (s *Server) checkAndDecodePayload(sendPacket *wkproto.SendPacket, conn *reactor.Conn) ([]byte, error) {
|
||||
|
||||
aesKey, aesIV := conn.aesKey, conn.aesIV
|
||||
aesKey, aesIV := conn.AesKey, conn.AesIV
|
||||
vail, err := s.sendPacketIsVail(sendPacket, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -518,8 +527,8 @@ func (s *Server) checkAndDecodePayload(sendPacket *wkproto.SendPacket, conn *con
|
||||
}
|
||||
|
||||
// send packet is vail
|
||||
func (s *Server) sendPacketIsVail(sendPacket *wkproto.SendPacket, conn *connContext) (bool, error) {
|
||||
aesKey, aesIV := conn.aesKey, conn.aesIV
|
||||
func (s *Server) sendPacketIsVail(sendPacket *wkproto.SendPacket, conn *reactor.Conn) (bool, error) {
|
||||
aesKey, aesIV := conn.AesKey, conn.AesIV
|
||||
signStr := sendPacket.VerityString()
|
||||
|
||||
signBuff := bytebufferpool.Get()
|
||||
|
||||
@@ -21,6 +21,10 @@ func (s *Server) handleClusterMessage(fromNodeId uint64, msg *proto.Message) {
|
||||
go s.handleNodePing(fromNodeId, msg)
|
||||
case ClusterMsgTypeNodePong: // 节点Pong
|
||||
go s.handleNodePong(fromNodeId, msg)
|
||||
default:
|
||||
if msg.MsgType >= 2000 && msg.MsgType < 3000 {
|
||||
s.processUser.onMessage(msg)
|
||||
}
|
||||
|
||||
}
|
||||
// switch ClusterMsgType(msg.MsgType) {
|
||||
|
||||
73
pkg/bytequeue/queue.go
Normal file
73
pkg/bytequeue/queue.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package bytequeue
|
||||
|
||||
import "github.com/valyala/bytebufferpool"
|
||||
|
||||
type ByteQueue struct {
|
||||
buffer *bytebufferpool.ByteBuffer
|
||||
offsetSize uint64 // 偏移大小
|
||||
totalSize uint64 // 总大小
|
||||
|
||||
}
|
||||
|
||||
func New() *ByteQueue {
|
||||
return &ByteQueue{
|
||||
buffer: bytebufferpool.Get(),
|
||||
}
|
||||
}
|
||||
|
||||
// Write 写入字节
|
||||
func (b *ByteQueue) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
n, err := b.buffer.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b.totalSize += uint64(n)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// 从指定位置开始读取n个字节
|
||||
func (b *ByteQueue) Peek(startPosition uint64, n int) []byte {
|
||||
if b.totalSize == 0 {
|
||||
return nil
|
||||
}
|
||||
if startPosition >= b.totalSize {
|
||||
return nil
|
||||
}
|
||||
startIndex := int(startPosition - b.offsetSize)
|
||||
if startIndex < 0 {
|
||||
return nil
|
||||
}
|
||||
if startIndex+n > len(b.buffer.B) {
|
||||
n = len(b.buffer.B) - int(startIndex)
|
||||
}
|
||||
return b.buffer.B[startIndex : startIndex+n]
|
||||
}
|
||||
|
||||
func (b *ByteQueue) Discard(endPosition uint64) {
|
||||
n := int(endPosition - b.offsetSize)
|
||||
_ = b.discard(n)
|
||||
}
|
||||
|
||||
func (b *ByteQueue) discard(n int) int {
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
if n >= len(b.buffer.B) {
|
||||
b.offsetSize += uint64(len(b.buffer.B))
|
||||
b.buffer.B = b.buffer.B[:0]
|
||||
return len(b.buffer.B)
|
||||
}
|
||||
b.buffer.B = b.buffer.B[n:]
|
||||
b.offsetSize += uint64(n)
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *ByteQueue) Reset() {
|
||||
b.buffer.Reset()
|
||||
b.offsetSize = 0
|
||||
b.totalSize = 0
|
||||
bytebufferpool.Put(b.buffer)
|
||||
}
|
||||
94
pkg/bytequeue/queue_test.go
Normal file
94
pkg/bytequeue/queue_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package bytequeue
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestByteQueue(t *testing.T) {
|
||||
// 创建一个新的 ByteQueue
|
||||
bq := New()
|
||||
defer bq.Reset() // 确保测试结束后重置
|
||||
|
||||
// 测试 Write 方法
|
||||
t.Run("Write", func(t *testing.T) {
|
||||
data := []byte("hello")
|
||||
n, err := bq.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Write() failed: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write() = %d, want %d", n, len(data))
|
||||
}
|
||||
if !bytes.Equal(bq.buffer.B, data) {
|
||||
t.Errorf("Write() buffer = %v, want %v", bq.buffer.B, data)
|
||||
}
|
||||
})
|
||||
|
||||
// 测试 Peek 方法
|
||||
t.Run("Peek", func(t *testing.T) {
|
||||
data := []byte("hello")
|
||||
// 在偏移量 0 处读取 5 个字节
|
||||
result := bq.Peek(0, 5)
|
||||
if !bytes.Equal(result, data) {
|
||||
t.Errorf("Peek() = %v, want %v", result, data)
|
||||
}
|
||||
|
||||
// 读取部分数据,从位置 2 开始读取 3 个字节
|
||||
result = bq.Peek(2, 3)
|
||||
expected := []byte("llo")
|
||||
if !bytes.Equal(result, expected) {
|
||||
t.Errorf("Peek() = %v, want %v", result, expected)
|
||||
}
|
||||
|
||||
// 测试越界,startPosition 大于数据总大小
|
||||
result = bq.Peek(100, 3)
|
||||
if result != nil {
|
||||
t.Errorf("Peek() = %v, want nil", result)
|
||||
}
|
||||
|
||||
// 测试读取超过缓冲区剩余数据
|
||||
result = bq.Peek(0, 10) // 请求更多字节
|
||||
if !bytes.Equal(result, data) {
|
||||
t.Errorf("Peek() = %v, want %v", result, data)
|
||||
}
|
||||
})
|
||||
|
||||
// 测试 Discard 方法
|
||||
t.Run("Discard", func(t *testing.T) {
|
||||
// 写入一些数据
|
||||
_, _ = bq.Write([]byte("world"))
|
||||
|
||||
// 丢弃前 5 个字节
|
||||
bq.Discard(5)
|
||||
if len(bq.buffer.B) != 5 {
|
||||
t.Errorf("Discard() buffer size = %d, want 5", len(bq.buffer.B))
|
||||
}
|
||||
|
||||
// 丢弃所有数据
|
||||
bq.Discard(10)
|
||||
if len(bq.buffer.B) != 0 {
|
||||
t.Errorf("Discard() buffer size = %d, want 0", len(bq.buffer.B))
|
||||
}
|
||||
})
|
||||
|
||||
// 测试 Reset 方法
|
||||
t.Run("Reset", func(t *testing.T) {
|
||||
// 写入数据
|
||||
_, _ = bq.Write([]byte("reset"))
|
||||
|
||||
// 重置
|
||||
bq.Reset()
|
||||
|
||||
// 确保缓冲区已清空
|
||||
if len(bq.buffer.B) != 0 {
|
||||
t.Errorf("Reset() buffer size = %d, want 0", len(bq.buffer.B))
|
||||
}
|
||||
if bq.offsetSize != 0 {
|
||||
t.Errorf("Reset() offsetSize = %d, want 0", bq.offsetSize)
|
||||
}
|
||||
if bq.totalSize != 0 {
|
||||
t.Errorf("Reset() totalSize = %d, want 0", bq.totalSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -129,6 +129,50 @@ func (rb *Buffer) Peek(n int) (head []byte, tail []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
// PeekFromPos returns the next n bytes from the specified start position without advancing the read pointer,
|
||||
// it returns all bytes when n <= 0.
|
||||
func (rb *Buffer) PeekFromPos(start, n int) (head []byte, tail []byte) {
|
||||
// 检查起始位置是否有效
|
||||
if start < 0 || start >= rb.size {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if rb.isEmpty {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 如果 n <= 0,返回所有数据
|
||||
if n <= 0 {
|
||||
return rb.peekAll()
|
||||
}
|
||||
|
||||
// 计算读取的数据的实际长度
|
||||
var m int
|
||||
if start < rb.r {
|
||||
// 如果起始位置小于读取位置 r,说明数据是跨越了缓冲区的尾部
|
||||
m = rb.size - start + rb.w
|
||||
} else {
|
||||
// 如果起始位置大于等于读取位置 r,则在缓冲区内部连续读取
|
||||
m = rb.w - start
|
||||
}
|
||||
|
||||
if m > n {
|
||||
m = n
|
||||
}
|
||||
|
||||
// 从 start 开始读取数据,检查是否跨越环形缓冲区的边界
|
||||
if start+m <= rb.size {
|
||||
head = rb.buf[start : start+m]
|
||||
} else {
|
||||
c1 := rb.size - start
|
||||
head = rb.buf[start:]
|
||||
c2 := m - c1
|
||||
tail = rb.buf[:c2]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// peekAll returns all bytes without advancing the read pointer.
|
||||
func (rb *Buffer) peekAll() (head []byte, tail []byte) {
|
||||
if rb.isEmpty {
|
||||
@@ -148,6 +192,25 @@ func (rb *Buffer) peekAll() (head []byte, tail []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
// peekAll returns all bytes without advancing the read pointer.
|
||||
func (rb *Buffer) peekAllFrom(start int) (head []byte, tail []byte) {
|
||||
if rb.isEmpty {
|
||||
return
|
||||
}
|
||||
|
||||
if rb.w > rb.r {
|
||||
head = rb.buf[rb.r:rb.w]
|
||||
return
|
||||
}
|
||||
|
||||
head = rb.buf[rb.r:]
|
||||
if rb.w != 0 {
|
||||
tail = rb.buf[:rb.w]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Discard skips the next n bytes by advancing the read pointer.
|
||||
func (rb *Buffer) Discard(n int) (discarded int, err error) {
|
||||
if n <= 0 {
|
||||
|
||||
Reference in New Issue
Block a user