feat:init deepseek chatgpt

#145
This commit is contained in:
samwaf
2025-02-17 15:51:44 +08:00
parent 77af5d4b5c
commit f542b99182
14 changed files with 411 additions and 3 deletions

View File

@@ -34,6 +34,7 @@ type APIGroup struct {
WafHttpAuthBaseApi
WafTaskApi
WafBlockingPageApi
WafGPTApi
}
var APIGroupAPP = new(APIGroup)

226
api/waf_gpt.go Normal file
View File

@@ -0,0 +1,226 @@
package api
import (
"SamWaf/global"
"SamWaf/model"
"SamWaf/model/request"
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"strings"
)
type WafGPTApi struct {
}
// 新增用于解析流式响应的结构体
type StreamResponse struct {
ID string `json:"id"`
Choices []StreamChoice `json:"choices"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Object string `json:"object"`
Usage *TokenUsage `json:"usage,omitempty"` // 只有最后一条消息包含
}
type StreamChoice struct {
Index int `json:"index"`
Delta Delta `json:"delta"`
FinishReason *string `json:"finish_reason"` // 使用指针类型处理 null
Logprobs interface{} `json:"logprobs"`
}
type Delta struct {
Content string `json:"content"` // 内容增量
Role *string `json:"role,omitempty"` // 使用指针处理 null
}
type TokenUsage struct {
CompletionTokens int `json:"completion_tokens"`
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
// SendDeltaMessage 发送信息
func SendDeltaMessage(messageChan chan<- string, content string, role ...string) {
// 设置默认角色为 assistant
r := "assistant"
if len(role) > 0 {
r = role[0]
}
// 创建消息结构
msg := Delta{
Content: content,
Role: &r,
}
// 序列化并发送
if payload, err := json.Marshal(msg); err == nil {
messageChan <- string(payload)
}
}
func (w *WafGPTApi) ChatApi(c *gin.Context) {
var req request.WafGptSendReq
err := c.ShouldBindJSON(&req)
if err == nil {
// 创建一个取消信号通道,用于触发异常退出
stopChan := make(chan bool)
messageChan := make(chan string)
// 启动一个 goroutine发送流请求并推送时间
go func() {
defer close(stopChan)
defer close(messageChan)
// 构造基础消息数组
messages := []model.GptMessage{
{
Content: "你是一位信息安全专家,输出如下格式:\n\n风险等级: 0-100\n风险类型:某种注入,跨站等\n风险说明:对风险的阐释",
Role: "user",
},
}
// 将History内容转换为消息并追加
for _, historyItem := range req.History {
if len(historyItem) < 2 {
continue // 跳过无效条目
}
if historyItem[1] == "远程服务器未返回信息,请检查配置" {
continue
}
messages = append(messages, model.GptMessage{
Role: historyItem[0], // 角色类型system/user/assistant
Content: historyItem[1], // 对话内容
})
}
gptReq := model.GPTRequest{
Messages: messages,
Model: global.GCONFIG_RECORD_GPT_MODEL,
FrequencyPenalty: 0,
MaxTokens: 2048,
PresencePenalty: 0,
ResponseFormat: model.GptResponseFormat{Type: "text"},
Stop: nil,
Stream: true,
Temperature: 1,
TopP: 1,
}
// 序列化为JSON字符串
bodyBytes, _ := json.Marshal(gptReq)
requestBody := string(bodyBytes)
// 创建请求
req, err := http.NewRequest("POST", global.GCONFIG_RECORD_GPT_URL+"/v1/chat/completions", strings.NewReader(requestBody))
if err != nil {
stopChan <- true
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Authorization", "Bearer "+global.GCONFIG_RECORD_GPT_TOKEN)
// 创建 HTTP 客户端并发送请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
SendDeltaMessage(messageChan, fmt.Sprintf("访问报错%v", err.Error()), "assistant")
stopChan <- true
return
}
defer resp.Body.Close()
// 读取流
// 创建带缓冲的读取器
reader := bufio.NewReader(resp.Body)
var buffer bytes.Buffer
var residual []byte
for {
// 读取数据块
chunk := make([]byte, 1024)
n, err := reader.Read(chunk)
if err != nil && err != io.EOF {
stopChan <- true
return
}
// 合并残留数据和新数据
buffer.Write(append(residual, chunk[:n]...))
residual = nil
// 分割数据包
for {
line, err := buffer.ReadBytes('\n')
if err == io.EOF {
// 判断残留数据中是否有错误信息
lineStr := strings.TrimSpace(string(line))
if strings.Contains(lineStr, `"error":`) {
SendDeltaMessage(messageChan, fmt.Sprintf("Error: %s", lineStr), "assistant")
stopChan <- true
return
}
residual = line
break
}
// 处理单行数据
lineStr := strings.TrimSpace(string(line))
if strings.HasPrefix(lineStr, "data: ") {
content := strings.TrimPrefix(lineStr, "data: ")
// 处理流结束标记
if content == "[DONE]" {
SendDeltaMessage(messageChan, "[DONE]", "assistant")
stopChan <- true
return
}
// 解析JSON数据
var response StreamResponse
if err := json.Unmarshal([]byte(content), &response); err != nil {
continue // 忽略解析错误
}
// 处理消息内容
for _, choice := range response.Choices {
// 发送内容增量
if choice.Delta.Content != "" {
SendDeltaMessage(messageChan, choice.Delta.Content, "assistant")
}
// 处理停止条件
if choice.FinishReason != nil && *choice.FinishReason == "stop" {
stopChan <- true
return
}
}
}
}
if err == io.EOF {
break
}
}
}()
c.Stream(func(w io.Writer) bool {
// 判断是否接收到停止信号
select {
case <-stopChan:
return false // 退出流式推送
case message := <-messageChan:
c.SSEvent("message", message) // 发送事件到客户端
return true // 继续推送
}
})
}
}

View File

@@ -12,6 +12,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
@@ -191,7 +192,12 @@ func (w *WafLogAPi) GetHttpCopyMaskApi(c *gin.Context) {
}
wafLog, _ := wafLogService.GetDetailApi(req)
response.OkWithDetailed(GenerateCurlRequest(wafLog), "获取成功", c)
if req.OutputFormat == "curl" {
response.OkWithDetailed(GenerateCurlRequest(wafLog), "获取成功", c)
} else {
response.OkWithDetailed(GenerateRawHTTPRequest(wafLog), "获取成功", c)
}
} else {
response.FailWithMessage("解析失败", c)
}
@@ -229,6 +235,86 @@ func (w *WafLogAPi) GetAllIpTagApi(c *gin.Context) {
response.OkWithDetailed(ipAttackTags, "获取成功", c)
}
}
func GenerateRawHTTPRequest(weblog innerbean.WebLog) string {
parsedURL, err := url.Parse(weblog.URL)
if err != nil {
return ""
}
// 构建请求行
pathWithQuery := parsedURL.Path
if parsedURL.RawQuery != "" {
pathWithQuery += "?" + parsedURL.RawQuery
}
// 根据协议确定 HTTP 版本
httpVersion := "HTTP/1.1"
if weblog.Scheme != "" {
httpVersion = weblog.Scheme
}
// 处理敏感头信息
maskedHeaders := maskSensitiveHeader(weblog.HEADER)
headers := strings.Split(maskedHeaders, "\n")
// 处理 Cookie
maskedCookies := maskSensitiveCookies(weblog.COOKIES)
if maskedCookies != "" {
cookieHeader := fmt.Sprintf("Cookie: %s", maskedCookies)
// 替换或添加 Cookie 头
cookieFound := false
for i, h := range headers {
if strings.HasPrefix(strings.TrimSpace(h), "Cookie:") {
headers[i] = cookieHeader
cookieFound = true
break
}
}
if !cookieFound {
headers = append(headers, cookieHeader)
}
}
// 确保 Host 头存在
host := parsedURL.Host
if host != "" {
hostExists := false
for _, h := range headers {
if strings.HasPrefix(strings.TrimSpace(strings.ToLower(h)), "host:") {
hostExists = true
break
}
}
if !hostExists {
headers = append(headers, fmt.Sprintf("Host: %s", host))
}
}
// 构建最终 header
var cleanHeaders []string
for _, h := range headers {
if trimmed := strings.TrimSpace(h); trimmed != "" {
cleanHeaders = append(cleanHeaders, trimmed)
}
}
// 构建完整请求
requestLines := []string{
fmt.Sprintf("%s %s %s",
weblog.METHOD,
pathWithQuery,
httpVersion,
),
}
requestLines = append(requestLines, cleanHeaders...)
// 添加 body如果有
if weblog.BODY != "" {
requestLines = append(requestLines, "", weblog.BODY)
}
return strings.Join(requestLines, "\n")
}
func GenerateCurlRequest(weblog innerbean.WebLog) string {
headers := strings.Split(weblog.HEADER, "\n")

View File

@@ -32,4 +32,8 @@ var (
GCONFIG_RECORD_DEBUG_ENABLE int64 = 0 //调试开关 默认关闭
GCONFIG_RECORD_DEBUG_PWD string = "" //调试密码 如果未空则不需要密码
GCONFIG_RECORD_GPT_URL string = "https://api.deepseek.com" //GPT远程地址 DeepSeek ChatGpt 以及使用one-api封装好的接口
GCONFIG_RECORD_GPT_TOKEN string = "SamWaf提示请输入密钥" //GPT远程授权密钥
GCONFIG_RECORD_GPT_MODEL string = "deepseek-chat" //GPT 模型名称
)

View File

@@ -37,6 +37,7 @@ type WebLog struct {
SrcByteBody []byte `json:"src_byte_body"` //原始body信息
SrcByteResBody []byte `json:"src_byte_res_body"` //返回body bytes信息
WebLogVersion int `json:"web_log_version"` //日志版本信息早期的是空和0后期实时增加
Scheme string `json:"scheme"` //HTTP 协议
}
// 在 GORM 的 Model 方法中定义复合索引

View File

@@ -0,0 +1,13 @@
package middleware
import "github.com/gin-gonic/gin"
func StreamMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Next()
}
}

23
model/gpt.go Normal file
View File

@@ -0,0 +1,23 @@
package model
type GptMessage struct {
Content string `json:"content"`
Role string `json:"role"`
Unimportant bool `json:"unimportant"`
}
type GptResponseFormat struct {
Type string `json:"type"`
}
type GPTRequest struct {
Messages []GptMessage `json:"messages"`
Model string `json:"model"`
FrequencyPenalty float64 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
PresencePenalty float64 `json:"presence_penalty"`
ResponseFormat GptResponseFormat `json:"response_format"`
Stop interface{} `json:"stop"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
}

