mirror of
https://gitee.com/samwaf/SamWaf.git
synced 2025-12-06 06:58:54 +08:00
@@ -34,6 +34,7 @@ type APIGroup struct {
|
||||
WafHttpAuthBaseApi
|
||||
WafTaskApi
|
||||
WafBlockingPageApi
|
||||
WafGPTApi
|
||||
}
|
||||
|
||||
var APIGroupAPP = new(APIGroup)
|
||||
|
||||
226
api/waf_gpt.go
Normal file
226
api/waf_gpt.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"SamWaf/global"
|
||||
"SamWaf/model"
|
||||
"SamWaf/model/request"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WafGPTApi struct {
|
||||
}
|
||||
|
||||
// 新增用于解析流式响应的结构体
|
||||
type StreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []StreamChoice `json:"choices"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Object string `json:"object"`
|
||||
Usage *TokenUsage `json:"usage,omitempty"` // 只有最后一条消息包含
|
||||
}
|
||||
|
||||
type StreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"` // 使用指针类型处理 null
|
||||
Logprobs interface{} `json:"logprobs"`
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Content string `json:"content"` // 内容增量
|
||||
Role *string `json:"role,omitempty"` // 使用指针处理 null
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// SendDeltaMessage 发送信息
|
||||
func SendDeltaMessage(messageChan chan<- string, content string, role ...string) {
|
||||
// 设置默认角色为 assistant
|
||||
r := "assistant"
|
||||
if len(role) > 0 {
|
||||
r = role[0]
|
||||
}
|
||||
|
||||
// 创建消息结构
|
||||
msg := Delta{
|
||||
Content: content,
|
||||
Role: &r,
|
||||
}
|
||||
|
||||
// 序列化并发送
|
||||
if payload, err := json.Marshal(msg); err == nil {
|
||||
messageChan <- string(payload)
|
||||
}
|
||||
}
|
||||
func (w *WafGPTApi) ChatApi(c *gin.Context) {
|
||||
|
||||
var req request.WafGptSendReq
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err == nil {
|
||||
// 创建一个取消信号通道,用于触发异常退出
|
||||
stopChan := make(chan bool)
|
||||
messageChan := make(chan string)
|
||||
|
||||
// 启动一个 goroutine,发送流请求并推送时间
|
||||
go func() {
|
||||
defer close(stopChan)
|
||||
defer close(messageChan)
|
||||
|
||||
// 构造基础消息数组
|
||||
messages := []model.GptMessage{
|
||||
{
|
||||
Content: "你是一位信息安全专家,输出如下格式:\n\n风险等级: 0-100\n风险类型:某种注入,跨站等\n风险说明:对风险的阐释",
|
||||
Role: "user",
|
||||
},
|
||||
}
|
||||
// 将History内容转换为消息并追加
|
||||
for _, historyItem := range req.History {
|
||||
if len(historyItem) < 2 {
|
||||
continue // 跳过无效条目
|
||||
}
|
||||
if historyItem[1] == "远程服务器未返回信息,请检查配置" {
|
||||
continue
|
||||
}
|
||||
messages = append(messages, model.GptMessage{
|
||||
Role: historyItem[0], // 角色类型(system/user/assistant)
|
||||
Content: historyItem[1], // 对话内容
|
||||
})
|
||||
}
|
||||
gptReq := model.GPTRequest{
|
||||
Messages: messages,
|
||||
Model: global.GCONFIG_RECORD_GPT_MODEL,
|
||||
FrequencyPenalty: 0,
|
||||
MaxTokens: 2048,
|
||||
PresencePenalty: 0,
|
||||
ResponseFormat: model.GptResponseFormat{Type: "text"},
|
||||
Stop: nil,
|
||||
Stream: true,
|
||||
Temperature: 1,
|
||||
TopP: 1,
|
||||
}
|
||||
|
||||
// 序列化为JSON字符串
|
||||
bodyBytes, _ := json.Marshal(gptReq)
|
||||
requestBody := string(bodyBytes)
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequest("POST", global.GCONFIG_RECORD_GPT_URL+"/v1/chat/completions", strings.NewReader(requestBody))
|
||||
if err != nil {
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Authorization", "Bearer "+global.GCONFIG_RECORD_GPT_TOKEN)
|
||||
|
||||
// 创建 HTTP 客户端并发送请求
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
SendDeltaMessage(messageChan, fmt.Sprintf("访问报错%v", err.Error()), "assistant")
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取流
|
||||
// 创建带缓冲的读取器
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var buffer bytes.Buffer
|
||||
var residual []byte
|
||||
|
||||
for {
|
||||
// 读取数据块
|
||||
chunk := make([]byte, 1024)
|
||||
n, err := reader.Read(chunk)
|
||||
if err != nil && err != io.EOF {
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
|
||||
// 合并残留数据和新数据
|
||||
buffer.Write(append(residual, chunk[:n]...))
|
||||
residual = nil
|
||||
|
||||
// 分割数据包
|
||||
for {
|
||||
line, err := buffer.ReadBytes('\n')
|
||||
if err == io.EOF {
|
||||
// 判断残留数据中是否有错误信息
|
||||
lineStr := strings.TrimSpace(string(line))
|
||||
if strings.Contains(lineStr, `"error":`) {
|
||||
SendDeltaMessage(messageChan, fmt.Sprintf("Error: %s", lineStr), "assistant")
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
residual = line
|
||||
break
|
||||
}
|
||||
|
||||
// 处理单行数据
|
||||
lineStr := strings.TrimSpace(string(line))
|
||||
if strings.HasPrefix(lineStr, "data: ") {
|
||||
content := strings.TrimPrefix(lineStr, "data: ")
|
||||
|
||||
// 处理流结束标记
|
||||
if content == "[DONE]" {
|
||||
SendDeltaMessage(messageChan, "[DONE]", "assistant")
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
|
||||
// 解析JSON数据
|
||||
var response StreamResponse
|
||||
if err := json.Unmarshal([]byte(content), &response); err != nil {
|
||||
continue // 忽略解析错误
|
||||
}
|
||||
|
||||
// 处理消息内容
|
||||
for _, choice := range response.Choices {
|
||||
// 发送内容增量
|
||||
if choice.Delta.Content != "" {
|
||||
SendDeltaMessage(messageChan, choice.Delta.Content, "assistant")
|
||||
}
|
||||
|
||||
// 处理停止条件
|
||||
if choice.FinishReason != nil && *choice.FinishReason == "stop" {
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
// 判断是否接收到停止信号
|
||||
select {
|
||||
case <-stopChan:
|
||||
return false // 退出流式推送
|
||||
case message := <-messageChan:
|
||||
c.SSEvent("message", message) // 发送事件到客户端
|
||||
return true // 继续推送
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -191,7 +192,12 @@ func (w *WafLogAPi) GetHttpCopyMaskApi(c *gin.Context) {
|
||||
}
|
||||
wafLog, _ := wafLogService.GetDetailApi(req)
|
||||
|
||||
response.OkWithDetailed(GenerateCurlRequest(wafLog), "获取成功", c)
|
||||
if req.OutputFormat == "curl" {
|
||||
response.OkWithDetailed(GenerateCurlRequest(wafLog), "获取成功", c)
|
||||
} else {
|
||||
response.OkWithDetailed(GenerateRawHTTPRequest(wafLog), "获取成功", c)
|
||||
}
|
||||
|
||||
} else {
|
||||
response.FailWithMessage("解析失败", c)
|
||||
}
|
||||
@@ -229,6 +235,86 @@ func (w *WafLogAPi) GetAllIpTagApi(c *gin.Context) {
|
||||
response.OkWithDetailed(ipAttackTags, "获取成功", c)
|
||||
}
|
||||
}
|
||||
func GenerateRawHTTPRequest(weblog innerbean.WebLog) string {
|
||||
parsedURL, err := url.Parse(weblog.URL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 构建请求行
|
||||
pathWithQuery := parsedURL.Path
|
||||
if parsedURL.RawQuery != "" {
|
||||
pathWithQuery += "?" + parsedURL.RawQuery
|
||||
}
|
||||
|
||||
// 根据协议确定 HTTP 版本
|
||||
httpVersion := "HTTP/1.1"
|
||||
if weblog.Scheme != "" {
|
||||
httpVersion = weblog.Scheme
|
||||
}
|
||||
|
||||
// 处理敏感头信息
|
||||
maskedHeaders := maskSensitiveHeader(weblog.HEADER)
|
||||
headers := strings.Split(maskedHeaders, "\n")
|
||||
|
||||
// 处理 Cookie
|
||||
maskedCookies := maskSensitiveCookies(weblog.COOKIES)
|
||||
if maskedCookies != "" {
|
||||
cookieHeader := fmt.Sprintf("Cookie: %s", maskedCookies)
|
||||
// 替换或添加 Cookie 头
|
||||
cookieFound := false
|
||||
for i, h := range headers {
|
||||
if strings.HasPrefix(strings.TrimSpace(h), "Cookie:") {
|
||||
headers[i] = cookieHeader
|
||||
cookieFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !cookieFound {
|
||||
headers = append(headers, cookieHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 Host 头存在
|
||||
host := parsedURL.Host
|
||||
if host != "" {
|
||||
hostExists := false
|
||||
for _, h := range headers {
|
||||
if strings.HasPrefix(strings.TrimSpace(strings.ToLower(h)), "host:") {
|
||||
hostExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hostExists {
|
||||
headers = append(headers, fmt.Sprintf("Host: %s", host))
|
||||
}
|
||||
}
|
||||
|
||||
// 构建最终 header
|
||||
var cleanHeaders []string
|
||||
for _, h := range headers {
|
||||
if trimmed := strings.TrimSpace(h); trimmed != "" {
|
||||
cleanHeaders = append(cleanHeaders, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建完整请求
|
||||
requestLines := []string{
|
||||
fmt.Sprintf("%s %s %s",
|
||||
weblog.METHOD,
|
||||
pathWithQuery,
|
||||
httpVersion,
|
||||
),
|
||||
}
|
||||
requestLines = append(requestLines, cleanHeaders...)
|
||||
|
||||
// 添加 body(如果有)
|
||||
if weblog.BODY != "" {
|
||||
requestLines = append(requestLines, "", weblog.BODY)
|
||||
}
|
||||
|
||||
return strings.Join(requestLines, "\n")
|
||||
}
|
||||
func GenerateCurlRequest(weblog innerbean.WebLog) string {
|
||||
|
||||
headers := strings.Split(weblog.HEADER, "\n")
|
||||
|
||||
@@ -32,4 +32,8 @@ var (
|
||||
|
||||
GCONFIG_RECORD_DEBUG_ENABLE int64 = 0 //调试开关 默认关闭
|
||||
GCONFIG_RECORD_DEBUG_PWD string = "" //调试密码 如果未空则不需要密码
|
||||
|
||||
GCONFIG_RECORD_GPT_URL string = "https://api.deepseek.com" //GPT远程地址 DeepSeek ChatGpt 以及使用one-api封装好的接口
|
||||
GCONFIG_RECORD_GPT_TOKEN string = "SamWaf提示请输入密钥" //GPT远程授权密钥
|
||||
GCONFIG_RECORD_GPT_MODEL string = "deepseek-chat" //GPT 模型名称
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ type WebLog struct {
|
||||
SrcByteBody []byte `json:"src_byte_body"` //原始body信息
|
||||
SrcByteResBody []byte `json:"src_byte_res_body"` //返回body bytes信息
|
||||
WebLogVersion int `json:"web_log_version"` //日志版本信息早期的是空和0,后期实时增加
|
||||
Scheme string `json:"scheme"` //HTTP 协议
|
||||
}
|
||||
|
||||
// 在 GORM 的 Model 方法中定义复合索引
|
||||
|
||||
13
middleware/stream_middleware.go
Normal file
13
middleware/stream_middleware.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func StreamMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
23
model/gpt.go
Normal file
23
model/gpt.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package model
|
||||
|
||||
type GptMessage struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Unimportant bool `json:"unimportant"`
|
||||
}
|
||||
|
||||
type GptResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
type GPTRequest struct {
|
||||
Messages []GptMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
ResponseFormat GptResponseFormat `json:"response_format"`
|
||||
Stop interface{} `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
TopP float64 `json:"top_p"`
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import "SamWaf/model/common/request"
|
||||
type WafAttackLogDetailReq struct {
|
||||
CurrrentDbName string `json:"current_db_name"`
|
||||
REQ_UUID string `json:"req_uuid"`
|
||||
OutputFormat string `json:"output_format"` //输出格式 raw,curl
|
||||
}
|
||||
|
||||
type WafAttackLogDoExport struct {
|
||||
|
||||
5
model/request/waf_gpt_send_req.go
Normal file
5
model/request/waf_gpt_send_req.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package request
|
||||
|
||||
type WafGptSendReq struct {
|
||||
History [][]string `json:"history"` // 定义结构体
|
||||
}
|
||||
@@ -32,6 +32,7 @@ type ApiGroup struct {
|
||||
WafHttpAuthBaseRouter
|
||||
WafTaskRouter
|
||||
WafBlockingPageRouter
|
||||
WafGPTRouter
|
||||
}
|
||||
type PublicApiGroup struct {
|
||||
LoginRouter
|
||||
|
||||
15
router/waf_gpt.go
Normal file
15
router/waf_gpt.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"SamWaf/api"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type WafGPTRouter struct {
|
||||
}
|
||||
|
||||
func (receiver *WafGPTRouter) InitGPTRouter(group *gin.RouterGroup) {
|
||||
api := api.APIGroupAPP.WafGPTApi
|
||||
router := group.Group("")
|
||||
router.POST("/samwaf/gpt/chat", api.ChatApi)
|
||||
}
|
||||
@@ -217,6 +217,7 @@ func (waf *WafEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
NetSrcIp: utils.GetSourceClientIP(r.RemoteAddr),
|
||||
SrcByteBody: bodyByte,
|
||||
WebLogVersion: global.GWEBLOG_VERSION,
|
||||
Scheme: r.Proto,
|
||||
}
|
||||
|
||||
formValues := url.Values{}
|
||||
@@ -447,6 +448,7 @@ func (waf *WafEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
TimeSpent: 0,
|
||||
NetSrcIp: utils.GetSourceClientIP(r.RemoteAddr),
|
||||
WebLogVersion: global.GWEBLOG_VERSION,
|
||||
Scheme: r.Proto,
|
||||
}
|
||||
|
||||
//记录响应body
|
||||
|
||||
@@ -64,8 +64,11 @@ func (web *WafWebManager) initRouter(r *gin.Engine) {
|
||||
router.ApiGroupApp.InitWafTaskRouter(RouterGroup)
|
||||
router.ApiGroupApp.InitWafBlockingPageRouter(RouterGroup)
|
||||
|
||||
gptRouterGroup := r.Group("")
|
||||
gptRouterGroup.Use(middleware.StreamMiddleware())
|
||||
router.ApiGroupApp.InitGPTRouter(gptRouterGroup)
|
||||
}
|
||||
//r.Use(middleware.GinGlobalExceptionMiddleWare())
|
||||
|
||||
if global.GWAF_RELEASE == "true" {
|
||||
static.Static(r, func(handlers ...gin.HandlerFunc) {
|
||||
r.NoRoute(handlers...)
|
||||
|
||||
@@ -12,21 +12,28 @@ func setConfigIntValue(name string, value int64, change int) {
|
||||
switch name {
|
||||
case "record_max_req_body_length":
|
||||
global.GCONFIG_RECORD_MAX_BODY_LENGTH = value
|
||||
break
|
||||
case "record_max_res_body_length":
|
||||
global.GCONFIG_RECORD_MAX_RES_BODY_LENGTH = value
|
||||
break
|
||||
case "record_resp":
|
||||
global.GCONFIG_RECORD_RESP = value
|
||||
break
|
||||
case "delete_history_log_day":
|
||||
global.GDATA_DELETE_INTERVAL = value
|
||||
break
|
||||
case "log_db_size":
|
||||
global.GDATA_SHARE_DB_SIZE = value
|
||||
break
|
||||
case "auto_load_ssl_file":
|
||||
global.GCONFIG_RECORD_AUTO_LOAD_SSL = value
|
||||
break
|
||||
case "kafka_enable":
|
||||
if global.GCONFIG_RECORD_KAFKA_ENABLE != value && global.GNOTIFY_KAKFA_SERVICE != nil {
|
||||
global.GNOTIFY_KAKFA_SERVICE.ChangeEnable(value)
|
||||
}
|
||||
global.GCONFIG_RECORD_KAFKA_ENABLE = value
|
||||
break
|
||||
case "redirect_https_code":
|
||||
global.GCONFIG_RECORD_REDIRECT_HTTPS_CODE = value
|
||||
break
|
||||
@@ -76,21 +83,37 @@ func setConfigStringValue(name string, value string, change int) {
|
||||
switch name {
|
||||
case "dns_server":
|
||||
global.GWAF_RUNTIME_DNS_SERVER = value
|
||||
break
|
||||
case "record_log_type":
|
||||
global.GWAF_RUNTIME_RECORD_LOG_TYPE = value
|
||||
break
|
||||
case "gwaf_center_enable":
|
||||
global.GWAF_CENTER_ENABLE = value
|
||||
break
|
||||
case "gwaf_center_url":
|
||||
global.GWAF_CENTER_URL = value
|
||||
break
|
||||
case "gwaf_proxy_header":
|
||||
global.GCONFIG_RECORD_PROXY_HEADER = value
|
||||
break
|
||||
case "kafka_url":
|
||||
global.GCONFIG_RECORD_KAFKA_URL = value
|
||||
break
|
||||
case "kafka_topic":
|
||||
global.GCONFIG_RECORD_KAFKA_TOPIC = value
|
||||
break
|
||||
case "debug_pwd":
|
||||
global.GCONFIG_RECORD_DEBUG_PWD = value
|
||||
|
||||
break
|
||||
case "gpt_url":
|
||||
global.GCONFIG_RECORD_GPT_URL = value
|
||||
break
|
||||
case "gpt_token":
|
||||
global.GCONFIG_RECORD_GPT_TOKEN = value
|
||||
break
|
||||
case "gpt_model":
|
||||
global.GCONFIG_RECORD_GPT_MODEL = value
|
||||
break
|
||||
default:
|
||||
zlog.Warn("Unknown config item:", name)
|
||||
}
|
||||
@@ -184,4 +207,8 @@ func TaskLoadSetting(initLoad bool) {
|
||||
updateConfigIntItem(initLoad, "debug", "enable_debug", global.GCONFIG_RECORD_DEBUG_ENABLE, "调试开关 默认关闭", "int", "")
|
||||
updateConfigStringItem(initLoad, "debug", "debug_pwd", global.GCONFIG_RECORD_DEBUG_PWD, "调试密码 如果未空则不需要密码", "string", "")
|
||||
|
||||
updateConfigStringItem(initLoad, "gpt", "gpt_url", global.GCONFIG_RECORD_GPT_URL, "GPT远程地址 默认:DeepSeek ,符合ChatGpt或者使用one-api封装好的接口都可以", "string", "")
|
||||
updateConfigStringItem(initLoad, "gpt", "gpt_token", global.GCONFIG_RECORD_GPT_TOKEN, "GPT远程授权密钥", "string", "")
|
||||
updateConfigStringItem(initLoad, "gpt", "gpt_model", global.GCONFIG_RECORD_GPT_MODEL, "GPT模型名称", "string", "")
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user