Files
rulego/test/test_rule_context.go
2025-08-23 00:13:40 +08:00

498 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* Copyright 2023 The RuleGo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/rulego/rulego/api/types"
"github.com/rulego/rulego/utils/cache"
)
var _ types.RuleContext = (*NodeTestRuleContext)(nil)
// NodeTestRuleContext
// 只为测试单节点,临时创建的上下文
// 无法把多个节点组成链式
// callback 回调处理结果
type NodeTestRuleContext struct {
context context.Context
config types.Config
callback func(msg types.RuleMsg, relationType string, err error)
self types.Node
selfId string
//所有子节点处理完成事件,只执行一次
onAllNodeCompleted func()
onEndFunc types.OnEndFunc
childrenNodes sync.Map
out types.RuleMsg
globalCache types.Cache
chainCache types.Cache
mutex sync.RWMutex // Add mutex for thread safety
}
func (ctx *NodeTestRuleContext) GlobalCache() types.Cache {
return ctx.globalCache
}
func (ctx *NodeTestRuleContext) ChainCache() types.Cache {
return ctx.chainCache
}
func NewRuleContext(config types.Config, callback func(msg types.RuleMsg, relationType string, err error)) types.RuleContext {
globalCache := cache.NewMemoryCache(time.Minute * 5)
return &NodeTestRuleContext{
context: context.TODO(),
config: config,
callback: callback,
globalCache: globalCache,
chainCache: cache.NewNamespaceCache(globalCache, "test"),
}
}
func NewRuleContextFull(config types.Config, self types.Node, childrenNodes map[string]types.Node, callback func(msg types.RuleMsg, relationType string, err error)) types.RuleContext {
ctx := &NodeTestRuleContext{
config: config,
self: self,
callback: callback,
context: context.TODO(),
globalCache: config.Cache,
chainCache: cache.NewNamespaceCache(config.Cache, "test"),
}
for k, v := range childrenNodes {
ctx.childrenNodes.Store(k, v)
}
return ctx
}
func (ctx *NodeTestRuleContext) TellSuccess(msg types.RuleMsg) {
ctx.mutex.RLock()
callback := ctx.callback
onEndFunc := ctx.onEndFunc
ctx.mutex.RUnlock()
if callback != nil {
callback(msg, types.Success, nil)
}
if onEndFunc != nil {
onEndFunc(ctx, msg, nil, types.Success)
}
}
func (ctx *NodeTestRuleContext) TellFailure(msg types.RuleMsg, err error) {
ctx.mutex.RLock()
callback := ctx.callback
onEndFunc := ctx.onEndFunc
ctx.mutex.RUnlock()
if callback != nil {
callback(msg, types.Failure, err)
}
if onEndFunc != nil {
onEndFunc(ctx, msg, err, types.Failure)
}
}
func (ctx *NodeTestRuleContext) TellNext(msg types.RuleMsg, relationTypes ...string) {
ctx.mutex.RLock()
callback := ctx.callback
onEndFunc := ctx.onEndFunc
ctx.mutex.RUnlock()
for _, relationType := range relationTypes {
if callback != nil {
callback(msg, relationType, nil)
}
if onEndFunc != nil {
onEndFunc(ctx, msg, nil, relationType)
}
}
}
func (ctx *NodeTestRuleContext) TellSelf(msg types.RuleMsg, delayMs int64) {
time.AfterFunc(time.Millisecond*time.Duration(delayMs), func() {
if ctx.self != nil {
ctx.self.OnMsg(ctx, msg)
}
})
}
func (ctx *NodeTestRuleContext) TellNextOrElse(msg types.RuleMsg, defaultRelationType string, relationTypes ...string) {
ctx.TellNext(msg, relationTypes...)
}
func (ctx *NodeTestRuleContext) NewMsg(msgType string, metaData *types.Metadata, data string) types.RuleMsg {
return types.NewMsg(0, msgType, types.JSON, metaData, data)
}
func (ctx *NodeTestRuleContext) GetSelfId() string {
ctx.mutex.RLock()
defer ctx.mutex.RUnlock()
return ctx.selfId
}
func (ctx *NodeTestRuleContext) Self() types.NodeCtx {
return nil
}
func (ctx *NodeTestRuleContext) From() types.NodeCtx {
return nil
}
func (ctx *NodeTestRuleContext) RuleChain() types.NodeCtx {
return nil
}
func (ctx *NodeTestRuleContext) Config() types.Config {
return ctx.config
}
func (ctx *NodeTestRuleContext) SubmitTack(task func()) {
ctx.SubmitTask(task)
}
func (ctx *NodeTestRuleContext) SubmitTask(task func()) {
go task()
}
func (ctx *NodeTestRuleContext) SetEndFunc(onEndFunc types.OnEndFunc) types.RuleContext {
ctx.mutex.Lock()
defer ctx.mutex.Unlock()
ctx.onEndFunc = onEndFunc
return ctx
}
func (ctx *NodeTestRuleContext) GetEndFunc() types.OnEndFunc {
ctx.mutex.RLock()
defer ctx.mutex.RUnlock()
return ctx.onEndFunc
}
func (ctx *NodeTestRuleContext) SetContext(c context.Context) types.RuleContext {
ctx.context = c
return ctx
}
func (ctx *NodeTestRuleContext) GetContext() context.Context {
return ctx.context
}
func (ctx *NodeTestRuleContext) TellFlow(chainId string, msg types.RuleMsg, opts ...types.RuleContextOption) {
for _, opt := range opts {
opt(ctx)
}
if chainId == "" {
if ctx.onEndFunc != nil {
ctx.onEndFunc(ctx, msg, errors.New("chainId can not nil"), types.Failure)
}
} else if chainId == "notfound" {
if ctx.onEndFunc != nil {
ctx.onEndFunc(ctx, msg, fmt.Errorf("ruleChain id=%s not found", chainId), types.Failure)
}
if ctx.onAllNodeCompleted != nil {
ctx.onAllNodeCompleted()
}
} else if chainId == "toTrue" {
if ctx.onEndFunc != nil {
ctx.onEndFunc(ctx, msg, nil, types.True)
}
if ctx.onAllNodeCompleted != nil {
ctx.onAllNodeCompleted()
}
} else {
if ctx.onEndFunc != nil {
ctx.onEndFunc(ctx, msg, nil, types.Success)
}
if ctx.onAllNodeCompleted != nil {
ctx.onAllNodeCompleted()
}
}
}
// TellNode 独立执行某个节点通过callback获取节点执行情况用于节点分组类节点控制执行某个节点
func (ctx *NodeTestRuleContext) TellNode(context context.Context, nodeId string, msg types.RuleMsg, skipTellNext bool, callback types.OnEndFunc, onAllNodeCompleted func()) {
if v, ok := ctx.childrenNodes.Load(nodeId); ok {
// 线程安全地设置 selfId
ctx.mutex.Lock()
ctx.selfId = nodeId
ctx.mutex.Unlock()
subCtx := NewRuleContext(ctx.config, func(msg types.RuleMsg, relationType string, err error) {
if callback != nil {
callback(ctx, msg, err, relationType)
}
if onAllNodeCompleted != nil {
onAllNodeCompleted()
}
})
v.(types.Node).OnMsg(subCtx, msg)
} else {
if callback != nil {
callback(ctx, msg, fmt.Errorf("node id=%s not found", nodeId), types.Failure)
}
if onAllNodeCompleted != nil {
onAllNodeCompleted()
}
}
}
// TellChainNode 独立执行某个节点通过callback获取节点执行情况用于节点分组类节点控制执行某个节点
func (ctx *NodeTestRuleContext) TellChainNode(context context.Context, chainId string, nodeId string, msg types.RuleMsg, skipTellNext bool, callback types.OnEndFunc, onAllNodeCompleted func()) {
ctx.TellNode(context, nodeId, msg, skipTellNext, callback, onAllNodeCompleted)
}
// SetOnAllNodeCompleted 设置所有节点执行完回调
func (ctx *NodeTestRuleContext) SetOnAllNodeCompleted(onAllNodeCompleted func()) {
ctx.onAllNodeCompleted = onAllNodeCompleted
}
func (ctx *NodeTestRuleContext) DoOnEnd(msg types.RuleMsg, err error, relationType string) {
}
// SetCallbackFunc 设置回调函数
func (ctx *NodeTestRuleContext) SetCallbackFunc(functionName string, f interface{}) {
}
// GetCallbackFunc 获取回调函数
func (ctx *NodeTestRuleContext) GetCallbackFunc(functionName string) interface{} {
return nil
}
// OnDebug 调用配置的OnDebug回调函数
func (ctx *NodeTestRuleContext) OnDebug(ruleChainId string, flowType string, nodeId string, msg types.RuleMsg, relationType string, err error) {
}
func (ctx *NodeTestRuleContext) SetExecuteNode(nodeId string, relationTypes ...string) {
}
func (ctx *NodeTestRuleContext) TellCollect(msg types.RuleMsg, callback func(msgList []types.WrapperMsg)) bool {
callback(nil)
return true
}
func (ctx *NodeTestRuleContext) GetOut() types.RuleMsg {
ctx.mutex.RLock()
defer ctx.mutex.RUnlock()
return ctx.out
}
func (ctx *NodeTestRuleContext) GetRelationTypes() []string {
return nil
}
// setOut safely sets the out field
func (ctx *NodeTestRuleContext) setOut(msg types.RuleMsg) {
ctx.mutex.Lock()
defer ctx.mutex.Unlock()
ctx.out = msg
}
func (ctx *NodeTestRuleContext) GetErr() error {
return nil
}
func (ctx *NodeTestRuleContext) TellStream(msg types.RuleMsg) {
ctx.TellNext(msg, types.Stream)
}
// GetEnv 获取环境变量和元数据
func (ctx *NodeTestRuleContext) GetEnv(msg types.RuleMsg, useMetadata bool) map[string]interface{} {
// 创建环境变量map
envVars := make(map[string]interface{})
// 设置基础环境变量
envVars["id"] = msg.GetId()
envVars["ts"] = msg.GetTs()
envVars["data"] = msg.GetData()
envVars["msgType"] = msg.GetType()
envVars["type"] = msg.GetType()
envVars["dataType"] = string(msg.GetDataType())
// 使用 GetJsonData() 避免重复JSON解析
if msg.DataType == types.JSON {
if jsonData, err := msg.GetJsonData(); err == nil {
envVars[types.MsgKey] = jsonData
} else {
// 解析失败,使用原始数据
envVars[types.MsgKey] = msg.GetData()
}
} else {
// 如果不是 JSON 类型,直接使用原始数据
envVars[types.MsgKey] = msg.GetData()
}
// 优化 metadata 处理
if msg.Metadata != nil {
if useMetadata {
// 遍历metadata将键值对添加到环境变量中 - use zero-copy ForEach
msg.Metadata.ForEach(func(k, v string) bool {
envVars[k] = v
return true // continue iteration
})
}
envVars[types.MetadataKey] = msg.Metadata.Values()
}
return envVars
}
// GetNodeRuleMsg 获取节点的完整消息信息(测试上下文中暂不支持跨节点取值)
// GetNodeRuleMsg retrieves the complete RuleMsg of a node (not supported in test context)
func (ctx *NodeTestRuleContext) GetNodeRuleMsg(nodeId string) (types.RuleMsg, bool) {
return types.RuleMsg{}, false
}
// ExtendedTestRuleContext 扩展的测试上下文,支持结果收集和节点处理器设置
// 可以替代 SimpleTestContext 和 MockRuleContext
type ExtendedTestRuleContext struct {
*NodeTestRuleContext
nodeHandlers map[string]func(msg types.RuleMsg) (string, error)
results []string
resultsChan chan TestResult
handlerMutex sync.RWMutex
}
// TestResult 测试结果结构
type TestResult struct {
RelationType string
Err error
}
// NewExtendedTestRuleContext 创建扩展的测试上下文
// 用于替代 SimpleTestContext 和 MockRuleContext
func NewExtendedTestRuleContext(config types.Config, callback func(msg types.RuleMsg, relationType string, err error)) *ExtendedTestRuleContext {
baseCtx := NewRuleContext(config, callback).(*NodeTestRuleContext)
return &ExtendedTestRuleContext{
NodeTestRuleContext: baseCtx,
nodeHandlers: make(map[string]func(msg types.RuleMsg) (string, error)),
results: make([]string, 0),
resultsChan: make(chan TestResult, 10),
}
}
// NewExtendedTestRuleContextWithChannel 创建带结果通道的扩展测试上下文
// 主要用于替代 SimpleTestContext
func NewExtendedTestRuleContextWithChannel() *ExtendedTestRuleContext {
config := types.NewConfig()
baseCtx := NewRuleContext(config, nil).(*NodeTestRuleContext)
return &ExtendedTestRuleContext{
NodeTestRuleContext: baseCtx,
nodeHandlers: make(map[string]func(msg types.RuleMsg) (string, error)),
results: make([]string, 0),
resultsChan: make(chan TestResult, 10),
}
}
// SetNodeHandler 设置节点处理器,用于模拟节点行为
// 替代 MockRuleContext 的 SetNodeHandler 方法
func (ctx *ExtendedTestRuleContext) SetNodeHandler(nodeId string, handler func(msg types.RuleMsg) (string, error)) {
ctx.handlerMutex.Lock()
defer ctx.handlerMutex.Unlock()
ctx.nodeHandlers[nodeId] = handler
}
// GetResults 获取收集的结果
// 替代 MockRuleContext 的 GetResults 方法
func (ctx *ExtendedTestRuleContext) GetResults() []string {
ctx.mutex.RLock()
defer ctx.mutex.RUnlock()
results := make([]string, len(ctx.results))
copy(results, ctx.results)
return results
}
// GetResultsChannel 获取结果通道
// 用于替代 SimpleTestContext 的 results 通道
func (ctx *ExtendedTestRuleContext) GetResultsChannel() <-chan TestResult {
return ctx.resultsChan
}
// TellNode 重写 TellNode 方法以支持节点处理器
func (ctx *ExtendedTestRuleContext) TellNode(context context.Context, nodeId string, msg types.RuleMsg, skipTellNext bool, callback types.OnEndFunc, onAllNodeCompleted func()) {
ctx.handlerMutex.RLock()
handler, hasHandler := ctx.nodeHandlers[nodeId]
ctx.handlerMutex.RUnlock()
if hasHandler {
// 使用自定义处理器(模拟节点行为)
go func() {
relationType, err := handler(msg)
if callback != nil {
callback(ctx, msg, err, relationType)
}
if onAllNodeCompleted != nil {
onAllNodeCompleted()
}
}()
} else {
// 使用原有的 TellNode 逻辑
ctx.NodeTestRuleContext.TellNode(context, nodeId, msg, skipTellNext, callback, onAllNodeCompleted)
}
}
// TellNext 重写以支持结果收集
func (ctx *ExtendedTestRuleContext) TellNext(msg types.RuleMsg, relationTypes ...string) {
// 调用原有逻辑
ctx.NodeTestRuleContext.TellNext(msg, relationTypes...)
// 收集结果
if len(relationTypes) > 0 {
ctx.mutex.Lock()
ctx.results = append(ctx.results, relationTypes[0])
ctx.mutex.Unlock()
// 发送到结果通道
select {
case ctx.resultsChan <- TestResult{RelationType: relationTypes[0], Err: nil}:
default:
}
}
}
// TellSuccess 重写以支持结果收集
func (ctx *ExtendedTestRuleContext) TellSuccess(msg types.RuleMsg) {
// 调用原有逻辑
ctx.NodeTestRuleContext.TellSuccess(msg)
// 收集结果
ctx.mutex.Lock()
ctx.results = append(ctx.results, "Success")
ctx.mutex.Unlock()
// 发送到结果通道
select {
case ctx.resultsChan <- TestResult{RelationType: "Success", Err: nil}:
default:
}
}
// TellFailure 重写以支持结果收集
func (ctx *ExtendedTestRuleContext) TellFailure(msg types.RuleMsg, err error) {
// 调用原有逻辑
ctx.NodeTestRuleContext.TellFailure(msg, err)
// 收集结果
ctx.mutex.Lock()
ctx.results = append(ctx.results, "Failure")
ctx.mutex.Unlock()
// 发送到结果通道
select {
case ctx.resultsChan <- TestResult{RelationType: "Failure", Err: err}:
default:
}
}