View File

@@ -5,6 +5,7 @@ import "SamWaf/model/common/request"
type WafAttackLogDetailReq struct {
CurrrentDbName string `json:"current_db_name"`
REQ_UUID string `json:"req_uuid"`
OutputFormat string `json:"output_format"` //输出格式 raw,curl
}
type WafAttackLogDoExport struct {

View File

@@ -0,0 +1,5 @@
package request
type WafGptSendReq struct {
History [][]string `json:"history"` // 定义结构体
}

View File

@@ -32,6 +32,7 @@ type ApiGroup struct {
WafHttpAuthBaseRouter
WafTaskRouter
WafBlockingPageRouter
WafGPTRouter
}
type PublicApiGroup struct {
LoginRouter

15
router/waf_gpt.go Normal file
View File

@@ -0,0 +1,15 @@
package router
import (
"SamWaf/api"
"github.com/gin-gonic/gin"
)
type WafGPTRouter struct {
}
func (receiver *WafGPTRouter) InitGPTRouter(group *gin.RouterGroup) {
api := api.APIGroupAPP.WafGPTApi
router := group.Group("")
router.POST("/samwaf/gpt/chat", api.ChatApi)
}

View File

@@ -217,6 +217,7 @@ func (waf *WafEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
NetSrcIp: utils.GetSourceClientIP(r.RemoteAddr),
SrcByteBody: bodyByte,
WebLogVersion: global.GWEBLOG_VERSION,
Scheme: r.Proto,
}
formValues := url.Values{}
@@ -447,6 +448,7 @@ func (waf *WafEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
TimeSpent: 0,
NetSrcIp: utils.GetSourceClientIP(r.RemoteAddr),
WebLogVersion: global.GWEBLOG_VERSION,
Scheme: r.Proto,
}
//记录响应body

View File

@@ -64,8 +64,11 @@ func (web *WafWebManager) initRouter(r *gin.Engine) {
router.ApiGroupApp.InitWafTaskRouter(RouterGroup)
router.ApiGroupApp.InitWafBlockingPageRouter(RouterGroup)
gptRouterGroup := r.Group("")
gptRouterGroup.Use(middleware.StreamMiddleware())
router.ApiGroupApp.InitGPTRouter(gptRouterGroup)
}
//r.Use(middleware.GinGlobalExceptionMiddleWare())
if global.GWAF_RELEASE == "true" {
static.Static(r, func(handlers ...gin.HandlerFunc) {
r.NoRoute(handlers...)

View File

@@ -12,21 +12,28 @@ func setConfigIntValue(name string, value int64, change int) {
switch name {
case "record_max_req_body_length":
global.GCONFIG_RECORD_MAX_BODY_LENGTH = value
break
case "record_max_res_body_length":
global.GCONFIG_RECORD_MAX_RES_BODY_LENGTH = value
break
case "record_resp":
global.GCONFIG_RECORD_RESP = value
break
case "delete_history_log_day":
global.GDATA_DELETE_INTERVAL = value
break
case "log_db_size":
global.GDATA_SHARE_DB_SIZE = value
break
case "auto_load_ssl_file":
global.GCONFIG_RECORD_AUTO_LOAD_SSL = value
break
case "kafka_enable":
if global.GCONFIG_RECORD_KAFKA_ENABLE != value && global.GNOTIFY_KAKFA_SERVICE != nil {
global.GNOTIFY_KAKFA_SERVICE.ChangeEnable(value)
}
global.GCONFIG_RECORD_KAFKA_ENABLE = value
break
case "redirect_https_code":
global.GCONFIG_RECORD_REDIRECT_HTTPS_CODE = value
break
@@ -76,21 +83,37 @@ func setConfigStringValue(name string, value string, change int) {
switch name {
case "dns_server":
global.GWAF_RUNTIME_DNS_SERVER = value
break
case "record_log_type":
global.GWAF_RUNTIME_RECORD_LOG_TYPE = value
break
case "gwaf_center_enable":
global.GWAF_CENTER_ENABLE = value
break
case "gwaf_center_url":
global.GWAF_CENTER_URL = value
break
case "gwaf_proxy_header":
global.GCONFIG_RECORD_PROXY_HEADER = value
break
case "kafka_url":
global.GCONFIG_RECORD_KAFKA_URL = value
break
case "kafka_topic":
global.GCONFIG_RECORD_KAFKA_TOPIC = value
break
case "debug_pwd":
global.GCONFIG_RECORD_DEBUG_PWD = value
break
case "gpt_url":
global.GCONFIG_RECORD_GPT_URL = value
break
case "gpt_token":
global.GCONFIG_RECORD_GPT_TOKEN = value
break
case "gpt_model":
global.GCONFIG_RECORD_GPT_MODEL = value
break
default:
zlog.Warn("Unknown config item:", name)
}
@@ -184,4 +207,8 @@ func TaskLoadSetting(initLoad bool) {
updateConfigIntItem(initLoad, "debug", "enable_debug", global.GCONFIG_RECORD_DEBUG_ENABLE, "调试开关 默认关闭", "int", "")
updateConfigStringItem(initLoad, "debug", "debug_pwd", global.GCONFIG_RECORD_DEBUG_PWD, "调试密码 如果未空则不需要密码", "string", "")
updateConfigStringItem(initLoad, "gpt", "gpt_url", global.GCONFIG_RECORD_GPT_URL, "GPT远程地址 默认DeepSeek 符合ChatGpt或者使用one-api封装好的接口都可以", "string", "")
updateConfigStringItem(initLoad, "gpt", "gpt_token", global.GCONFIG_RECORD_GPT_TOKEN, "GPT远程授权密钥", "string", "")
updateConfigStringItem(initLoad, "gpt", "gpt_model", global.GCONFIG_RECORD_GPT_MODEL, "GPT模型名称", "string", "")
}