feat:support response 500 custom page

#565
This commit is contained in:
samwaf
2025-11-28 10:03:37 +08:00
parent 8a8c3057c5
commit c7d82acdbc
5 changed files with 226 additions and 80 deletions

View File

@@ -8,6 +8,7 @@ import (
"SamWaf/model/request"
"SamWaf/model/spec"
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -117,9 +118,14 @@ func (w *WafBlockingPageApi) NotifyWaf(hostCode string) {
if len(blockingPageList) > 0 {
for i := 0; i < len(blockingPageList); i++ {
if blockingPageList[i].BlockingType == "not_match_website" {
// 域名不匹配使用固定的key
blockingPageMap["not_match_website"] = blockingPageList[i]
} else if blockingPageList[i].BlockingType == "other_block" {
blockingPageMap["other_block"] = blockingPageList[i]
// other_block 类型根据 response_code 区分不同的错误页面
// 例如: 403(WAF拦截), 404, 500, 502 等
if blockingPageList[i].ResponseCode != "" {
blockingPageMap[blockingPageList[i].ResponseCode] = blockingPageList[i]
}
}
}
}

View File

@@ -44,13 +44,6 @@ func (receiver *WafBlockingPageService) CheckIsExistApi(req request.WafBlockingP
//where字段
whereField = ""
if len(req.BlockingPageName) > 0 {
if len(whereField) > 0 {
whereField = whereField + " and "
}
whereField = whereField + " blocking_page_name=? "
}
if len(req.BlockingType) > 0 {
if len(whereField) > 0 {
whereField = whereField + " and "
@@ -65,14 +58,17 @@ func (receiver *WafBlockingPageService) CheckIsExistApi(req request.WafBlockingP
whereField = whereField + " host_code=? "
}
//where字段赋值
if len(req.BlockingPageName) > 0 {
// 对于 other_block 类型,需要检查 response_code 的唯一性
// 同一个网站下,相同的 blocking_type + response_code 组合必须唯一
if len(req.ResponseCode) > 0 {
if len(whereField) > 0 {
whereValues = append(whereValues, req.BlockingPageName)
whereField = whereField + " and "
}
whereField = whereField + " response_code=? "
}
//where字段赋值
if len(req.BlockingType) > 0 {
if len(whereField) > 0 {
whereValues = append(whereValues, req.BlockingType)
@@ -85,6 +81,12 @@ func (receiver *WafBlockingPageService) CheckIsExistApi(req request.WafBlockingP
}
}
if len(req.ResponseCode) > 0 {
if len(whereField) > 0 {
whereValues = append(whereValues, req.ResponseCode)
}
}
global.GWAF_LOCAL_DB.Model(&model.BlockingPage{}).Where(whereField, whereValues...).Count(&total)
return int(total)
}
@@ -99,13 +101,6 @@ func (receiver *WafBlockingPageService) ModifyApi(req request.WafBlockingPageEdi
//where字段
whereField = ""
if len(req.BlockingPageName) > 0 {
if len(whereField) > 0 {
whereField = whereField + " and "
}
whereField = whereField + " blocking_page_name=? "
}
if len(req.BlockingType) > 0 {
if len(whereField) > 0 {
whereField = whereField + " and "
@@ -120,12 +115,17 @@ func (receiver *WafBlockingPageService) ModifyApi(req request.WafBlockingPageEdi
whereField = whereField + " host_code=? "
}
//where字段赋值
if len(req.BlockingPageName) > 0 {
whereValues = append(whereValues, req.BlockingPageName)
// 对于 other_block 类型,需要检查 response_code 的唯一性
// 同一个网站下,相同的 blocking_type + response_code 组合必须唯一
if len(req.ResponseCode) > 0 {
if len(whereField) > 0 {
whereField = whereField + " and "
}
whereField = whereField + " response_code=? "
}
//where字段赋值
if len(req.BlockingType) > 0 {
whereValues = append(whereValues, req.BlockingType)
}
@@ -134,6 +134,10 @@ func (receiver *WafBlockingPageService) ModifyApi(req request.WafBlockingPageEdi
whereValues = append(whereValues, req.HostCode)
}
if len(req.ResponseCode) > 0 {
whereValues = append(whereValues, req.ResponseCode)
}
global.GWAF_LOCAL_DB.Model(&model.BlockingPage{}).Where(whereField, whereValues...).Count(&total)
// 查询是否已存在记录
var bean model.BlockingPage

View File

@@ -43,7 +43,10 @@ func EchoErrorInfo(w http.ResponseWriter, r *http.Request, weblogbean *innerbean
}
// 处理 hostsafe 的模板
if blockingPage, ok := hostsafe.BlockingPage["other_block"]; ok {
// response_code 为 403 的配置
blockingPage, ok := hostsafe.BlockingPage["403"]
if ok {
// 设置 HTTP header
var headers []map[string]string
if err := json.Unmarshal([]byte(blockingPage.ResponseHeader), &headers); err == nil {
@@ -68,38 +71,42 @@ func EchoErrorInfo(w http.ResponseWriter, r *http.Request, weblogbean *innerbean
if code, err := strconv.Atoi(blockingPage.ResponseCode); err == nil {
responseCode = code
}
} else if globalBlockingPage, ok := globalHostSafe.BlockingPage["other_block"]; ok {
// 处理 globalHostSafe 的模板
// 设置 HTTP header
var headers []map[string]string
if err := json.Unmarshal([]byte(globalBlockingPage.ResponseHeader), &headers); err == nil {
for _, header := range headers {
if name, ok := header["name"]; ok {
if value, ok := header["value"]; ok && value != "" {
w.Header().Set(name, value)
} else {
// 检查全局配置
globalBlockingPage, ok := globalHostSafe.BlockingPage["403"]
if ok {
// 处理 globalHostSafe 的模板
// 设置 HTTP header
var headers []map[string]string
if err := json.Unmarshal([]byte(globalBlockingPage.ResponseHeader), &headers); err == nil {
for _, header := range headers {
if name, ok := header["name"]; ok {
if value, ok := header["value"]; ok && value != "" {
w.Header().Set(name, value)
}
}
}
}
}
// 渲染模板
renderedBytes, err := renderTemplate(globalBlockingPage.ResponseContent, renderData)
if err == nil {
resBytes = renderedBytes
// 渲染模板
renderedBytes, err := renderTemplate(globalBlockingPage.ResponseContent, renderData)
if err == nil {
resBytes = renderedBytes
} else {
resBytes = []byte(globalBlockingPage.ResponseContent)
}
// 设置响应码
if code, err := strconv.Atoi(globalBlockingPage.ResponseCode); err == nil {
responseCode = code
}
} else {
resBytes = []byte(globalBlockingPage.ResponseContent)
}
// 设置响应码
if code, err := strconv.Atoi(globalBlockingPage.ResponseCode); err == nil {
responseCode = code
}
} else {
// 默认的阻止页面
renderedBytes, err := renderTemplate(global.GLOBAL_DEFAULT_BLOCK_INFO, renderData)
if err == nil {
resBytes = renderedBytes
} else {
resBytes = []byte(global.GLOBAL_DEFAULT_BLOCK_INFO)
// 默认的阻止页面
renderedBytes, err := renderTemplate(global.GLOBAL_DEFAULT_BLOCK_INFO, renderData)
if err == nil {
resBytes = renderedBytes
} else {
resBytes = []byte(global.GLOBAL_DEFAULT_BLOCK_INFO)
}
}
}
@@ -146,7 +153,10 @@ func EchoResponseErrorInfo(resp *http.Response, weblogbean *innerbean.WebLog, ru
}
// 处理 hostsafe 的模板
if blockingPage, ok := hostsafe.BlockingPage["other_block"]; ok {
// 优先使用 response_code 为 403 的配置,兼容旧版本的 other_block
blockingPage, ok := hostsafe.BlockingPage["403"]
if ok {
// 设置 HTTP header
var headers []map[string]string
if err := json.Unmarshal([]byte(blockingPage.ResponseHeader), &headers); err == nil {
@@ -171,38 +181,43 @@ func EchoResponseErrorInfo(resp *http.Response, weblogbean *innerbean.WebLog, ru
if code, err := strconv.Atoi(blockingPage.ResponseCode); err == nil {
responseCode = code
}
} else if globalBlockingPage, ok := globalHostSafe.BlockingPage["other_block"]; ok {
// 处理 globalHostSafe 的模板
// 设置 HTTP header
var headers []map[string]string
if err := json.Unmarshal([]byte(globalBlockingPage.ResponseHeader), &headers); err == nil {
for _, header := range headers {
if name, ok := header["name"]; ok {
if value, ok := header["value"]; ok && value != "" {
resp.Header.Set(name, value)
} else {
// 检查全局配置
globalBlockingPage, ok := globalHostSafe.BlockingPage["403"]
if ok {
// 处理 globalHostSafe 的模板
// 设置 HTTP header
var headers []map[string]string
if err := json.Unmarshal([]byte(globalBlockingPage.ResponseHeader), &headers); err == nil {
for _, header := range headers {
if name, ok := header["name"]; ok {
if value, ok := header["value"]; ok && value != "" {
resp.Header.Set(name, value)
}
}
}
}
}
// 渲染模板
renderedBytes, err := renderTemplate(globalBlockingPage.ResponseContent, renderData)
if err == nil {
resBytes = renderedBytes
// 渲染模板
renderedBytes, err := renderTemplate(globalBlockingPage.ResponseContent, renderData)
if err == nil {
resBytes = renderedBytes
} else {
resBytes = []byte(globalBlockingPage.ResponseContent)
}
// 设置响应码
if code, err := strconv.Atoi(globalBlockingPage.ResponseCode); err == nil {
responseCode = code
}
} else {
resBytes = []byte(globalBlockingPage.ResponseContent)
}
// 设置响应码
if code, err := strconv.Atoi(globalBlockingPage.ResponseCode); err == nil {
responseCode = code
}
} else {
// 默认的阻止页面
renderedBytes, err := renderTemplate(global.GLOBAL_DEFAULT_BLOCK_INFO, renderData)
if err == nil {
resBytes = renderedBytes
} else {
resBytes = []byte(global.GLOBAL_DEFAULT_BLOCK_INFO)
// 默认的阻止页面
renderedBytes, err := renderTemplate(global.GLOBAL_DEFAULT_BLOCK_INFO, renderData)
if err == nil {
resBytes = renderedBytes
} else {
resBytes = []byte(global.GLOBAL_DEFAULT_BLOCK_INFO)
}
}
}

View File

@@ -853,6 +853,11 @@ func (waf *WafEngine) modifyResponse() func(*http.Response) error {
weblogfrist := wafHttpContext.Weblog
host := waf.HostCode[wafHttpContext.HostCode]
// 记录后端真实返回的状态码
backendStatusCode := resp.StatusCode
backendStatus := resp.Status
weblogfrist.ACTION = "放行"
weblogfrist.STATUS = resp.Status
weblogfrist.STATUS_CODE = resp.StatusCode
@@ -1089,6 +1094,117 @@ func (waf *WafEngine) modifyResponse() func(*http.Response) error {
weblogfrist.HOST, weblogfrist.URL, resp.StatusCode))
}
} else {
// 检查是否需要应用自定义错误页面(非 ACME Challenge 请求)
statusCodeKey := strconv.Itoa(backendStatusCode)
var customBlockingPage *model.BlockingPage
var useCustomPage bool
// 优先检查网站级别的自定义错误页面配置
if blockingPage, ok := waf.HostTarget[host].BlockingPage[statusCodeKey]; ok {
customBlockingPage = &blockingPage
useCustomPage = true
} else if globalBlockingPage, ok := waf.HostTarget[waf.HostCode[global.GWAF_GLOBAL_HOST_CODE]].BlockingPage[statusCodeKey]; ok {
// 检查全局级别的自定义错误页面配置
customBlockingPage = &globalBlockingPage
useCustomPage = true
}
// 如果找到自定义错误页面配置,则应用
if useCustomPage && customBlockingPage != nil {
// 先读取后端原始响应内容,用于日志记录
var backendOriginalBody []byte
if resp.Body != nil && resp.Body != http.NoBody {
backendOriginalBody, _ = io.ReadAll(resp.Body)
resp.Body.Close()
}
renderData := map[string]interface{}{
"SAMWAF_REQ_UUID": weblogfrist.REQ_UUID,
"SAMWAF_BACKEND_STATUS": backendStatus,
"SAMWAF_BACKEND_CODE": backendStatusCode,
"SAMWAF_BACKEND_BODY": string(backendOriginalBody),
}
// 渲染自定义模板
renderedBytes, err := renderTemplate(customBlockingPage.ResponseContent, renderData)
var resBytes []byte
if err == nil {
resBytes = renderedBytes
} else {
resBytes = []byte(customBlockingPage.ResponseContent)
zlog.Warn(fmt.Sprintf("模板渲染失败: %v, 使用原始内容", err))
}
// 设置自定义响应码(如果配置了)
var customResponseCode int = backendStatusCode
if customBlockingPage.ResponseCode != "" {
if code, err := strconv.Atoi(customBlockingPage.ResponseCode); err == nil {
customResponseCode = code
}
}
// 清空现有的响应头并设置自定义响应头
var headers []map[string]string
if err := json.Unmarshal([]byte(customBlockingPage.ResponseHeader), &headers); err == nil {
for _, header := range headers {
if name, ok := header["name"]; ok {
if value, ok := header["value"]; ok && value != "" {
resp.Header.Set(name, value)
}
}
}
}
// 更新响应
resp.StatusCode = customResponseCode
resp.Status = http.StatusText(customResponseCode)
resp.Body = io.NopCloser(bytes.NewBuffer(resBytes))
resp.ContentLength = int64(len(resBytes))
resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(resBytes)), 10))
// 记录响应Header信息
resHeader := ""
for key, values := range resp.Header {
for _, value := range values {
resHeader += key + ": " + value + "\r\n"
}
}
weblogfrist.ResHeader = resHeader
// 记录日志信息
weblogfrist.ACTION = "放行"
weblogfrist.STATUS = resp.Status
weblogfrist.STATUS_CODE = resp.StatusCode
weblogfrist.RES_BODY = string(backendOriginalBody)
weblogfrist.TASK_FLAG = 1
weblogfrist.BackendCheckCost = time.Now().UnixNano()/1e6 - backendCheckStart
// 记录日志 - 根据配置决定是否记录
if global.GWAF_RUNTIME_RECORD_LOG_TYPE == "all" {
if waf.HostTarget[host].Host.EXCLUDE_URL_LOG == "" {
global.GQEQUE_LOG_DB.Enqueue(weblogfrist)
} else {
lines := strings.Split(waf.HostTarget[host].Host.EXCLUDE_URL_LOG, "\n")
isRecordLog := true
for _, line := range lines {
if strings.HasPrefix(weblogfrist.URL, line) {
isRecordLog = false
}
}
if isRecordLog {
global.GQEQUE_LOG_DB.Enqueue(weblogfrist)
}
}
} else if global.GWAF_RUNTIME_RECORD_LOG_TYPE == "abnormal" && weblogfrist.ACTION != "放行" {
// 自定义错误页也属于"异常"情况,需要记录
global.GQEQUE_LOG_DB.Enqueue(weblogfrist)
}
// 应用自定义错误页后直接返回
return nil
}
//记录响应Header信息
resHeader := ""
for key, values := range resp.Header {

View File

@@ -206,9 +206,14 @@ func (waf *WafEngine) LoadHost(inHost model.Hosts) []innerbean.ServerRunTime {
if len(blockingPageList) > 0 {
for i := 0; i < len(blockingPageList); i++ {
if blockingPageList[i].BlockingType == "not_match_website" {
// 域名不匹配使用固定的key
blockingPageMap["not_match_website"] = blockingPageList[i]
} else if blockingPageList[i].BlockingType == "other_block" {
blockingPageMap["other_block"] = blockingPageList[i]
// other_block 类型根据 response_code 区分不同的错误页面
// 例如: 403(WAF拦截), 404, 500, 502 等
if blockingPageList[i].ResponseCode != "" {
blockingPageMap[blockingPageList[i].ResponseCode] = blockingPageList[i]
}
}
}
}