Refactors JSON-RPC implementation for improved handling

Updates the JSON-RPC implementation to enhance error handling and streamline data processing.

The changes include removing the `ReqId` field from the `Conn` struct and adding it to the `Event` struct for better context management in asynchronous operations. Modifies the `ConnWrite` method in `eventbus/user.go` to accept a `reqId`, ensuring proper association of requests and responses in JSON-RPC flows. Implements error handling in `handleAuthenticatedConn` to send errors to the client, preventing connection closures due to JSON-RPC processing issues.

These changes enhance the reliability and maintainability of the JSON-RPC communication within the system.
This commit is contained in:
tt
2025-04-18 09:20:51 +08:00
parent 16dde49bb5
commit 9fbc5d9804
22 changed files with 754 additions and 111 deletions

View File

@@ -118,7 +118,7 @@ func (cn *connApi) kick(c *wkhttp.Context) {
conn := eventbus.User.ConnById(req.Uid, req.NodeId, req.ConnID)
if conn != nil {
eventbus.User.ConnWrite(conn, &wkproto.DisconnectPacket{
eventbus.User.ConnWrite("", conn, &wkproto.DisconnectPacket{
ReasonCode: wkproto.ReasonConnectKick,
Reason: "server kick",
})

View File

@@ -119,7 +119,7 @@ func (u *user) quitUserDevice(uid string, deviceFlag wkproto.DeviceFlag) error {
oldConns := eventbus.User.ConnsByDeviceFlag(uid, deviceFlag)
if len(oldConns) > 0 {
for _, oldConn := range oldConns {
eventbus.User.ConnWrite(oldConn, &wkproto.DisconnectPacket{
eventbus.User.ConnWrite("", oldConn, &wkproto.DisconnectPacket{
ReasonCode: wkproto.ReasonConnectKick,
})
u.s.timingWheel.AfterFunc(time.Second*2, func() {
@@ -366,7 +366,7 @@ func (u *user) updateToken(c *wkhttp.Context) {
for _, oldConn := range oldConns {
u.Debug("更新Token时存在旧连接", zap.String("uid", req.UID), zap.Int64("id", oldConn.ConnId), zap.String("deviceFlag", req.DeviceFlag.String()))
// 踢旧连接
eventbus.User.ConnWrite(oldConn, &wkproto.DisconnectPacket{
eventbus.User.ConnWrite("", oldConn, &wkproto.DisconnectPacket{
ReasonCode: wkproto.ReasonConnectKick,
Reason: "账号在其他设备上登录",
})

View File

@@ -39,7 +39,7 @@ func (h *Handler) sendack(ctx *eventbus.ChannelContext) {
}
sendPacket := e.Frame.(*wkproto.SendPacket)
eventbus.User.ConnWrite(e.Conn, &wkproto.SendackPacket{
eventbus.User.ConnWrite(e.ReqId, e.Conn, &wkproto.SendackPacket{
Framer: sendPacket.Framer,
MessageID: e.MessageId,
MessageSeq: uint32(e.MessageSeq),

View File

@@ -51,8 +51,7 @@ type Conn struct {
// 不参与编码
LastActive uint64 // 最后一次活动时间单位秒
IsJsonRpc bool // 是否是jsonrpc连接
ReqId string // 请求id (jsonrpc)
IsJsonRpc bool // 是否是jsonrpc连接
}
func (c *Conn) Encode() ([]byte, error) {
@@ -69,7 +68,6 @@ func (c *Conn) Encode() ([]byte, error) {
enc.WriteUint8(c.ProtoVersion)
enc.WriteUint64(c.Uptime)
enc.WriteUint8(wkutil.BoolToUint8(c.IsJsonRpc))
enc.WriteString(c.ReqId)
return enc.Bytes(), nil
}
@@ -126,10 +124,6 @@ func (c *Conn) Decode(data []byte) error {
}
c.IsJsonRpc = wkutil.Uint8ToBool(isJsonRpc)
if c.ReqId, err = dec.String(); err != nil {
return err
}
return nil
}

View File

@@ -98,6 +98,7 @@ type Event struct {
OfflineUsers []string // 离线用户集合
ChannelId string // 频道ID
ChannelType uint8 // 频道类型
ReqId string // 请求ID(非必填)(jsonrpc)
}
func (e *Event) Clone() *Event {
@@ -118,6 +119,7 @@ func (e *Event) Clone() *Event {
OfflineUsers: e.OfflineUsers,
ChannelId: e.ChannelId,
ChannelType: e.ChannelType,
ReqId: e.ReqId,
}
}
@@ -149,11 +151,15 @@ func (e *Event) Size() uint64 {
size += 1 // channel type
}
if e.hasReqId() == 1 {
size += uint64(2 + len(e.ReqId)) // req id
}
return size
}
func (e Event) encodeWithEcoder(enc *wkproto.Encoder) error {
var flag uint8 = e.hasConn()<<7 | e.hasFrame()<<6 | e.hasTrack()<<5 | e.hasChannel()<<4
var flag uint8 = e.hasConn()<<7 | e.hasFrame()<<6 | e.hasTrack()<<5 | e.hasChannel()<<4 | e.hasReqId()<<3
enc.WriteUint8(flag)
enc.WriteUint8(e.Type.Uint8())
@@ -191,6 +197,10 @@ func (e Event) encodeWithEcoder(enc *wkproto.Encoder) error {
enc.WriteUint8(e.ChannelType)
}
if e.hasReqId() == 1 {
enc.WriteString(e.ReqId)
}
return nil
}
@@ -203,7 +213,7 @@ func (e *Event) decodeWithDecoder(dec *wkproto.Decoder) error {
hasFrame := (flag >> 6) & 0x01
hasTrack := (flag >> 5) & 0x01
hasChannel := (flag >> 4) & 0x01
hasReqId := (flag >> 3) & 0x01
typeUint8, err := dec.Uint8()
if err != nil {
return err
@@ -296,6 +306,12 @@ func (e *Event) decodeWithDecoder(dec *wkproto.Decoder) error {
e.ChannelType = channelType
}
if hasReqId == 1 {
if e.ReqId, err = dec.String(); err != nil {
return err
}
}
return nil
}
@@ -326,6 +342,13 @@ func (e *Event) hasChannel() uint8 {
return 0
}
func (e *Event) hasReqId() uint8 {
if e.ReqId != "" {
return 1
}
return 0
}
type EventBatch []*Event
func (e EventBatch) Encode() ([]byte, error) {

View File

@@ -45,6 +45,9 @@ type IUser interface {
AllConnCount() int
// AllConn 获取所有连接
AllConn() []*Conn
// WriteLocalData 写入本地数据(conn 是本地连接)
WriteLocalData(conn *Conn, data []byte) error
}
type userPlus struct {
@@ -74,12 +77,13 @@ func (u *userPlus) Advance(uid string) {
// ========================================== conn ==========================================
// Connect 请求连接
func (u *userPlus) Connect(conn *Conn, connectPacket *wkproto.ConnectPacket) {
func (u *userPlus) Connect(reqId string, conn *Conn, connectPacket *wkproto.ConnectPacket) {
u.user.AddEvent(conn.Uid, &Event{
Type: EventConnect,
Frame: connectPacket,
Conn: conn,
SourceNodeId: options.G.Cluster.NodeId,
ReqId: reqId,
})
}
@@ -127,12 +131,13 @@ func (u *userPlus) UpdateConn(conn *Conn) {
}
// ConnWrite 连接写包
func (u *userPlus) ConnWrite(conn *Conn, frame wkproto.Frame) {
func (u *userPlus) ConnWrite(reqId string, conn *Conn, frame wkproto.Frame) {
u.user.AddEvent(conn.Uid, &Event{
Type: EventConnWriteFrame,
Conn: conn,
Frame: frame,
SourceNodeId: options.G.Cluster.NodeId,
ReqId: reqId,
})
}
@@ -179,3 +184,7 @@ func (u *userPlus) AllConnCount() int {
func (u *userPlus) AllConn() []*Conn {
return u.user.AllConn()
}
func (u *userPlus) WriteLocalData(conn *Conn, data []byte) error {
return u.user.WriteLocalData(conn, data)
}

View File

@@ -92,7 +92,7 @@ func (r *RetryManager) retry(msg *types.RetryMessage) {
r.Info("retry send message", zap.Int("retry", msg.Retry), zap.Uint64("fromNode", msg.FromNode), zap.String("uid", msg.Uid), zap.Int64("messageId", msg.MessageId), zap.Int64("connId", msg.ConnId))
}
eventbus.User.ConnWrite(conn, msg.RecvPacket)
eventbus.User.ConnWrite("", conn, msg.RecvPacket)
}

View File

@@ -148,7 +148,7 @@ func (h *Handler) processChannelPush(events []*eventbus.Event) {
})
}
eventbus.User.ConnWrite(toConn, recvPacket)
eventbus.User.ConnWrite(e.ReqId, toConn, recvPacket)
}
eventbus.User.Advance(e.ToUid)
}

View File

@@ -50,7 +50,15 @@ func (s *Server) onData(conn wknet.Conn) error {
_, _ = conn.Discard(consumedBytes)
return nil
} else {
return s.handleAuthenticatedConn(conn, connCtx, buff, isJson)
err = s.handleAuthenticatedConn(conn, connCtx, buff, isJson)
// 如果jsonrpc请求则不返回错误, 因为jsonrpc请求将错误返回给客户端了这里不返回error是为了防止返回error服务将此连接关闭
if isJson {
if err != nil {
s.Warn("Failed to handle authenticated conn", zap.Error(err))
return nil
}
}
return err
}
}
@@ -62,29 +70,31 @@ func (s *Server) handleAuthenticatedConn(conn wknet.Conn, connCtx *eventbus.Conn
var events []*eventbus.Event
frames := make([]wkproto.Frame, 0, 10)
reqIds := make([]string, 0, 10)
if isJson {
reader := bytes.NewReader(buff)
decoder := json.NewDecoder(reader)
for {
packet, _, err := jsonrpc.Decode(decoder)
packet, probe, err := jsonrpc.Decode(decoder)
if err != nil {
if err == io.EOF {
break
}
s.Warn("Failed to decode jsonrpc packet", zap.Error(err))
conn.Close()
eventbus.User.WriteLocalData(connCtx, jsonrpc.EncodeErrorResponse(jsonrpc.DecodeID(probe.ID), err))
return err
}
if packet == nil {
break
}
frame, err := jsonrpc.ToFrame(packet)
frame, reqId, err := jsonrpc.ToFrame(packet)
if err != nil {
s.Warn("Failed to convert jsonrpc packet to frame", zap.Error(err))
conn.Close()
eventbus.User.WriteLocalData(connCtx, jsonrpc.EncodeErrorResponse(jsonrpc.DecodeID(probe.ID), err))
return err
}
frames = append(frames, frame)
reqIds = append(reqIds, reqId)
}
offset += (len(buff) - reader.Len())
@@ -109,7 +119,13 @@ func (s *Server) handleAuthenticatedConn(conn wknet.Conn, connCtx *eventbus.Conn
return nil
}
for _, frame := range frames {
for i, frame := range frames {
var reqId string
if len(reqIds) > 0 && len(reqIds) == len(frames) {
reqId = reqIds[i]
}
event := &eventbus.Event{
Type: eventbus.EventOnSend, // Assuming all data frames trigger an OnSend event internally
Frame: frame,
@@ -118,6 +134,7 @@ func (s *Server) handleAuthenticatedConn(conn wknet.Conn, connCtx *eventbus.Conn
Track: track.Message{
PreStart: time.Now(),
},
ReqId: reqId,
}
event.Track.Record(track.PositionStart)
@@ -244,13 +261,12 @@ func (s *Server) handleUnauthenticatedConn(conn wknet.Conn, buff []byte, isJson
ProtoVersion: connectPacket.Version,
Uptime: fasttime.UnixTimestamp(),
IsJsonRpc: isJson,
ReqId: reqId,
}
conn.SetContext(connCtx)
conn.SetMaxIdle(time.Second * 4)
eventbus.User.Connect(connCtx, connectPacket)
eventbus.User.Connect(reqId, connCtx, connectPacket)
eventbus.User.Advance(connCtx.Uid)
return connCtx, consumedBytes, nil

View File

@@ -1,11 +1,14 @@
package event
import (
"fmt"
"hash/fnv"
"github.com/WuKongIM/WuKongIM/internal/eventbus"
"github.com/WuKongIM/WuKongIM/internal/options"
"github.com/WuKongIM/WuKongIM/internal/service"
"github.com/WuKongIM/WuKongIM/pkg/wklog"
"github.com/WuKongIM/WuKongIM/pkg/wknet"
wkproto "github.com/WuKongIM/WuKongIMGoProto"
"go.uber.org/zap"
)
@@ -122,3 +125,36 @@ func (e *EventPool) AllConnCount() int {
func (e *EventPool) RemoveConn(conn *eventbus.Conn) {
e.pollerByUid(conn.Uid).removeConn(conn)
}
func (e *EventPool) WriteLocalData(conn *eventbus.Conn, data []byte) error {
if !options.G.IsLocalNode(conn.NodeId) {
e.Error("writeLocalData: conn not local node", zap.String("uid", conn.Uid), zap.Uint64("nodeId", conn.NodeId), zap.Int64("connId", conn.ConnId))
return fmt.Errorf("writeLocalData: conn not local node")
}
realConn := service.ConnManager.GetConn(conn.ConnId)
if realConn == nil {
e.Error("writeLocalData: conn not exist", zap.String("uid", conn.Uid), zap.Uint64("nodeId", conn.NodeId), zap.Int64("connId", conn.ConnId))
return fmt.Errorf("writeLocalData: conn not exist")
}
wsConn, wsok := realConn.(wknet.IWSConn) // websocket连接
var err error
if wsok {
if conn.IsJsonRpc {
err = wsConn.WriteServerText(data)
} else {
err = wsConn.WriteServerBinary(data)
}
if err != nil {
e.Warn("writeFrame: Failed to ws write the message", zap.Error(err))
}
} else {
_, err := realConn.WriteToOutboundBuffer(data)
if err != nil {
e.Warn("writeFrame: Failed to write the message", zap.Error(err))
return err
}
}
return realConn.WakeWrite()
}

View File

@@ -35,7 +35,7 @@ func (h *Handler) connack(ctx *eventbus.UserContext) {
// 更新连接
eventbus.User.UpdateConn(conn)
}
eventbus.User.ConnWrite(conn, connack)
eventbus.User.ConnWrite(event.ReqId, conn, connack)
}
}

View File

@@ -35,6 +35,7 @@ func (h *Handler) connect(ctx *eventbus.UserContext) {
Conn: conn,
Frame: packet,
SourceNodeId: options.G.Cluster.NodeId,
ReqId: event.ReqId,
}
if options.G.IsLocalNode(conn.NodeId) {
eventbus.User.AddEvent(uid, connackEvent)
@@ -45,6 +46,7 @@ func (h *Handler) connect(ctx *eventbus.UserContext) {
Conn: conn,
Frame: packet,
SourceNodeId: options.G.Cluster.NodeId,
ReqId: event.ReqId,
})
}
@@ -130,7 +132,7 @@ func (h *Handler) handleConnect(event *eventbus.Event) (wkproto.ReasonCode, *wkp
zap.String("deviceID", connectPacket.DeviceID),
zap.String("oldDeviceId", oldConn.DeviceId),
)
eventbus.User.ConnWrite(oldConn, &wkproto.DisconnectPacket{
eventbus.User.ConnWrite(event.ReqId, oldConn, &wkproto.DisconnectPacket{
ReasonCode: wkproto.ReasonConnectKick,
Reason: "login in other device",
})

View File

@@ -59,7 +59,7 @@ func (h *Handler) handleOnSend(event *eventbus.Event) {
ClientMsgNo: sendPacket.ClientMsgNo,
ReasonCode: wkproto.ReasonPayloadDecodeError,
}
eventbus.User.ConnWrite(conn, sendack)
eventbus.User.ConnWrite(event.ReqId, conn, sendack)
return
}
sendPacket.Payload = newPayload // 使用解密后的 Payload
@@ -80,6 +80,7 @@ func (h *Handler) handleOnSend(event *eventbus.Event) {
Frame: sendPacket,
MessageId: event.MessageId,
Track: event.Track,
ReqId: event.ReqId,
})
// 推进
eventbus.Channel.Advance(fakeChannelId, channelType)

View File

@@ -14,5 +14,5 @@ func (h *Handler) ping(event *eventbus.Event) {
trace.GlobalTrace.Metrics.App().PongCountAdd(1)
trace.GlobalTrace.Metrics.App().PongBytesAdd(1)
eventbus.User.ConnWrite(conn, &wkproto.PongPacket{})
eventbus.User.ConnWrite(event.ReqId, conn, &wkproto.PongPacket{})
}

View File

@@ -46,7 +46,7 @@ func (h *Handler) writeLocalFrame(event *eventbus.Event) {
)
if conn.IsJsonRpc {
req, err := jsonrpc.FromFrame(conn.ReqId, frame)
req, err := jsonrpc.FromFrame(event.ReqId, frame)
if err != nil {
h.Error("writeFrame jsonrpc: from frame err", zap.Error(err))
return
@@ -86,7 +86,11 @@ func (h *Handler) writeLocalFrame(event *eventbus.Event) {
}
wsConn, wsok := realConn.(wknet.IWSConn) // websocket连接
if wsok {
err := wsConn.WriteServerBinary(data)
if conn.IsJsonRpc {
err = wsConn.WriteServerText(data)
} else {
err = wsConn.WriteServerBinary(data)
}
if err != nil {
h.Warn("writeFrame: Failed to ws write the message", zap.Error(err))
}

View File

@@ -2,6 +2,7 @@ package jsonrpc
import (
"encoding/json"
"errors"
"fmt"
wkproto "github.com/WuKongIM/WuKongIMGoProto"
@@ -24,6 +25,20 @@ const (
MethodRecv = "recv" // Notification method
)
// Predefined decoding errors
var (
ErrInvalidVersion = errors.New("jsonrpc: invalid version")
ErrInvalidStructure = errors.New("jsonrpc: invalid message structure or field combination")
ErrAmbiguousMessageType = errors.New("jsonrpc: ambiguous message type")
ErrResponseFormat = errors.New("jsonrpc: invalid response format (missing id, result/error mismatch)")
ErrRequestFormat = errors.New("jsonrpc: invalid request format (missing id or method)")
ErrNotificationFormat = errors.New("jsonrpc: invalid notification format (missing method)")
ErrUnknownMethod = errors.New("jsonrpc: unknown method")
ErrMissingParams = errors.New("jsonrpc: missing params field")
ErrUnmarshalFieldFailed = errors.New("jsonrpc: failed to unmarshal field") // Base error for wrapping
ErrInternal = errors.New("jsonrpc: internal codec error") // For unexpected logic paths
)
// Probe is a temporary structure used to determine the type of an incoming JSON-RPC message
// by checking the presence of key fields like id, method, result, error.
type Probe struct {
@@ -43,6 +58,8 @@ const (
)
// decodingError creates a formatted error specific to JSON-RPC decoding.
// It's a simple wrapper around fmt.Errorf with a prefix.
// Callers should use %w in the format string to wrap base errors when needed.
func decodingError(format string, args ...interface{}) error {
return fmt.Errorf("jsonrpc decode: "+format, args...)
}
@@ -57,6 +74,30 @@ func Encode(msg interface{}) ([]byte, error) {
return bytes, nil
}
func EncodeErrorResponse(id string, err error) []byte {
data, err := Encode(GenericResponse{
BaseResponse: BaseResponse{
Jsonrpc: jsonRPCVersion,
ID: id,
Error: &ErrorObject{
Message: err.Error(),
},
},
})
if err != nil {
return nil
}
return data
}
func DecodeID(id json.RawMessage) string {
var idStr string
if err := json.Unmarshal(id, &idStr); err != nil {
return ""
}
return idStr
}
// determineMessageType probes the basic structure, validates version and fields,
// and determines if the message is a Request, Response, or Notification.
func determineMessageType(probe *Probe) (msgType int, version string, err error) {
@@ -65,14 +106,14 @@ func determineMessageType(probe *Probe) (msgType int, version string, err error)
if probe.Jsonrpc != nil {
var parsedVersion string
if jsonErr := json.Unmarshal(probe.Jsonrpc, &parsedVersion); jsonErr != nil {
err = decodingError("failed to unmarshal jsonrpc field: %w", jsonErr)
err = fmt.Errorf("%w: jsonrpc field: %w", ErrUnmarshalFieldFailed, jsonErr) // Wrap base error
return
}
if parsedVersion != jsonRPCVersion {
err = decodingError("invalid 'jsonrpc' version '%s', must be %s or omitted", parsedVersion, jsonRPCVersion)
err = fmt.Errorf("%w: expected '%s', got '%s'", ErrInvalidVersion, jsonRPCVersion, parsedVersion)
return
}
version = parsedVersion // Use provided version if valid
version = parsedVersion
}
// Check field presence
@@ -91,13 +132,13 @@ func determineMessageType(probe *Probe) (msgType int, version string, err error)
// Validate field combinations
switch {
case prelimIsRequest && prelimIsResponse:
err = decodingError("message cannot have both 'method' and ('result' or 'error')")
err = ErrInvalidStructure // Use predefined error
return
case prelimIsResponse && !resultIsPresent && !errorIsPresent:
err = decodingError("response must contain either 'result' or 'error'")
err = ErrResponseFormat // Use predefined error
return
case prelimIsResponse && resultIsPresent && errorIsPresent:
err = decodingError("response cannot contain both 'result' and 'error'")
err = ErrResponseFormat // Use predefined error
return
// case prelimIsRequest && prelimIsNotification: // This overlap is handled by specific type assignment below
// err = decodingError("message ambiguity: matches request and notification criteria (id: %s, method: %v)", string(probe.ID), probe.Method)
@@ -143,7 +184,7 @@ func determineMessageType(probe *Probe) (msgType int, version string, err error)
}
if msgType == msgTypeUnknown && err == nil { // Assign error if type is unknown and no specific validation failed
err = decodingError("unable to determine message type (invalid field combination)")
err = ErrInvalidStructure // Assign general structure error if type unknown
}
return
@@ -157,24 +198,32 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
// 1. Probe the message structure
var probe Probe
if err := decoder.Decode(&probe); err != nil {
return nil, probe, err // Return zero Probe on initial decode error
// If it's a json syntax error, wrap it?
var syntaxErr *json.SyntaxError
if errors.As(err, &syntaxErr) {
return nil, probe, fmt.Errorf("%w: %w", ErrInvalidStructure, err)
}
return nil, probe, err // Includes io.EOF
}
// 2. Determine message type and validate basic structure
msgType, version, err := determineMessageType(&probe)
if err != nil {
return nil, probe, err // Return probe even if type determination fails, might be useful
return nil, probe, err // Return error from determination (already specific)
}
// 3. Construct and Populate the Specific Type based on msgType
switch msgType {
case msgTypeRequest:
if probe.Method == "" { // Should be caught by determineMessageType
return nil, probe, decodingError("internal: msgTypeRequest but method is nil")
return nil, probe, ErrRequestFormat
}
baseReq := BaseRequest{Jsonrpc: version, Method: probe.Method}
if probe.ID == nil || string(probe.ID) == "null" { // ID is mandatory and non-null for requests
return nil, probe, ErrRequestFormat
}
if err := json.Unmarshal(probe.ID, &baseReq.ID); err != nil {
return nil, probe, decodingError("failed to unmarshal request ID: %w", err)
return nil, probe, fmt.Errorf("%w: request ID: %w", ErrUnmarshalFieldFailed, err)
}
switch probe.Method {
@@ -182,50 +231,50 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
var req ConnectRequest
req.BaseRequest = baseReq
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s request", MethodConnect)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodConnect)
}
if err := json.Unmarshal(probe.Params, &req.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodConnect, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodConnect, err)
}
return req, probe, nil
case MethodSend:
var req SendRequest
req.BaseRequest = baseReq
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s request", MethodSend)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodSend)
}
if err := json.Unmarshal(probe.Params, &req.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodSend, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodSend, err)
}
return req, probe, nil
case MethodRecvAck:
var req RecvAckRequest
req.BaseRequest = baseReq
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s request", MethodRecvAck)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodRecvAck)
}
if err := json.Unmarshal(probe.Params, &req.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodRecvAck, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodRecvAck, err)
}
return req, probe, nil
case MethodSubscribe:
var req SubscribeRequest
req.BaseRequest = baseReq
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s request", MethodSubscribe)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodSubscribe)
}
if err := json.Unmarshal(probe.Params, &req.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodSubscribe, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodSubscribe, err)
}
return req, probe, nil
case MethodUnsubscribe:
var req UnsubscribeRequest
req.BaseRequest = baseReq
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s request", MethodUnsubscribe)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodUnsubscribe)
}
if err := json.Unmarshal(probe.Params, &req.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodUnsubscribe, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodUnsubscribe, err)
}
return req, probe, nil
case MethodPing:
@@ -235,7 +284,7 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
var p PingParams
if err := json.Unmarshal(probe.Params, &p); err != nil {
if string(probe.Params) != "{}" { // Allow empty object
return nil, probe, decodingError("failed to unmarshal %s params: %w", MethodPing, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodPing, err)
}
}
req.Params = &p
@@ -245,23 +294,29 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
var req DisconnectRequest
req.BaseRequest = baseReq
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s request", MethodDisconnect)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodDisconnect)
}
if err := json.Unmarshal(probe.Params, &req.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodDisconnect, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodDisconnect, err)
}
return req, probe, nil
default:
return nil, probe, decodingError("unknown request method '%s'", probe.Method)
return nil, probe, fmt.Errorf("%w: %s", ErrUnknownMethod, probe.Method)
}
case msgTypeResponse:
baseResp := BaseResponse{Jsonrpc: version}
if probe.ID == nil { // Should be caught by determineMessageType
return nil, probe, decodingError("internal: msgTypeResponse but ID is nil")
if probe.ID == nil || string(probe.ID) == "null" { // ID is mandatory and non-null for responses
return nil, probe, ErrResponseFormat
}
if err := json.Unmarshal(probe.ID, &baseResp.ID); err != nil {
return nil, probe, decodingError("failed to unmarshal response ID: %w", err)
return nil, probe, fmt.Errorf("%w: response ID: %w", ErrUnmarshalFieldFailed, err)
}
if probe.Result == nil && probe.Error == nil {
return nil, probe, ErrResponseFormat // Must have result or error
}
if probe.Result != nil && probe.Error != nil {
return nil, probe, ErrResponseFormat // Cannot have both
}
resp := GenericResponse{
@@ -271,7 +326,7 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
if probe.Error != nil {
var errObj ErrorObject
if err := json.Unmarshal(probe.Error, &errObj); err != nil {
return nil, probe, decodingError("failed to unmarshal error object: %w", err)
return nil, probe, fmt.Errorf("%w: error object: %w", ErrUnmarshalFieldFailed, err)
}
resp.Error = &errObj
}
@@ -279,7 +334,7 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
case msgTypeNotification:
if probe.Method == "" { // Should be caught by determineMessageType
return nil, probe, decodingError("internal: msgTypeNotification but method is nil")
return nil, probe, ErrNotificationFormat
}
baseNotif := BaseNotification{Jsonrpc: version, Method: probe.Method}
@@ -288,20 +343,20 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
var notif RecvNotification
notif.BaseNotification = baseNotif
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s notification", MethodRecv)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodRecv)
}
if err := json.Unmarshal(probe.Params, &notif.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodRecv, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodRecv, err)
}
return notif, probe, nil
case MethodDisconnect:
var notif DisconnectNotification
notif.BaseNotification = baseNotif
if probe.Params == nil {
return nil, probe, decodingError("missing params for %s notification", MethodDisconnect)
return nil, probe, fmt.Errorf("%w: method %s", ErrMissingParams, MethodDisconnect)
}
if err := json.Unmarshal(probe.Params, &notif.Params); err != nil {
return nil, probe, decodingError("unmarshal %s params: %w", MethodDisconnect, err)
return nil, probe, fmt.Errorf("%w: %s params: %w", ErrUnmarshalFieldFailed, MethodDisconnect, err)
}
return notif, probe, nil
case MethodPong:
@@ -309,33 +364,41 @@ func Decode(decoder *json.Decoder) (interface{}, Probe, error) {
notif.BaseNotification = baseNotif
return notif, probe, nil
default:
return nil, probe, decodingError("unknown notification method '%s'", probe.Method)
return nil, probe, fmt.Errorf("%w: %s", ErrUnknownMethod, probe.Method)
}
default: // msgTypeUnknown or other unexpected case
// Error was already generated by determineMessageType if type was unknown
// If we reach here unexpectedly, return a generic internal error
// If determineMessageType returned an error, it's already specific.
// Otherwise, return the general invalid structure error.
if err == nil {
err = decodingError("internal error - unexpected message type state")
err = ErrInvalidStructure
}
return nil, probe, err
}
}
func ToFrame(packet interface{}) (wkproto.Frame, error) {
func ToFrame(packet interface{}) (wkproto.Frame, string, error) {
switch p := packet.(type) {
case ConnectRequest:
return p.Params.ToProto(), nil
return p.Params.ToProto(), p.ID, nil
case SendRequest:
return p.Params.ToProto(p.ID), nil
return p.Params.ToProto(), p.ID, nil
case RecvAckRequest:
return p.Params.ToProto()
frame, err := p.Params.ToProto()
if err != nil {
return nil, "", err
}
return frame, p.ID, nil
case PingRequest:
return &wkproto.PingPacket{}, p.ID, nil
case DisconnectRequest:
return p.Params.ToProto(), p.ID, nil
}
return nil, fmt.Errorf("unknown packet type: %T", packet)
return nil, "", fmt.Errorf("unknown packet type: %T", packet)
}
func FromFrame(id string, frame wkproto.Frame) (interface{}, error) {
func FromFrame(reqId string, frame wkproto.Frame) (interface{}, error) {
switch frame.GetFrameType() {
case wkproto.CONNACK:
@@ -344,7 +407,7 @@ func FromFrame(id string, frame wkproto.Frame) (interface{}, error) {
return ConnectResponse{
BaseResponse: BaseResponse{
Jsonrpc: jsonRPCVersion,
ID: id,
ID: reqId,
},
Result: params,
}, nil
@@ -354,7 +417,7 @@ func FromFrame(id string, frame wkproto.Frame) (interface{}, error) {
return SendResponse{
BaseResponse: BaseResponse{
Jsonrpc: jsonRPCVersion,
ID: sendack.ClientMsgNo,
ID: reqId,
},
Result: result,
}, nil
@@ -380,7 +443,7 @@ func FromFrame(id string, frame wkproto.Frame) (interface{}, error) {
},
}, nil
}
return nil, fmt.Errorf("unknown frame type: %d", frame.GetFrameType())
return nil, fmt.Errorf("jsonrpc: unknown frame type: %d", frame.GetFrameType())
}
// IsJSONObjectPrefix checks if the byte slice likely starts with a JSON object,

View File

@@ -80,7 +80,7 @@ func TestEncodeDecode_Connect(t *testing.T) {
// --- Test Response (Success) ---
respResult := ConnectResult{
Header: Header{NoPersist: false},
Header: &Header{NoPersist: false},
ServerVersion: 1,
ServerKey: "testServerKey",
Salt: "testSalt",
@@ -242,7 +242,7 @@ func TestDecode_EdgeCases(t *testing.T) { // Renamed for clarity
decoder := json.NewDecoder(bytes.NewReader(data))
_, _, err := Decode(decoder) // Ignore msg and probe
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid 'jsonrpc' version")
assert.Contains(t, err.Error(), "invalid version: expected")
})
t.Run("RequestMissingId", func(t *testing.T) {
@@ -250,7 +250,7 @@ func TestDecode_EdgeCases(t *testing.T) { // Renamed for clarity
decoder := json.NewDecoder(bytes.NewReader(data))
_, _, err := Decode(decoder) // Ignore msg and probe
assert.Error(t, err)
assert.Contains(t, err.Error(), "unknown notification method 'get_data'") // Updated assertion based on actual behavior
assert.Contains(t, err.Error(), "jsonrpc decode: unknown notification method") // Updated assertion based on actual behavior
})
t.Run("RequestNullId", func(t *testing.T) {
@@ -259,7 +259,7 @@ func TestDecode_EdgeCases(t *testing.T) { // Renamed for clarity
decoder := json.NewDecoder(bytes.NewReader(data))
msg, _, err := Decode(decoder) // Ignore probe
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown notification method 'get_data'")
assert.Contains(t, err.Error(), "invalid request format")
assert.Nil(t, msg)
})

445
pkg/jsonrpc/protocol.md Normal file
View File

@@ -0,0 +1,445 @@
# WuKongIM JSON-RPC 协议文档
## 概述
本文档描述了 WuKongIM 使用 JSON-RPC 2.0 规范进行通信的协议格式。所有请求、响应和通知都遵循 JSON-RPC 2.0 标准结构。
* `jsonrpc`: **可选** 字符串,固定为 "2.0"。如果省略,服务器应假定其为 "2.0"。
* `method`: 请求或通知的方法名。
* `params`: 请求或通知的参数,通常是一个对象。
* `id`: 请求的唯一标识符(字符串类型)。响应必须包含与请求相同的 id。通知没有 id。
* `result`: 成功响应的结果数据。
* `error`: 错误响应的错误对象。
## 注意事项
1. 建立Websocket连接后需要在2秒之内进行认证(`connect`),超过2秒或者认证失败的连接将断开
2. 发送错误的数据格式或系统不支持的`method`方法,系统都将断开连接
## 通用组件
以下是一些在多个消息类型中复用的组件定义:
### ErrorObject
当请求处理失败时,响应中会包含此对象。
| 字段 | 类型 | 必填 | 描述 |
| :-------- | :------ | :--- | :------------- |
| `code` | integer | 是 | 错误码 |
| `message` | string | 是 | 错误描述 |
| `data` | any | 否 | 附加错误数据 |
### Header
可选的消息头信息。
| 字段 | 类型 | 必填 | 描述 |
| :---------- | :------ | :--- | :------------------- |
| `noPersist` | boolean | 否 | 消息是否不存储 |
| `redDot` | boolean | 否 | 是否显示红点 |
| `syncOnce` | boolean | 否 | 是否只被同步一次 |
| `dup` | boolean | 否 | 是否是重发的消息 |
### SettingFlags
消息设置标记位。
| 字段 | 类型 | 必填 | 描述 |
| :-------- | :------ | :--- | :----------------- |
| `receipt` | boolean | 否 | 消息已读回执 |
| `stream` | boolean | 否 | 是否为流式消息 |
| `topic` | boolean | 否 | 是否包含 Topic |
## 消息类型详解
### 1. 连接 (Connect)
#### Connect Request (`connect`)
客户端发起的第一个请求,用于建立连接和认证。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :-------------- | :------ | :--- | :------------------------- |
| `uid` | string | 是 | 用户ID |
| `token` | string | 是 | 认证Token |
| `header` | Header | 否 | 消息头 |
| `version` | integer | 否 | 客户端协议版本 |
| `clientKey` | string | 否 | 客户端公钥 |
| `deviceId` | string | 否 | 设备ID |
| `deviceFlag` | integer | 否 | 设备标识 (1:APP, 2:WEB...) |
| `clientTimestamp`| integer | 否 | 客户端13位毫秒时间戳 |
**最小示例**
```json
{
"method": "connect",
"params": {
"uid": "testUser",
"token": "testToken"
},
"id": "req-conn-1"
}
```
#### Connect Response
服务器对 `connect` 请求的响应。
**成功结果 (`result`)**
| 字段 | 类型 | 必填 | 描述 |
| :------------ | :------ | :--- | :----------------------------- |
| `serverKey` | string | 是 | 服务端的DH公钥 |
| `salt` | string | 是 | 加密盐值 |
| `timeDiff` | integer | 是 | 客户端与服务器时间差(毫秒) |
| `reasonCode` | integer | 是 | 原因码 (成功时通常为0) |
| `header` | Header | 否 | 消息头 |
| `serverVersion`| integer | 否 | 服务端版本 |
| `nodeId` | integer | 否 | 连接的节点ID (协议版本 >= 4) |
**错误结果 (`error`)**: 参考 `ErrorObject`
**最小成功示例**
```json
{
"result": {
"serverKey": "serverPublicKey",
"salt": "randomSalt",
"timeDiff": -15,
"reasonCode": 0
},
"id": "req-conn-1"
}
```
**最小错误示例**
```json
{
"error": {
"code": 1001,
"message": "Authentication Failed"
},
"id": "req-conn-1"
}
```
### 2. 发送消息 (Send)
#### Send Request (`send`)
客户端发送消息到指定频道。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :------------ | :----------- | :--- | :------------------------------- |
| `clientMsgNo` | string | 是 | 客户端消息唯一编号(UUID) |
| `channelId` | string | 是 | 频道ID |
| `channelType` | integer | 是 | 频道类型 (1:个人, 2:群组) |
| `payload` | object | 是 | 消息内容 (业务自定义JSON对象) |
| `header` | Header | 否 | 消息头 |
| `setting` | SettingFlags | 否 | 消息设置 |
| `msgKey` | string | 否 | 消息验证Key |
| `expire` | integer | 否 | 消息过期时间(秒), 0表示不过期 |
| `streamNo` | string | 否 | 流编号 (如果 setting.stream 为 true) |
| `topic` | string | 否 | 消息 Topic (如果 setting.topic 为 true) |
**最小示例**
```json
{
"method": "send",
"params": {
"clientMsgNo": "uuid-12345",
"channelId": "targetUser",
"channelType": 1,
"payload": {"content": "Hello!","type":1}
},
"id": "req-send-1"
}
```
#### Send Response
服务器对 `send` 请求的响应,表示消息已收到并分配了服务端 ID。
**成功结果 (`result`)**
| 字段 | 类型 | 必填 | 描述 |
| :------------ | :------ | :--- | :------------------------- |
| `messageId` | string | 是 | 服务端消息ID |
| `messageSeq` | integer | 是 | 服务端消息序列号 |
| `reasonCode` | integer | 是 | 原因码 (成功时通常为0) |
| `header` | Header | 否 | 消息头 |
**错误结果 (`error`)**: 参考 `ErrorObject`
**最小成功示例**
```json
{
"result": {
"messageId": "serverMsgId1",
"messageSeq": 1001,
"reasonCode": 0
},
"id": "req-send-1"
}
```
### 3. 收到消息 (Recv)
#### Recv Notification (`recv`)
服务器推送消息给客户端。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :------------ | :----------- | :--- | :------------------------------------- |
| `messageId` | string | 是 | 服务端消息ID |
| `messageSeq` | integer | 是 | 服务端消息序列号 |
| `timestamp` | integer | 是 | 服务端消息时间戳(秒) |
| `channelId` | string | 是 | 频道ID |
| `channelType` | integer | 是 | 频道类型 |
| `fromUid` | string | 是 | 发送者UID |
| `payload` | object | 是 | 消息内容 (业务自定义JSON对象) |
| `header` | Header | 否 | 消息头 |
| `setting` | SettingFlags | 否 | 消息设置 |
| `msgKey` | string | 否 | 消息验证Key |
| `expire` | integer | 否 | 消息过期时间(秒) (协议版本 >= 3) |
| `clientMsgNo` | string | 否 | 客户端消息唯一编号 (用于去重) |
| `streamNo` | string | 否 | 流编号 (协议版本 >= 2) |
| `streamId` | string | 否 | 流序列号 (协议版本 >= 2) |
| `streamFlag` | integer | 否 | 流标记 (0:Start, 1:Ing, 2:End) |
| `topic` | string | 否 | 消息 Topic (如果 setting.topic 为 true) |
**最小示例**
```json
{
"method": "recv",
"params": {
"messageId": "serverMsgId2",
"messageSeq": 50,
"timestamp": 1678886400,
"channelId": "senderUser",
"channelType": 1,
"fromUid": "senderUser",
"payload": {"content": "How are you?","type":1}
}
}
```
### 4. 收到消息确认 (RecvAck)
#### RecvAck Request (`recvack`)
客户端确认收到某条消息。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :----------- | :------ | :--- | :----------------------- |
| `messageId` | string | 是 | 要确认的服务端消息ID |
| `messageSeq` | integer | 是 | 要确认的服务端消息序列号 |
| `header` | Header | 否 | 消息头 |
**最小示例**
```json
{
"method": "recvack",
"params": {
"messageId": "serverMsgId2",
"messageSeq": 50
},
"id": "req-ack-1"
}
```
*(注:`recvack` 通常没有特定的响应体,如果需要响应,服务器可能会返回一个空的成功 `result` 或错误)*
### 5. 订阅频道 (Subscribe)(暂不支持)
#### Subscribe Request (`subscribe`)
客户端请求订阅指定频道的消息。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :------------ | :----------- | :--- | :------------------------- |
| `subNo` | string | 是 | 订阅请求编号 (客户端生成) |
| `channelId` | string | 是 | 要订阅的频道ID |
| `channelType` | integer | 是 | 频道类型 |
| `header` | Header | 否 | 消息头 |
| `setting` | SettingFlags | 否 | 消息设置 |
| `param` | string | 否 | 订阅参数 (可选) |
**最小示例**
```json
{
"method": "subscribe",
"params": {
"subNo": "sub-req-1",
"channelId": "group123",
"channelType": 2
},
"id": "req-sub-1"
}
```
### 6. 取消订阅频道 (Unsubscribe)(暂不支持)
#### Unsubscribe Request (`unsubscribe`)
客户端请求取消订阅指定频道。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :------------ | :----------- | :--- | :----------------------------- |
| `subNo` | string | 是 | 取消订阅请求编号 (客户端生成) |
| `channelId` | string | 是 | 要取消订阅的频道ID |
| `channelType` | integer | 是 | 频道类型 |
| `header` | Header | 否 | 消息头 |
| `setting` | SettingFlags | 否 | 消息设置 |
**最小示例**
```json
{
"method": "unsubscribe",
"params": {
"subNo": "unsub-req-1",
"channelId": "group123",
"channelType": 2
},
"id": "req-unsub-1"
}
```
#### Subscription Response
服务器对 `subscribe``unsubscribe` 请求的响应。
**成功结果 (`result`)**
| 字段 | 类型 | 必填 | 描述 |
| :------------ | :------ | :--- | :----------------------------- |
| `subNo` | string | 是 | 对应的请求编号 |
| `channelId` | string | 是 | 对应的频道ID |
| `channelType` | integer | 是 | 对应的频道类型 |
| `action` | integer | 是 | 动作 (0: Subscribe, 1: Unsubscribe) |
| `reasonCode` | integer | 是 | 原因码 (成功时通常为0) |
| `header` | Header | 否 | 消息头 |
**错误结果 (`error`)**: 参考 `ErrorObject`
**最小成功示例 (Subscribe)**
```json
{
"result": {
"subNo": "sub-req-1",
"channelId": "group123",
"channelType": 2,
"action": 0,
"reasonCode": 0
},
"id": "req-sub-1"
}
```
### 7. 心跳 (Ping/Pong)
#### Ping Request (`ping`)
客户端发送心跳以保持连接。
**参数 (`params`)**: 通常为 `null` 或空对象 `{}`
**最小示例**
```json
{
"method": "ping",
"id": "req-ping-1"
}
```
#### Pong Response
服务器对 `ping` 请求的响应。
**成功结果 (`result`)**: 通常为 `null` 或空对象 `{}`
**错误结果 (`error`)**: 参考 `ErrorObject`
**最小成功示例**
```json
{
"method": "pong"
}
```
### 8. 断开连接 (Disconnect)(暂不支持)
#### Disconnect Request (`disconnect`)
客户端主动通知服务器断开连接。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :----------- | :------ | :--- | :--------------------- |
| `reasonCode` | integer | 是 | 原因码 |
| `header` | Header | 否 | 消息头 |
| `reason` | string | 否 | 断开原因描述 (可选) |
**最小示例**
```json
{
"method": "disconnect",
"params": {
"reasonCode": 0
},
"id": "req-disc-1"
}
```
*(注:`disconnect` 请求通常不需要服务器响应)*
#### Disconnect Notification (`disconnect`)
服务器通知客户端连接已断开(例如,被踢出)。
**参数 (`params`)**
| 字段 | 类型 | 必填 | 描述 |
| :----------- | :------ | :--- | :--------------------- |
| `reasonCode` | integer | 是 | 原因码 |
| `header` | Header | 否 | 消息头 |
| `reason` | string | 否 | 断开原因描述 (可选) |
**最小示例**
```json
{
"method": "disconnect",
"params": {
"reasonCode": 401,
"reason": "Kicked by another device"
}
}
```

View File

@@ -7,6 +7,7 @@ import (
// Import the WuKongIMGoProto package
"strconv" // Added for MessageID parsing
"github.com/WuKongIM/WuKongIM/pkg/wkutil"
wkproto "github.com/WuKongIM/WuKongIMGoProto"
)
@@ -104,6 +105,7 @@ type SendParams struct {
Setting SettingFlags `json:"setting,omitempty"`
MsgKey string `json:"msgKey,omitempty"`
Expire uint32 `json:"expire,omitempty"`
ClientMsgNo string `json:"clientMsgNo,omitempty"`
StreamNo string `json:"streamNo,omitempty"`
ChannelID string `json:"channelId"`
ChannelType int `json:"channelType"`
@@ -142,25 +144,24 @@ type DisconnectParams struct {
// --- Specific Result Payloads ---
type ConnectResult struct {
Header Header `json:"header,omitempty"`
Header *Header `json:"header,omitempty"`
ServerVersion int `json:"serverVersion,omitempty"`
ServerKey string `json:"serverKey"`
Salt string `json:"salt"`
TimeDiff int64 `json:"timeDiff"`
ServerKey string `json:"serverKey,omitempty"`
Salt string `json:"salt,omitempty"`
TimeDiff int64 `json:"timeDiff,omitempty"`
ReasonCode ReasonCodeEnum `json:"reasonCode"`
NodeID uint64 `json:"nodeId,omitempty"`
NodeID uint64 `json:"nodeId"`
}
type SendResult struct {
Header Header `json:"header,omitempty"`
MessageID string `json:"messageId"`
MessageSeq uint32 `json:"messageSeq"`
ClientMsgNo string `json:"clientMsgNo,omitempty"`
ReasonCode ReasonCodeEnum `json:"reasonCode"`
Header *Header `json:"header,omitempty"`
MessageID string `json:"messageId"`
MessageSeq uint32 `json:"messageSeq"`
ReasonCode ReasonCodeEnum `json:"reasonCode"`
}
type SubscriptionResult struct {
Header Header `json:"header,omitempty"`
Header *Header `json:"header,omitempty"`
SubNo string `json:"subNo"`
ChannelID string `json:"channelId"`
ChannelType int `json:"channelType"`
@@ -173,8 +174,8 @@ type SubscriptionResult struct {
// --- Specific Notification Payloads (Params) ---
type RecvNotificationParams struct {
Header Header `json:"header,omitempty"`
Setting SettingFlags `json:"setting,omitempty"`
Header *Header `json:"header,omitempty"`
Setting *SettingFlags `json:"setting,omitempty"`
MsgKey string `json:"msgKey,omitempty"`
Expire uint32 `json:"expire,omitempty"`
MessageID string `json:"messageId"`
@@ -317,9 +318,15 @@ func (h Header) ToProto() *wkproto.Framer {
// ToProto converts JSON-RPC ConnectParams to wkproto.ConnectReq
func (p ConnectParams) ToProto() *wkproto.ConnectPacket {
var version uint8 = uint8(p.Version)
if p.Version == 0 {
version = wkproto.LatestVersion
}
req := &wkproto.ConnectPacket{
Framer: headerToFramer(p.Header),
Version: uint8(p.Version),
Version: version,
ClientKey: p.ClientKey,
DeviceID: p.DeviceID,
DeviceFlag: wkproto.DeviceFlag(p.DeviceFlag),
@@ -348,12 +355,16 @@ func FromProtoConnectAck(ack *wkproto.ConnackPacket) *ConnectResult {
}
// ToProto converts JSON-RPC SendParams to wkproto.SendReq
func (p SendParams) ToProto(id string) *wkproto.SendPacket {
func (p SendParams) ToProto() *wkproto.SendPacket {
payloadBytes := []byte(p.Payload)
clientMsgNo := p.ClientMsgNo
if clientMsgNo == "" {
clientMsgNo = wkutil.GenUUID()
}
req := &wkproto.SendPacket{
Framer: headerToFramer(p.Header),
Setting: p.Setting.ToProto(),
ClientMsgNo: id,
ClientMsgNo: clientMsgNo,
ChannelID: p.ChannelID,
ChannelType: uint8(p.ChannelType),
Payload: payloadBytes,
@@ -372,11 +383,10 @@ func FromProtoSendAck(ack *wkproto.SendackPacket) *SendResult {
}
messageID := strconv.FormatInt(ack.MessageID, 10)
res := &SendResult{
Header: fromProtoHeader(ack.Framer),
MessageID: messageID,
MessageSeq: ack.MessageSeq,
ClientMsgNo: ack.ClientMsgNo,
ReasonCode: ReasonCodeEnum(ack.ReasonCode),
Header: fromProtoHeader(ack.Framer),
MessageID: messageID,
MessageSeq: ack.MessageSeq,
ReasonCode: ReasonCodeEnum(ack.ReasonCode),
}
return res
}
@@ -478,8 +488,11 @@ func FromProtoPongPacket(pkt *wkproto.PongPacket) {
// --- Reverse Helper Functions (Proto -> JSON-RPC) ---
// fromProtoHeader converts wkproto.Header to JSON-RPC Header
func fromProtoHeader(protoHeader wkproto.Framer) Header {
return Header{
func fromProtoHeader(protoHeader wkproto.Framer) *Header {
if !protoHeader.NoPersist && !protoHeader.RedDot && !protoHeader.SyncOnce && !protoHeader.DUP {
return nil
}
return &Header{
NoPersist: protoHeader.NoPersist,
RedDot: protoHeader.RedDot,
SyncOnce: protoHeader.SyncOnce,
@@ -497,8 +510,13 @@ func headerToFramer(header Header) wkproto.Framer {
}
// fromProtoSetting converts wkproto.Setting to JSON-RPC SettingFlags
func fromProtoSetting(setting wkproto.Setting) SettingFlags {
flags := SettingFlags{}
func fromProtoSetting(setting wkproto.Setting) *SettingFlags {
if setting == 0 {
return nil
}
flags := &SettingFlags{}
flags.Receipt = (setting & wkproto.SettingReceiptEnabled) != 0
flags.Signal = (setting & wkproto.SettingSignal) != 0
flags.Stream = (setting & wkproto.SettingStream) != 0
@@ -552,7 +570,27 @@ func NewRequest(method string, id string, params interface{}) interface{} {
// Helper function/type for generic response decoding later
type GenericResponse struct {
BaseResponse
Result json.RawMessage
Result json.RawMessage `json:"result,omitempty"`
}
func NewGenericResponse(id string, result json.RawMessage) GenericResponse {
return GenericResponse{
BaseResponse: BaseResponse{
Jsonrpc: jsonRPCVersion,
ID: id,
},
Result: result,
}
}
func NewGenericResponseWithErr(id string, err *ErrorObject) GenericResponse {
return GenericResponse{
BaseResponse: BaseResponse{
Jsonrpc: jsonRPCVersion,
ID: id,
Error: err,
},
}
}
// Add conversions for full Request/Response types if needed, e.g.:
@@ -577,7 +615,7 @@ func (r SendRequest) ToProto() (*wkproto.SendPacket, error) {
pkt := &wkproto.SendPacket{
Framer: headerToFramer(r.Params.Header),
Setting: r.Params.Setting.ToProto(),
ClientMsgNo: r.ID,
ClientMsgNo: r.Params.ClientMsgNo,
ChannelID: r.Params.ChannelID,
ChannelType: uint8(r.Params.ChannelType),
Payload: payloadBytes,

View File

@@ -181,7 +181,6 @@
"header": { "$ref": "#/components/schemas/Header" },
"messageId": { "type": "string", "description": "服务端消息ID" },
"messageSeq": { "type": "integer", "format": "uint32", "description": "服务端消息序列号" },
"clientMsgNo": { "type": "string", "description": "客户端消息编号(原样返回,可选)" },
"reasonCode": { "$ref": "#/components/schemas/ReasonCodeEnum" }
},
"required": ["messageId", "messageSeq", "reasonCode"]
@@ -189,7 +188,7 @@
"error": { "$ref": "#/components/schemas/ErrorObject" },
"id": { "type": "string" }
},
"required": ["method"]
"required": ["id"]
},
"RecvNotification": {
"type": "object",

View File

@@ -118,6 +118,7 @@ type Conn interface {
type IWSConn interface {
WriteServerBinary(data []byte) error
WriteServerText(data []byte) error
}
type DefaultConn struct {

View File

@@ -70,6 +70,12 @@ func (w *WSConn) WriteServerBinary(data []byte) error {
return wsutil.WriteServerBinary(w.outboundBuffer, data)
}
func (w *WSConn) WriteServerText(data []byte) error {
w.mu.Lock()
defer w.mu.Unlock()
return wsutil.WriteServerText(w.outboundBuffer, data)
}
// 解包ws的数据
func (w *WSConn) unpacketWSData() error {
@@ -407,6 +413,12 @@ func (w *WSSConn) WriteServerBinary(data []byte) error {
return wsutil.WriteServerBinary(w.TLSConn, data)
}
func (w *WSSConn) WriteServerText(data []byte) error {
w.d.mu.Lock()
defer w.d.mu.Unlock()
return wsutil.WriteServerText(w.TLSConn, data)
}
func (w *WSSConn) decode() ([]wsutil.Message, error) {
buff, err := w.peekFromWSTemp(-1)
if err != nil {