feat:add system firewall

#573
This commit is contained in:
samwaf
2025-12-01 16:32:42 +08:00
parent 635a94b233
commit 0b1b7acae5
22 changed files with 2157 additions and 388 deletions

View File

@@ -50,6 +50,7 @@ type APIGroup struct {
WafNotifyChannelApi
WafNotifySubscriptionApi
WafNotifyLogApi
WafFirewallIPBlockApi
}
var APIGroupAPP = new(APIGroup)
@@ -112,4 +113,6 @@ var (
wafCaServerInfoService = waf_service.WafCaServerInfoServiceApp
wafSqlQueryService = waf_service.WafSqlQueryServiceApp
wafFirewallIPBlockService = waf_service.WafFirewallIPBlockServiceApp
)

View File

@@ -256,9 +256,6 @@ func (w *WafFileApi) deleteLogDatabase(filePath string) error {
return fmt.Errorf("重新初始化数据库失败: %v", err)
}
// 5. 重新创建索引
global.GWAF_CHAN_CREATE_LOG_INDEX <- "1"
return nil
}

View File

@@ -0,0 +1,213 @@
package api
import (
"SamWaf/model/common/response"
"SamWaf/model/request"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type WafFirewallIPBlockApi struct {
}
// AddApi 添加防火墙IP封禁
func (w *WafFirewallIPBlockApi) AddApi(c *gin.Context) {
var req request.WafFirewallIPBlockAddReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败: "+err.Error(), c)
return
}
err = wafFirewallIPBlockService.AddApi(req)
if err != nil {
response.FailWithMessage("添加失败: "+err.Error(), c)
return
}
response.OkWithMessage("添加成功IP已在系统防火墙层面封禁", c)
}
// GetDetailApi 获取防火墙IP封禁详情
func (w *WafFirewallIPBlockApi) GetDetailApi(c *gin.Context) {
var req request.WafFirewallIPBlockDetailReq
err := c.ShouldBind(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
bean := wafFirewallIPBlockService.GetDetailApi(req)
response.OkWithDetailed(bean, "获取成功", c)
}
// GetListApi 获取防火墙IP封禁列表
func (w *WafFirewallIPBlockApi) GetListApi(c *gin.Context) {
var req request.WafFirewallIPBlockSearchReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
list, total, err := wafFirewallIPBlockService.GetListApi(req)
if err != nil {
response.FailWithMessage("获取失败: "+err.Error(), c)
return
}
response.OkWithDetailed(response.PageResult{
List: list,
Total: total,
PageIndex: req.PageIndex,
PageSize: req.PageSize,
}, "获取成功", c)
}
// DelApi 删除防火墙IP封禁
func (w *WafFirewallIPBlockApi) DelApi(c *gin.Context) {
var req request.WafFirewallIPBlockDelReq
err := c.ShouldBind(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
err = wafFirewallIPBlockService.DelApi(req)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
response.FailWithMessage("请检测参数", c)
} else if err != nil {
response.FailWithMessage("删除失败: "+err.Error(), c)
} else {
response.OkWithMessage("删除成功IP已从系统防火墙解除封禁", c)
}
}
// ModifyApi 修改防火墙IP封禁
func (w *WafFirewallIPBlockApi) ModifyApi(c *gin.Context) {
var req request.WafFirewallIPBlockEditReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
err = wafFirewallIPBlockService.ModifyApi(req)
if err != nil {
response.FailWithMessage("编辑失败: "+err.Error(), c)
} else {
response.OkWithMessage("编辑成功", c)
}
}
// BatchDelApi 批量删除防火墙IP封禁
func (w *WafFirewallIPBlockApi) BatchDelApi(c *gin.Context) {
var req request.WafFirewallIPBlockBatchDelReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
err = wafFirewallIPBlockService.BatchDelApi(req)
if err != nil {
response.FailWithMessage("批量删除失败: "+err.Error(), c)
} else {
response.OkWithMessage(fmt.Sprintf("成功删除 %d 条记录", len(req.Ids)), c)
}
}
// BatchAddApi 批量添加防火墙IP封禁
func (w *WafFirewallIPBlockApi) BatchAddApi(c *gin.Context) {
var req request.WafFirewallIPBlockBatchAddReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
successCount, failedIPs, err := wafFirewallIPBlockService.BatchAddApi(req)
if err != nil {
msg := fmt.Sprintf("批量添加完成,成功 %d 个,失败 %d 个", successCount, len(failedIPs))
if len(failedIPs) > 0 {
msg += fmt.Sprintf("失败的IP: %v", failedIPs)
}
response.FailWithMessage(msg, c)
} else {
response.OkWithMessage(fmt.Sprintf("批量添加成功,共封禁 %d 个IP", successCount), c)
}
}
// EnableApi 启用防火墙IP封禁
func (w *WafFirewallIPBlockApi) EnableApi(c *gin.Context) {
var req request.WafFirewallIPBlockEnableReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
err = wafFirewallIPBlockService.EnableApi(req)
if err != nil {
response.FailWithMessage("启用失败: "+err.Error(), c)
} else {
response.OkWithMessage("启用成功IP已在系统防火墙层面封禁", c)
}
}
// DisableApi 禁用防火墙IP封禁
func (w *WafFirewallIPBlockApi) DisableApi(c *gin.Context) {
var req request.WafFirewallIPBlockDisableReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
err = wafFirewallIPBlockService.DisableApi(req)
if err != nil {
response.FailWithMessage("禁用失败: "+err.Error(), c)
} else {
response.OkWithMessage("禁用成功IP已从系统防火墙解除封禁", c)
}
}
// SyncApi 同步防火墙规则(从数据库恢复到系统防火墙)
func (w *WafFirewallIPBlockApi) SyncApi(c *gin.Context) {
var req request.WafFirewallIPBlockSyncReq
err := c.ShouldBindJSON(&req)
if err != nil {
response.FailWithMessage("解析失败", c)
return
}
successCount, failedCount, err := wafFirewallIPBlockService.SyncFirewallRules(req.HostCode)
if err != nil {
response.FailWithMessage("同步失败: "+err.Error(), c)
} else {
msg := fmt.Sprintf("同步完成,成功 %d 个", successCount)
if failedCount > 0 {
msg += fmt.Sprintf(",失败 %d 个", failedCount)
}
response.OkWithMessage(msg, c)
}
}
// ClearExpiredApi 清理过期的封禁规则
func (w *WafFirewallIPBlockApi) ClearExpiredApi(c *gin.Context) {
count, err := wafFirewallIPBlockService.ClearExpiredRules()
if err != nil {
response.FailWithMessage("清理失败: "+err.Error(), c)
} else {
response.OkWithMessage(fmt.Sprintf("成功清理 %d 条过期规则", count), c)
}
}
// GetStatisticsApi 获取统计信息
func (w *WafFirewallIPBlockApi) GetStatisticsApi(c *gin.Context) {
stats := wafFirewallIPBlockService.GetStatistics()
response.OkWithDetailed(stats, "获取成功", c)
}

View File

@@ -17,9 +17,9 @@ const (
TASK_NOTICE = "task_notice" //通知信息
TASK_HEALTH = "task_health" //健康检测
TASK_CLEAR_CC_WINDOWS = "task_clear_cc_windows" //清除ccWindows记录
TASK_CREATE_DB_INDEX = "task_create_db_index" //创建索引
TASK_CLEAR_WEBCACHE = "task_clear_webcache" //清除缓存
TASK_GC = "task_gc" //GC回收
TASK_STATS_PUSH = "task_stats_push" //系统统计数据推送
TASK_DB_MONITOR = "task_db_monitor" //数据库监控
TASK_FIREWALL_CLEAN_EXPIRED = "task_firewall_clean_expired" //清理过期防火墙IP封禁规则
)

View File

@@ -5,10 +5,12 @@ package firewall
import (
"bufio"
"fmt"
"golang.org/x/text/encoding/simplifiedchinese"
"os/exec"
"runtime"
"strings"
"time"
"golang.org/x/text/encoding/simplifiedchinese"
)
type Charset string
@@ -22,9 +24,29 @@ const ACTION_ALLOW string = "allow" //allow 表示允许连接block 表示
const ACTION_BLOCK string = "block"
const ACTION_BYPASS string = "bypass"
const (
RULE_PREFIX = "SamWAF_Block_" // 规则名称前缀
PROTOCOL_ANY = "any" // 任意协议
PROTOCOL_TCP = "TCP" // TCP 协议
PROTOCOL_UDP = "UDP" // UDP 协议
DIRECTION_IN = "in" // 入站
DIRECTION_OUT = "out" // 出站
DIRECTION_BOTH = "both" // 双向
)
type FireWallEngine struct {
}
// IPBlockInfo IP封禁信息结构
type IPBlockInfo struct {
IP string // IP地址
RuleName string // 规则名称
Reason string // 封禁原因(预留字段,实际存储在数据库)
BlockTime time.Time // 封禁时间(预留字段,实际存储在数据库)
Protocol string // 协议类型
Direction string // 方向
}
func (fw *FireWallEngine) IsFirewallEnabled() bool {
if runtime.GOOS == "linux" {
out, err := exec.Command("iptables", "-L").CombinedOutput()
@@ -55,8 +77,19 @@ func (fw *FireWallEngine) executeCommand(cmd *exec.Cmd) (error error, printstr s
}
func (fw *FireWallEngine) AddRule(ruleName, ipToAdd, action, proc, localport string) error {
cmd := exec.Command("iptables", "-A", "INPUT", ipToAdd)
err, _ := fw.executeCommand(cmd)
// iptables -A INPUT -s <ip> -j DROP
fmt.Printf("[DEBUG] 添加防火墙规则: ip=%s\n", ipToAdd)
cmd := exec.Command("iptables", "-A", "INPUT", "-s", ipToAdd, "-j", "DROP")
fmt.Printf("[DEBUG] 执行命令: iptables -A INPUT -s %s -j DROP\n", ipToAdd)
err, output := fw.executeCommand(cmd)
if err != nil {
fmt.Printf("[ERROR] 添加规则失败: %v, 输出: %s\n", err, output)
return err
}
fmt.Printf("[DEBUG] 添加规则成功, 输出: %s\n", output)
return err
}
@@ -65,18 +98,56 @@ func (fw *FireWallEngine) EditRule(ruleNum int, newRule string) error {
}
func (fw *FireWallEngine) DeleteRule(ruleName string) (bool, error) {
var cmd *exec.Cmd
cmd = exec.Command("iptables", "-D", "INPUT", fmt.Sprintf("%s", ruleName))
err, _ := fw.executeCommand(cmd)
return false, err
// iptables -D INPUT -s <ip> -j DROP
// 从规则名中提取IP
fmt.Printf("[DEBUG] 删除防火墙规则: name=%s\n", ruleName)
ip := extractIPFromRuleName(ruleName)
if ip == "" {
fmt.Printf("[ERROR] 无效的规则名: %s\n", ruleName)
return false, fmt.Errorf("invalid rule name: %s", ruleName)
}
cmd := exec.Command("iptables", "-D", "INPUT", "-s", ip, "-j", "DROP")
fmt.Printf("[DEBUG] 执行命令: iptables -D INPUT -s %s -j DROP\n", ip)
err, output := fw.executeCommand(cmd)
if err != nil {
fmt.Printf("[ERROR] 删除规则失败: %v, 输出: %s\n", err, output)
return false, fmt.Errorf("failed to delete rule: %s, output: %s", err, output)
}
fmt.Printf("[DEBUG] 删除规则成功\n")
return true, nil
}
func (fw *FireWallEngine) IsRuleExists(ruleName string) (bool, error) {
// 从规则名中提取IP
fmt.Printf("[DEBUG] 检查规则是否存在: name=%s\n", ruleName)
ip := extractIPFromRuleName(ruleName)
if ip == "" {
fmt.Printf("[ERROR] 无效的规则名: %s\n", ruleName)
return false, fmt.Errorf("invalid rule name: %s", ruleName)
}
cmd := exec.Command("iptables-save")
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("[ERROR] 获取iptables规则失败: %v\n", err)
return false, fmt.Errorf("failed to list iptables rules: %s, output: %s", err, string(output))
}
return strings.Contains(string(output), "-A INPUT -s "+ruleName+" -j ACCEPT"), nil
// 查找DROP规则
exists := strings.Contains(string(output), "-A INPUT -s "+ip+" -j DROP") ||
strings.Contains(string(output), "-A INPUT -s "+ip+"/32 -j DROP")
if exists {
fmt.Printf("[DEBUG] 规则存在: %s (IP: %s)\n", ruleName, ip)
} else {
fmt.Printf("[DEBUG] 规则不存在: %s (IP: %s)\n", ruleName, ip)
}
return exists, nil
}
func ConvertByte2String(byte []byte, charset Charset) string {
var str string
@@ -91,3 +162,181 @@ func ConvertByte2String(byte []byte, charset Charset) string {
}
return str
}
// BlockIP 封禁指定IP地址
// ip: 要封禁的IP地址支持单个IP或CIDR格式
// reason: 封禁原因(可选,后续会存储到数据库)
func (fw *FireWallEngine) BlockIP(ip string, reason string) error {
fmt.Printf("[INFO] 开始封禁IP: %s, 原因: %s\n", ip, reason)
// 生成规则名称
ruleName := generateRuleName(ip)
fmt.Printf("[DEBUG] 生成规则名称: %s\n", ruleName)
// 检查规则是否已存在
exists, _ := fw.IsRuleExists(ruleName)
if exists {
fmt.Printf("[WARN] IP %s 已经被封禁\n", ip)
return fmt.Errorf("IP %s already blocked", ip)
}
// 添加iptables规则: iptables -A INPUT -s <ip> -j DROP
cmd := exec.Command("iptables", "-A", "INPUT", "-s", ip, "-j", "DROP")
err, output := fw.executeCommand(cmd)
if err != nil {
fmt.Printf("[ERROR] 封禁IP失败: %s, error: %v, output: %s\n", ip, err, output)
return fmt.Errorf("failed to block IP %s: %v, output: %s", ip, err, output)
}
fmt.Printf("[INFO] 成功封禁IP: %s\n", ip)
return nil
}
// UnblockIP 解除对指定IP的封禁
func (fw *FireWallEngine) UnblockIP(ip string) error {
fmt.Printf("[INFO] 开始解除IP封禁: %s\n", ip)
ruleName := generateRuleName(ip)
// 检查规则是否存在
exists, _ := fw.IsRuleExists(ruleName)
if !exists {
fmt.Printf("[WARN] IP %s 未被封禁\n", ip)
return fmt.Errorf("IP %s is not blocked", ip)
}
// 删除iptables规则: iptables -D INPUT -s <ip> -j DROP
cmd := exec.Command("iptables", "-D", "INPUT", "-s", ip, "-j", "DROP")
err, output := fw.executeCommand(cmd)
if err != nil {
fmt.Printf("[ERROR] 解除IP封禁失败: %s, error: %v, output: %s\n", ip, err, output)
return fmt.Errorf("failed to unblock IP %s: %v, output: %s", ip, err, output)
}
fmt.Printf("[INFO] 成功解除IP封禁: %s\n", ip)
return nil
}
// IsIPBlocked 检查IP是否已被封禁
func (fw *FireWallEngine) IsIPBlocked(ip string) (bool, error) {
fmt.Printf("[DEBUG] 检查IP是否被封禁: %s\n", ip)
ruleName := generateRuleName(ip)
blocked, err := fw.IsRuleExists(ruleName)
if blocked {
fmt.Printf("[DEBUG] IP %s 已被封禁\n", ip)
} else {
fmt.Printf("[DEBUG] IP %s 未被封禁\n", ip)
}
return blocked, err
}
// BlockIPList 批量封禁IP列表
// ips: IP地址列表
// 返回成功数量、失败的IP列表和错误信息
func (fw *FireWallEngine) BlockIPList(ips []string) (successCount int, failedIPs []string, err error) {
successCount = 0
failedIPs = []string{}
for _, ip := range ips {
err := fw.BlockIP(ip, "")
if err != nil {
failedIPs = append(failedIPs, ip)
} else {
successCount++
}
}
if len(failedIPs) > 0 {
return successCount, failedIPs, fmt.Errorf("failed to block %d IPs", len(failedIPs))
}
return successCount, failedIPs, nil
}
// UnblockIPList 批量解除IP封禁
func (fw *FireWallEngine) UnblockIPList(ips []string) (successCount int, failedIPs []string, err error) {
successCount = 0
failedIPs = []string{}
for _, ip := range ips {
err := fw.UnblockIP(ip)
if err != nil {
failedIPs = append(failedIPs, ip)
} else {
successCount++
}
}
if len(failedIPs) > 0 {
return successCount, failedIPs, fmt.Errorf("failed to unblock %d IPs", len(failedIPs))
}
return successCount, failedIPs, nil
}
// GetBlockedIPList 获取所有已封禁的IP列表
func (fw *FireWallEngine) GetBlockedIPList() ([]string, error) {
cmd := exec.Command("iptables-save")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get blocked IP list: %v", err)
}
blockedIPs := []string{}
lines := strings.Split(string(output), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// 查找DROP规则: -A INPUT -s <ip> -j DROP
if strings.Contains(line, "-A INPUT -s") && strings.Contains(line, "-j DROP") {
parts := strings.Fields(line)
for i, part := range parts {
if part == "-s" && i+1 < len(parts) {
ip := parts[i+1]
// 移除CIDR后缀如果有
ip = strings.TrimSuffix(ip, "/32")
blockedIPs = append(blockedIPs, ip)
break
}
}
}
}
return blockedIPs, nil
}
// ClearAllBlockedIPs 清除所有封禁规则(谨慎使用)
func (fw *FireWallEngine) ClearAllBlockedIPs() (int, error) {
blockedIPs, err := fw.GetBlockedIPList()
if err != nil {
return 0, err
}
count := 0
for _, ip := range blockedIPs {
err := fw.UnblockIP(ip)
if err == nil {
count++
}
}
return count, nil
}
// generateRuleName 生成规则名称
func generateRuleName(ip string) string {
// 将IP中的点替换为下划线避免命令行解析问题
safeName := strings.ReplaceAll(ip, ".", "_")
safeName = strings.ReplaceAll(safeName, "/", "_")
return RULE_PREFIX + safeName
}
// extractIPFromRuleName 从规则名中提取IP
func extractIPFromRuleName(ruleName string) string {
if !strings.HasPrefix(ruleName, RULE_PREFIX) {
return ""
}
safeName := strings.TrimPrefix(ruleName, RULE_PREFIX)
ip := strings.ReplaceAll(safeName, "_", ".")
return ip
}

422
firewall/firewall_darwin.go Normal file
View File

@@ -0,0 +1,422 @@
//go:build darwin
package firewall
import (
"bufio"
"fmt"
"os/exec"
"strings"
"time"
"golang.org/x/text/encoding/simplifiedchinese"
)
type Charset string
const (
UTF8 = Charset("UTF-8")
GB18030 = Charset("GB18030")
)
const ACTION_ALLOW string = "allow" //allow 表示允许连接block 表示阻止连接bypass 表示只允许安全连接。 =
const ACTION_BLOCK string = "block"
const ACTION_BYPASS string = "bypass"
const (
RULE_PREFIX = "SamWAF_Block_" // 规则名称前缀
PROTOCOL_ANY = "any" // 任意协议
PROTOCOL_TCP = "TCP" // TCP 协议
PROTOCOL_UDP = "UDP" // UDP 协议
DIRECTION_IN = "in" // 入站
DIRECTION_OUT = "out" // 出站
DIRECTION_BOTH = "both" // 双向
// Mac 特有常量
PF_TABLE_NAME = "samwaf_blocked" // pf table 名称
)
type FireWallEngine struct {
}
// IPBlockInfo IP封禁信息结构
type IPBlockInfo struct {
IP string // IP地址
RuleName string // 规则名称
Reason string // 封禁原因(预留字段,实际存储在数据库)
BlockTime time.Time // 封禁时间(预留字段,实际存储在数据库)
Protocol string // 协议类型
Direction string // 方向
}
func (fw *FireWallEngine) IsFirewallEnabled() bool {
// 检查 pf 是否启用
cmd := exec.Command("pfctl", "-s", "info")
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("[WARN] 检查防火墙状态失败: %v\n", err)
return false
}
// 查找 "Status: Enabled"
return strings.Contains(string(output), "Status: Enabled")
}
func (fw *FireWallEngine) executeCommand(cmd *exec.Cmd) (error error, printstr string) {
stdout, err := cmd.StdoutPipe()
if err != nil {
fmt.Println(err)
return err, err.Error()
}
cmd.Start()
in := bufio.NewScanner(stdout)
printstr = ""
for in.Scan() {
cmdRe := ConvertByte2String(in.Bytes(), "UTF-8")
printstr += cmdRe + "\n"
}
cmd.Wait()
return nil, printstr
}
func (fw *FireWallEngine) AddRule(ruleName, ipToAdd, action, proc, localport string) error {
// Mac 使用 pfctl 添加 IP 到 table
fmt.Printf("[DEBUG] 添加防火墙规则 (Mac): ip=%s\n", ipToAdd)
// 确保 table 存在
err := fw.ensureTableExists()
if err != nil {
return fmt.Errorf("failed to ensure table exists: %v", err)
}
// 添加 IP 到 table
cmd := exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "add", ipToAdd)
fmt.Printf("[DEBUG] 执行命令: pfctl -t %s -T add %s\n", PF_TABLE_NAME, ipToAdd)
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("[ERROR] 添加规则失败: %v, 输出: %s\n", err, string(output))
return err
}
fmt.Printf("[DEBUG] 添加规则成功, 输出: %s\n", string(output))
return nil
}
func (fw *FireWallEngine) EditRule(ruleNum int, newRule string) error {
return fmt.Errorf("editRule is not supported on macOS")
}
func (fw *FireWallEngine) DeleteRule(ruleName string) (bool, error) {
fmt.Printf("[DEBUG] 删除防火墙规则 (Mac): name=%s\n", ruleName)
// 从规则名中提取IP
ip := extractIPFromRuleName(ruleName)
if ip == "" {
fmt.Printf("[ERROR] 无效的规则名: %s\n", ruleName)
return false, fmt.Errorf("invalid rule name: %s", ruleName)
}
// 从 table 中删除 IP
cmd := exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "delete", ip)
fmt.Printf("[DEBUG] 执行命令: pfctl -t %s -T delete %s\n", PF_TABLE_NAME, ip)
output, err := cmd.CombinedOutput()
outputStr := string(output)
fmt.Printf("[DEBUG] 删除规则输出: %s\n", outputStr)
if err != nil {
// 如果IP不在表中也算成功幂等性
if strings.Contains(outputStr, "no addresses deleted") {
fmt.Printf("[WARN] IP不在表中: %s\n", ip)
return false, fmt.Errorf("IP %s not in table", ip)
}
fmt.Printf("[ERROR] 删除规则失败: %v\n", err)
return false, err
}
fmt.Printf("[DEBUG] 删除规则成功\n")
return true, nil
}
func (fw *FireWallEngine) IsRuleExists(ruleName string) (bool, error) {
fmt.Printf("[DEBUG] 检查规则是否存在 (Mac): name=%s\n", ruleName)
// 从规则名中提取IP
ip := extractIPFromRuleName(ruleName)
if ip == "" {
fmt.Printf("[ERROR] 无效的规则名: %s\n", ruleName)
return false, fmt.Errorf("invalid rule name: %s", ruleName)
}
// 检查 IP 是否在 table 中
cmd := exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "show")
output, err := cmd.CombinedOutput()
if err != nil {
// table 可能不存在
fmt.Printf("[DEBUG] 获取table失败 (table可能不存在): %v\n", err)
return false, nil
}
outputStr := string(output)
exists := strings.Contains(outputStr, ip)
if exists {
fmt.Printf("[DEBUG] 规则存在: %s (IP: %s)\n", ruleName, ip)
} else {
fmt.Printf("[DEBUG] 规则不存在: %s (IP: %s)\n", ruleName, ip)
}
return exists, nil
}
func ConvertByte2String(byte []byte, charset Charset) string {
var str string
switch charset {
case GB18030:
var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte)
str = string(decodeBytes)
case UTF8:
fallthrough
default:
str = string(byte)
}
return str
}
// BlockIP 封禁指定IP地址
// ip: 要封禁的IP地址支持单个IP或CIDR格式
// reason: 封禁原因(可选,后续会存储到数据库)
func (fw *FireWallEngine) BlockIP(ip string, reason string) error {
fmt.Printf("[INFO] 开始封禁IP (Mac): %s, 原因: %s\n", ip, reason)
// 生成规则名称
ruleName := generateRuleName(ip)
fmt.Printf("[DEBUG] 生成规则名称: %s\n", ruleName)
// 检查规则是否已存在
exists, _ := fw.IsRuleExists(ruleName)
if exists {
fmt.Printf("[WARN] IP %s 已经被封禁\n", ip)
return fmt.Errorf("IP %s already blocked", ip)
}
// 确保 pf 已启用和 table 存在
if err := fw.ensureTableExists(); err != nil {
return fmt.Errorf("failed to ensure table exists: %v", err)
}
// 添加 IP 到 pf table
cmd := exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "add", ip)
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("[ERROR] 封禁IP失败: %s, error: %v, output: %s\n", ip, err, string(output))
return fmt.Errorf("failed to block IP %s: %v, output: %s", ip, err, string(output))
}
fmt.Printf("[INFO] 成功封禁IP: %s\n", ip)
return nil
}
// UnblockIP 解除对指定IP的封禁
func (fw *FireWallEngine) UnblockIP(ip string) error {
fmt.Printf("[INFO] 开始解除IP封禁 (Mac): %s\n", ip)
ruleName := generateRuleName(ip)
// 检查规则是否存在
exists, _ := fw.IsRuleExists(ruleName)
if !exists {
fmt.Printf("[WARN] IP %s 未被封禁\n", ip)
return fmt.Errorf("IP %s is not blocked", ip)
}
// 从 table 中删除 IP
cmd := exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "delete", ip)
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("[ERROR] 解除IP封禁失败: %s, error: %v, output: %s\n", ip, err, string(output))
return fmt.Errorf("failed to unblock IP %s: %v, output: %s", ip, err, string(output))
}
fmt.Printf("[INFO] 成功解除IP封禁: %s\n", ip)
return nil
}
// IsIPBlocked 检查IP是否已被封禁
func (fw *FireWallEngine) IsIPBlocked(ip string) (bool, error) {
fmt.Printf("[DEBUG] 检查IP是否被封禁 (Mac): %s\n", ip)
ruleName := generateRuleName(ip)
blocked, err := fw.IsRuleExists(ruleName)
if blocked {
fmt.Printf("[DEBUG] IP %s 已被封禁\n", ip)
} else {
fmt.Printf("[DEBUG] IP %s 未被封禁\n", ip)
}
return blocked, err
}
// BlockIPList 批量封禁IP列表
// ips: IP地址列表
// 返回成功数量、失败的IP列表和错误信息
func (fw *FireWallEngine) BlockIPList(ips []string) (successCount int, failedIPs []string, err error) {
successCount = 0
failedIPs = []string{}
for _, ip := range ips {
err := fw.BlockIP(ip, "")
if err != nil {
failedIPs = append(failedIPs, ip)
} else {
successCount++
}
}
if len(failedIPs) > 0 {
return successCount, failedIPs, fmt.Errorf("failed to block %d IPs", len(failedIPs))
}
return successCount, failedIPs, nil
}
// UnblockIPList 批量解除IP封禁
func (fw *FireWallEngine) UnblockIPList(ips []string) (successCount int, failedIPs []string, err error) {
successCount = 0
failedIPs = []string{}
for _, ip := range ips {
err := fw.UnblockIP(ip)
if err != nil {
failedIPs = append(failedIPs, ip)
} else {
successCount++
}
}
if len(failedIPs) > 0 {
return successCount, failedIPs, fmt.Errorf("failed to unblock %d IPs", len(failedIPs))
}
return successCount, failedIPs, nil
}
// GetBlockedIPList 获取所有已封禁的IP列表
func (fw *FireWallEngine) GetBlockedIPList() ([]string, error) {
fmt.Printf("[DEBUG] 获取已封禁IP列表 (Mac)\n")
cmd := exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "show")
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("[ERROR] 获取IP列表失败: %v\n", err)
return nil, fmt.Errorf("failed to get blocked IP list: %v", err)
}
blockedIPs := []string{}
lines := strings.Split(string(output), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" && !strings.HasPrefix(line, "#") {
blockedIPs = append(blockedIPs, line)
}
}
fmt.Printf("[DEBUG] 找到 %d 个已封禁的IP\n", len(blockedIPs))
return blockedIPs, nil
}
// ClearAllBlockedIPs 清除所有封禁规则(谨慎使用)
func (fw *FireWallEngine) ClearAllBlockedIPs() (int, error) {
fmt.Printf("[INFO] 清除所有封禁规则 (Mac)\n")
blockedIPs, err := fw.GetBlockedIPList()
if err != nil {
return 0, err
}
count := 0
for _, ip := range blockedIPs {
err := fw.UnblockIP(ip)
if err == nil {
count++
}
}
fmt.Printf("[INFO] 成功清除 %d 条规则\n", count)
return count, nil
}
// generateRuleName 生成规则名称
func generateRuleName(ip string) string {
// 将IP中的点替换为下划线避免命令行解析问题
safeName := strings.ReplaceAll(ip, ".", "_")
safeName = strings.ReplaceAll(safeName, "/", "_")
safeName = strings.ReplaceAll(safeName, ":", "_") // IPv6 支持
return RULE_PREFIX + safeName
}
// extractIPFromRuleName 从规则名中提取IP
func extractIPFromRuleName(ruleName string) string {
if !strings.HasPrefix(ruleName, RULE_PREFIX) {
return ""
}
safeName := strings.TrimPrefix(ruleName, RULE_PREFIX)
ip := strings.ReplaceAll(safeName, "_", ".")
return ip
}
// ensureTableExists 确保 pf table 存在并配置规则
func (fw *FireWallEngine) ensureTableExists() error {
fmt.Printf("[DEBUG] 确保pf table存在\n")
// 检查 table 是否存在
cmd := exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "show")
_, err := cmd.CombinedOutput()
if err == nil {
// table 已存在
fmt.Printf("[DEBUG] pf table %s 已存在\n", PF_TABLE_NAME)
return nil
}
// 创建 table
// 注意:在 macOS 上,需要通过配置文件或 pfctl 的 anchor 来创建持久化的 table
// 这里我们先尝试直接添加一个空IP来初始化 table
fmt.Printf("[DEBUG] 初始化pf table %s\n", PF_TABLE_NAME)
// 尝试添加并立即删除一个临时IP来初始化表
tempIP := "127.0.0.254"
cmd = exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "add", tempIP)
_, err = cmd.CombinedOutput()
if err != nil {
fmt.Printf("[WARN] 初始化table失败: %v\n", err)
fmt.Printf("[INFO] 这是正常的table会在第一次使用时自动创建\n")
}
// 删除临时IP
cmd = exec.Command("pfctl", "-t", PF_TABLE_NAME, "-T", "delete", tempIP)
cmd.CombinedOutput()
return nil
}
// SetupPFRule 设置 pf 规则(需要手动调用一次来配置基础规则)
// 这个方法需要在系统启动时或首次使用时调用
func (fw *FireWallEngine) SetupPFRule() error {
fmt.Printf("[INFO] 设置pf基础规则 (Mac)\n")
// 创建一个临时的 pf 规则文件
ruleContent := fmt.Sprintf(`
# SamWAF IP Block Table
table <%s> persist
# Block incoming traffic from blocked IPs
block in quick from <%s> to any
`, PF_TABLE_NAME, PF_TABLE_NAME)
fmt.Printf("[INFO] pf规则内容:\n%s\n", ruleContent)
fmt.Printf("[INFO] 请手动将以上规则添加到 /etc/pf.conf 并执行 'sudo pfctl -f /etc/pf.conf'\n")
fmt.Printf("[INFO] 或者使用 anchor 功能动态加载规则\n")
return nil
}

View File

@@ -2,61 +2,60 @@ package firewall
import (
"fmt"
"os"
"testing"
"time"
)
// ================== 基础功能测试 ==================
func TestFireWallEngine_AddRule(t *testing.T) {
fw := FireWallEngine{}
// Add a new firewall rule
//ruleToAdd := "-p tcp --dport 8080 -j ACCEPT"
ruleName := "testwaf1"
ipToAdd := "192.168.1.12"
action := ACTION_BLOCK
proc := "TCP"
localport := "8989"
if err := fw.AddRule(ruleName, ipToAdd, action, proc, localport); err != nil {
fmt.Println("Failed to add firewall rule:", err)
t.Logf("Failed to add firewall rule: %v", err)
} else {
fmt.Println("Firewall rule added successfully.")
t.Log("Firewall rule added successfully.")
}
}
func TestFireWallEngine_DeleteRule(t *testing.T) {
fw := FireWallEngine{}
ruleName := "testwaf1"
exists, err := fw.IsRuleExists(ruleName)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
t.Fatalf("Error checking rule existence: %v", err)
}
if exists {
if excuteResult, err := fw.DeleteRule(ruleName); err != nil {
fmt.Println("Failed to delete firewall rule:", err)
t.Logf("Failed to delete firewall rule: %v", err)
} else {
if excuteResult {
fmt.Println("Firewall rule deleted successfully.")
t.Log("Firewall rule deleted successfully.")
} else {
fmt.Println("Firewall rule deleted failed.", err)
t.Logf("Firewall rule deleted failed: %v", err)
}
}
} else {
fmt.Println("Rule does not exist.")
t.Log("Rule does not exist.")
}
}
func TestFireWallEngine_EditRule(t *testing.T) {
fw := FireWallEngine{}
// Edit an existing firewall rule (not supported on Windows)
// Edit an existing firewall rule (not supported)
ruleNum := 1
newRule := "-p tcp --dport 8080 -j DROP"
if err := fw.EditRule(ruleNum, newRule); err != nil {
fmt.Println("Failed to edit firewall rule:", err)
t.Logf("Expected error: %v", err)
} else {
fmt.Println("Firewall rule edited successfully.")
t.Log("Firewall rule edited successfully.")
}
}
@@ -65,9 +64,9 @@ func TestFireWallEngine_IsFirewallEnabled(t *testing.T) {
// Check if the firewall is enabled
if fw.IsFirewallEnabled() {
fmt.Println("Firewall is enabled.")
t.Log("Firewall is enabled.")
} else {
fmt.Println("Firewall is not enabled.")
t.Log("Firewall is not enabled.")
}
}
@@ -78,13 +77,366 @@ func TestFireWallEngine_IsRuleExists(t *testing.T) {
ruleName := "testwaf"
exists, err := fw.IsRuleExists(ruleName)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
t.Logf("Error checking rule existence: %v", err)
}
if exists {
fmt.Println("Rule exists.")
t.Log("Rule exists.")
} else {
fmt.Println("Rule does not exist.")
t.Log("Rule does not exist.")
}
}
// ================== IP封禁功能测试 ==================
// TestBlockIP 测试单个IP封禁
func TestBlockIP(t *testing.T) {
fw := FireWallEngine{}
testIP := "192.168.100.100"
t.Logf("开始测试封禁IP: %s", testIP)
// 先确保IP未被封禁
fw.UnblockIP(testIP)
// 封禁IP
err := fw.BlockIP(testIP, "测试封禁")
if err != nil {
t.Fatalf("封禁IP失败: %v", err)
}
t.Logf("✓ 成功封禁IP: %s", testIP)
// 等待规则生效
time.Sleep(100 * time.Millisecond)
// 验证IP已被封禁
blocked, err := fw.IsIPBlocked(testIP)
if err != nil {
t.Fatalf("检查IP封禁状态失败: %v", err)
}
if !blocked {
t.Fatalf("IP应该已被封禁但检查结果为未封禁")
}
t.Logf("✓ 验证IP已被封禁")
// 测试重复封禁
err = fw.BlockIP(testIP, "重复测试")
if err == nil {
t.Logf("注意: 重复封禁应该返回错误,但未返回")
}
// 清理:解除封禁
err = fw.UnblockIP(testIP)
if err != nil {
t.Fatalf("解除封禁失败: %v", err)
}
t.Logf("✓ 成功解除封禁")
// 验证IP已解除封禁
blocked, err = fw.IsIPBlocked(testIP)
if err != nil {
t.Fatalf("检查IP封禁状态失败: %v", err)
}
if blocked {
t.Fatalf("IP应该已解除封禁但检查结果为已封禁")
}
t.Logf("✓ 验证IP已解除封禁")
}
// TestUnblockIP 测试解除IP封禁
func TestUnblockIP(t *testing.T) {
fw := FireWallEngine{}
testIP := "192.168.100.101"
t.Logf("开始测试解除封禁IP: %s", testIP)
// 先封禁IP
fw.BlockIP(testIP, "测试")
time.Sleep(100 * time.Millisecond)
// 解除封禁
err := fw.UnblockIP(testIP)
if err != nil {
t.Fatalf("解除封禁失败: %v", err)
}
t.Logf("✓ 成功解除封禁")
// 测试解除未封禁的IP
err = fw.UnblockIP(testIP)
if err == nil {
t.Logf("注意: 解除未封禁的IP应该返回错误但未返回")
}
}
// TestIsIPBlocked 测试检查IP封禁状态
func TestIsIPBlocked(t *testing.T) {
fw := FireWallEngine{}
testIP := "192.168.100.102"
t.Logf("开始测试检查IP封禁状态: %s", testIP)
// 先确保IP未被封禁
fw.UnblockIP(testIP)
time.Sleep(100 * time.Millisecond)
// 检查未封禁状态
blocked, err := fw.IsIPBlocked(testIP)
if err != nil {
t.Fatalf("检查IP封禁状态失败: %v", err)
}
if blocked {
t.Fatalf("IP应该未被封禁但检查结果为已封禁")
}
t.Logf("✓ 验证IP未被封禁")
// 封禁IP
fw.BlockIP(testIP, "测试")
time.Sleep(100 * time.Millisecond)
// 检查已封禁状态
blocked, err = fw.IsIPBlocked(testIP)
if err != nil {
t.Fatalf("检查IP封禁状态失败: %v", err)
}
if !blocked {
t.Fatalf("IP应该已被封禁但检查结果为未封禁")
}
t.Logf("✓ 验证IP已被封禁")
// 清理
fw.UnblockIP(testIP)
}
// TestBlockIPList 测试批量封禁IP
func TestBlockIPList(t *testing.T) {
fw := FireWallEngine{}
testIPs := []string{
"192.168.100.110",
"192.168.100.111",
"192.168.100.112",
"192.168.100.113",
"192.168.100.114",
}
t.Logf("开始测试批量封禁IP共 %d 个", len(testIPs))
// 先清理可能存在的规则
for _, ip := range testIPs {
fw.UnblockIP(ip)
}
time.Sleep(100 * time.Millisecond)
// 批量封禁
successCount, failedIPs, err := fw.BlockIPList(testIPs)
if err != nil && successCount == 0 {
t.Fatalf("批量封禁完全失败: %v", err)
}
t.Logf("✓ 成功封禁 %d 个IP", successCount)
if len(failedIPs) > 0 {
t.Logf("失败的IP: %v", failedIPs)
}
// 验证封禁结果
time.Sleep(100 * time.Millisecond)
for _, ip := range testIPs {
blocked, err := fw.IsIPBlocked(ip)
if err != nil {
t.Logf("检查IP %s 封禁状态失败: %v", ip, err)
continue
}
if blocked {
t.Logf("✓ IP %s 已被封禁", ip)
} else {
t.Logf("× IP %s 未被封禁", ip)
}
}
// 清理:批量解除封禁
successCount, failedIPs, err = fw.UnblockIPList(testIPs)
t.Logf("✓ 批量解除封禁完成,成功 %d 个", successCount)
if len(failedIPs) > 0 {
t.Logf("解除失败的IP: %v", failedIPs)
}
}
// TestUnblockIPList 测试批量解除封禁
func TestUnblockIPList(t *testing.T) {
fw := FireWallEngine{}
testIPs := []string{
"192.168.100.120",
"192.168.100.121",
"192.168.100.122",
}
t.Logf("开始测试批量解除封禁IP共 %d 个", len(testIPs))
// 先批量封禁
fw.BlockIPList(testIPs)
time.Sleep(100 * time.Millisecond)
// 批量解除封禁
successCount, failedIPs, err := fw.UnblockIPList(testIPs)
if err != nil && successCount == 0 {
t.Fatalf("批量解除封禁完全失败: %v", err)
}
t.Logf("✓ 成功解除 %d 个IP的封禁", successCount)
if len(failedIPs) > 0 {
t.Logf("失败的IP: %v", failedIPs)
}
}
// TestGetBlockedIPList 测试获取已封禁IP列表
func TestGetBlockedIPList(t *testing.T) {
fw := FireWallEngine{}
testIPs := []string{
"192.168.100.130",
"192.168.100.131",
"192.168.100.132",
}
t.Log("开始测试获取已封禁IP列表")
// 先清理
fw.UnblockIPList(testIPs)
time.Sleep(100 * time.Millisecond)
// 批量封禁测试IP
fw.BlockIPList(testIPs)
time.Sleep(100 * time.Millisecond)
// 获取已封禁IP列表
blockedIPs, err := fw.GetBlockedIPList()
if err != nil {
t.Fatalf("获取已封禁IP列表失败: %v", err)
}
t.Logf("✓ 当前已封禁IP数量: %d", len(blockedIPs))
t.Logf("已封禁IP列表: %v", blockedIPs)
// 验证测试IP是否在列表中
for _, testIP := range testIPs {
found := false
for _, blockedIP := range blockedIPs {
if blockedIP == testIP {
found = true
break
}
}
if found {
t.Logf("✓ IP %s 在已封禁列表中", testIP)
} else {
t.Logf("× IP %s 不在已封禁列表中", testIP)
}
}
// 清理
fw.UnblockIPList(testIPs)
}
// TestClearAllBlockedIPs 测试清除所有封禁
func TestClearAllBlockedIPs(t *testing.T) {
fw := FireWallEngine{}
testIPs := []string{
"192.168.100.140",
"192.168.100.141",
"192.168.100.142",
}
t.Log("开始测试清除所有封禁规则")
// 先添加一些测试规则
fw.BlockIPList(testIPs)
time.Sleep(100 * time.Millisecond)
// 清除所有封禁
count, err := fw.ClearAllBlockedIPs()
if err != nil {
t.Fatalf("清除所有封禁失败: %v", err)
}
t.Logf("✓ 成功清除 %d 条封禁规则", count)
// 验证是否清除干净
blockedIPs, err := fw.GetBlockedIPList()
if err != nil {
t.Fatalf("获取已封禁IP列表失败: %v", err)
}
t.Logf("清除后剩余封禁规则数量: %d", len(blockedIPs))
}
// TestCIDRNotation 测试CIDR格式的IP封禁
func TestCIDRNotation(t *testing.T) {
fw := FireWallEngine{}
testCIDR := "192.168.100.0/24"
t.Logf("开始测试CIDR格式封禁: %s", testCIDR)
// 先清理
fw.UnblockIP(testCIDR)
time.Sleep(100 * time.Millisecond)
// 封禁CIDR
err := fw.BlockIP(testCIDR, "测试CIDR封禁")
if err != nil {
t.Fatalf("封禁CIDR失败: %v", err)
}
t.Logf("✓ 成功封禁CIDR: %s", testCIDR)
// 验证封禁状态
time.Sleep(100 * time.Millisecond)
blocked, err := fw.IsIPBlocked(testCIDR)
if err != nil {
t.Fatalf("检查CIDR封禁状态失败: %v", err)
}
if !blocked {
t.Logf("注意: CIDR应该已被封禁但检查结果为未封禁")
} else {
t.Logf("✓ 验证CIDR已被封禁")
}
// 清理
err = fw.UnblockIP(testCIDR)
if err != nil {
t.Fatalf("解除CIDR封禁失败: %v", err)
}
t.Logf("✓ 成功解除CIDR封禁")
}
// ================== 性能测试 ==================
// BenchmarkBlockIP 性能测试封禁单个IP
func BenchmarkBlockIP(b *testing.B) {
fw := FireWallEngine{}
baseIP := "10.0.0."
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := fmt.Sprintf("%s%d", baseIP, i%254+1)
fw.BlockIP(ip, "benchmark test")
}
// 清理
b.StopTimer()
for i := 0; i < b.N; i++ {
ip := fmt.Sprintf("%s%d", baseIP, i%254+1)
fw.UnblockIP(ip)
}
}
// BenchmarkIsIPBlocked 性能测试检查IP封禁状态
func BenchmarkIsIPBlocked(b *testing.B) {
fw := FireWallEngine{}
testIP := "10.0.0.100"
// 准备先封禁一个IP
fw.BlockIP(testIP, "benchmark test")
time.Sleep(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
fw.IsIPBlocked(testIP)
}
// 清理
b.StopTimer()
fw.UnblockIP(testIP)
}

View File

@@ -5,10 +5,12 @@ package firewall
import (
"bufio"
"fmt"
"golang.org/x/sys/windows/registry"
"golang.org/x/text/encoding/simplifiedchinese"
"os/exec"
"strings"
"time"
"golang.org/x/sys/windows/registry"
"golang.org/x/text/encoding/simplifiedchinese"
)
type Charset string
@@ -22,9 +24,29 @@ const ACTION_ALLOW string = "allow" //allow 表示允许连接block 表示
const ACTION_BLOCK string = "block"
const ACTION_BYPASS string = "bypass"
const (
RULE_PREFIX = "SamWAF_Block_" // 规则名称前缀
PROTOCOL_ANY = "any" // 任意协议
PROTOCOL_TCP = "TCP" // TCP 协议
PROTOCOL_UDP = "UDP" // UDP 协议
DIRECTION_IN = "in" // 入站
DIRECTION_OUT = "out" // 出站
DIRECTION_BOTH = "both" // 双向
)
type FireWallEngine struct {
}
// IPBlockInfo IP封禁信息结构
type IPBlockInfo struct {
IP string // IP地址
RuleName string // 规则名称
Reason string // 封禁原因(预留字段,实际存储在数据库)
BlockTime time.Time // 封禁时间(预留字段,实际存储在数据库)
Protocol string // 协议类型
Direction string // 方向
}
func (fw *FireWallEngine) IsFirewallEnabled() bool {
const firewallRegistryPath = `SYSTEM\CurrentControlSet\Services\SharedAccess\Parameters\FirewallPolicy\StandardProfile`
key, err := registry.OpenKey(registry.LOCAL_MACHINE, firewallRegistryPath, registry.QUERY_VALUE)
@@ -59,18 +81,35 @@ func (fw *FireWallEngine) executeCommand(cmd *exec.Cmd) (error error, printstr s
}
func (fw *FireWallEngine) AddRule(ruleName, ipToAdd, action, proc, localport string) error {
// 构建命令参数
args := []string{"advfirewall", "firewall", "add", "rule", "name=" + ruleName, "dir=in", "action=" + action}
/*s := fmt.Sprintf(`netsh advfirewall firewall add rule name="%s" dir=in action=allow protocol=TCP localport=8080 remoteip=%s`, ruleName, ipToAdd)
cmd = exec.Command("netsh", s)*/
/*cmd = exec.Command("netsh", "advfirewall", "firewall", "add", "rule",
fmt.Sprintf(`name="%s"`, ruleName),
fmt.Sprintf(`dir=in action=allow protocol=TCP localport=8080 remoteip=%s`, ipToAdd),
)*/
cmd := exec.Command("netsh", "advfirewall", "firewall", "add", "rule",
"name="+ruleName, "dir=in", "action="+action, "protocol="+proc, "localport="+localport,
"remoteip="+ipToAdd,
)
err, _ := fw.executeCommand(cmd)
// 处理协议参数 - Windows netsh 不支持 "any" 协议,需要分别创建规则或使用 any
if proc == PROTOCOL_ANY || proc == "any" {
// 对于 any 协议,不指定 protocol 和 localport 参数,这样会匹配所有协议
args = append(args, "remoteip="+ipToAdd)
fmt.Printf("[DEBUG] 添加防火墙规则 (ANY协议): name=%s, action=%s, remoteip=%s\n", ruleName, action, ipToAdd)
} else {
// 指定具体协议
args = append(args, "protocol="+proc)
if localport != "" && localport != "any" {
args = append(args, "localport="+localport)
}
args = append(args, "remoteip="+ipToAdd)
fmt.Printf("[DEBUG] 添加防火墙规则: name=%s, action=%s, protocol=%s, localport=%s, remoteip=%s\n",
ruleName, action, proc, localport, ipToAdd)
}
cmd := exec.Command("netsh", args...)
fmt.Printf("[DEBUG] 执行命令: netsh %s\n", strings.Join(args, " "))
err, output := fw.executeCommand(cmd)
if err != nil {
fmt.Printf("[ERROR] 添加规则失败: %v, 输出: %s\n", err, output)
return err
}
fmt.Printf("[DEBUG] 添加规则成功, 输出: %s\n", output)
return err
}
@@ -79,37 +118,55 @@ func (fw *FireWallEngine) EditRule(ruleNum int, newRule string) error {
}
func (fw *FireWallEngine) DeleteRule(ruleName string) (bool, error) {
fmt.Printf("[DEBUG] 删除防火墙规则: name=%s\n", ruleName)
cmd := exec.Command("netsh", "advfirewall", "firewall", "delete", "rule", fmt.Sprintf("name=%s", ruleName))
err, output := fw.executeCommand(cmd)
fmt.Println(output)
fmt.Printf("[DEBUG] 删除规则输出: %s\n", output)
//已删除 1 规则。确定。
if err == nil {
if strings.Contains(output, "No rules match the specified criteria") {
fmt.Printf("[WARN] 规则不存在: %s\n", ruleName)
return false, fmt.Errorf("error:delete firewall rule: %s, output: %s", ruleName, output)
}
if strings.Contains(output, "没有与指定标准相匹配的规则。") {
fmt.Printf("[WARN] 规则不存在: %s\n", ruleName)
return false, fmt.Errorf("error:delete firewall rule: %s, output: %s", ruleName, output)
}
if strings.Contains(output, "已删除") {
if strings.Contains(output, "已删除") || strings.Contains(output, "Ok") {
fmt.Printf("[DEBUG] 删除规则成功: %s\n", ruleName)
return true, nil
}
}
fmt.Printf("[ERROR] 删除规则失败: %s, error: %v\n", ruleName, err)
return false, fmt.Errorf("error:delete firewall rule: %s, output: %s", ruleName, output)
}
func (fw *FireWallEngine) IsRuleExists(ruleName string) (bool, error) {
fmt.Printf("[DEBUG] 检查规则是否存在: name=%s\n", ruleName)
cmd := exec.Command("netsh", "advfirewall", "firewall", "show", "rule", "name="+ruleName)
err, output := fw.executeCommand(cmd)
if err == nil {
if strings.Contains(output, "No rules match the specified criteria") {
fmt.Printf("[DEBUG] 规则不存在 (EN): %s\n", ruleName)
return false, nil
}
if strings.Contains(output, "没有与指定标准相匹配的规则。") {
fmt.Printf("[DEBUG] 规则不存在 (CN): %s\n", ruleName)
return false, nil
}
if strings.Contains(output, " "+ruleName+"-----") {
// 改进规则存在的判断逻辑 - 只要输出中包含规则名就认为存在
if strings.Contains(output, ruleName) {
fmt.Printf("[DEBUG] 规则存在: %s\n", ruleName)
return true, nil
}
}
fmt.Printf("[WARN] 检查规则失败: %s, error: %v, output: %s\n", ruleName, err, output)
return false, fmt.Errorf("failed to show firewall rule: %s, output: %s", err, string(output))
}
func ConvertByte2String(byte []byte, charset Charset) string {
@@ -125,3 +182,173 @@ func ConvertByte2String(byte []byte, charset Charset) string {
}
return str
}
// BlockIP 封禁指定IP地址入站+出站双向封禁)
// ip: 要封禁的IP地址支持单个IP或CIDR格式
// reason: 封禁原因(可选,后续会存储到数据库)
func (fw *FireWallEngine) BlockIP(ip string, reason string) error {
fmt.Printf("[INFO] 开始封禁IP: %s, 原因: %s\n", ip, reason)
// 生成规则名称
ruleName := generateRuleName(ip)
fmt.Printf("[DEBUG] 生成规则名称: %s\n", ruleName)
// 检查规则是否已存在
exists, _ := fw.IsRuleExists(ruleName)
if exists {
fmt.Printf("[WARN] IP %s 已经被封禁\n", ip)
return fmt.Errorf("IP %s already blocked", ip)
}
// 添加入站阻止规则 - 使用 any 协议会匹配所有协议
err := fw.AddRule(ruleName, ip, ACTION_BLOCK, PROTOCOL_ANY, "")
if err != nil {
fmt.Printf("[ERROR] 封禁IP失败: %s, error: %v\n", ip, err)
return fmt.Errorf("failed to block IP %s: %v", ip, err)
}
fmt.Printf("[INFO] 成功封禁IP: %s\n", ip)
return nil
}
// UnblockIP 解除对指定IP的封禁
func (fw *FireWallEngine) UnblockIP(ip string) error {
fmt.Printf("[INFO] 开始解除IP封禁: %s\n", ip)
ruleName := generateRuleName(ip)
// 检查规则是否存在
exists, _ := fw.IsRuleExists(ruleName)
if !exists {
fmt.Printf("[WARN] IP %s 未被封禁\n", ip)
return fmt.Errorf("IP %s is not blocked", ip)
}
// 删除规则
success, err := fw.DeleteRule(ruleName)
if err != nil {
fmt.Printf("[ERROR] 解除IP封禁失败: %s, error: %v\n", ip, err)
return fmt.Errorf("failed to unblock IP %s: %v", ip, err)
}
if !success {
fmt.Printf("[ERROR] 删除规则失败: %s\n", ip)
return fmt.Errorf("failed to unblock IP %s: rule deletion failed", ip)
}
fmt.Printf("[INFO] 成功解除IP封禁: %s\n", ip)
return nil
}
// IsIPBlocked 检查IP是否已被封禁
func (fw *FireWallEngine) IsIPBlocked(ip string) (bool, error) {
fmt.Printf("[DEBUG] 检查IP是否被封禁: %s\n", ip)
ruleName := generateRuleName(ip)
blocked, err := fw.IsRuleExists(ruleName)
if blocked {
fmt.Printf("[DEBUG] IP %s 已被封禁\n", ip)
} else {
fmt.Printf("[DEBUG] IP %s 未被封禁\n", ip)
}
return blocked, err
}
// BlockIPList 批量封禁IP列表
// ips: IP地址列表
// 返回成功数量、失败的IP列表和错误信息
func (fw *FireWallEngine) BlockIPList(ips []string) (successCount int, failedIPs []string, err error) {
successCount = 0
failedIPs = []string{}
for _, ip := range ips {
err := fw.BlockIP(ip, "")
if err != nil {
failedIPs = append(failedIPs, ip)
} else {
successCount++
}
}
if len(failedIPs) > 0 {
return successCount, failedIPs, fmt.Errorf("failed to block %d IPs", len(failedIPs))
}
return successCount, failedIPs, nil
}
// UnblockIPList 批量解除IP封禁
func (fw *FireWallEngine) UnblockIPList(ips []string) (successCount int, failedIPs []string, err error) {
successCount = 0
failedIPs = []string{}
for _, ip := range ips {
err := fw.UnblockIP(ip)
if err != nil {
failedIPs = append(failedIPs, ip)
} else {
successCount++
}
}
if len(failedIPs) > 0 {
return successCount, failedIPs, fmt.Errorf("failed to unblock %d IPs", len(failedIPs))
}
return successCount, failedIPs, nil
}
// GetBlockedIPList 获取所有已封禁的IP列表
func (fw *FireWallEngine) GetBlockedIPList() ([]string, error) {
cmd := exec.Command("netsh", "advfirewall", "firewall", "show", "rule", "name=all")
err, output := fw.executeCommand(cmd)
if err != nil {
return nil, fmt.Errorf("failed to get blocked IP list: %v", err)
}
blockedIPs := []string{}
lines := strings.Split(output, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// 查找包含规则前缀的规则名
if strings.Contains(line, RULE_PREFIX) {
// 提取IP地址
parts := strings.Split(line, RULE_PREFIX)
if len(parts) > 1 {
ip := strings.TrimSpace(strings.Split(parts[1], " ")[0])
ip = strings.ReplaceAll(ip, "-", ".")
if ip != "" {
blockedIPs = append(blockedIPs, ip)
}
}
}
}
return blockedIPs, nil
}
// ClearAllBlockedIPs 清除所有封禁规则(谨慎使用)
func (fw *FireWallEngine) ClearAllBlockedIPs() (int, error) {
blockedIPs, err := fw.GetBlockedIPList()
if err != nil {
return 0, err
}
count := 0
for _, ip := range blockedIPs {
err := fw.UnblockIP(ip)
if err == nil {
count++
}
}
return count, nil
}
// generateRuleName 生成规则名称
func generateRuleName(ip string) string {
// 将IP中的点替换为下划线避免命令行解析问题
safeName := strings.ReplaceAll(ip, ".", "_")
safeName = strings.ReplaceAll(safeName, "/", "_")
return RULE_PREFIX + safeName
}

View File

@@ -107,7 +107,6 @@ var (
GWAF_CHAN_CLEAR_CC_WINDOWS = make(chan int, 10) //清除cc缓存信息
GWAF_CHAN_CLEAR_CC_IP = make(chan string, 10) //清除cc缓存信息IP
GWAF_QUEUE_SHUTDOWN_SIGNAL chan struct{} = make(chan struct{}) // 队列关闭信号
GWAF_CHAN_CREATE_LOG_INDEX = make(chan string, 10) // 创建日志索引
GWAF_CHAN_MANAGER_RESTART = make(chan int, 1) // 管理端重启信号
GWAF_SHUTDOWN_SIGNAL bool = false // 系统关闭信号

27
main.go
View File

@@ -247,24 +247,9 @@ func (m *wafSystenService) run() {
}
//初始化本地数据库
isNewMainDb, err := wafdb.InitCoreDb("")
if err == nil {
if isNewMainDb {
waftask.TaskCreateIndexByDbName(enums.DB_MAIN)
}
}
isNewLogDb, err := wafdb.InitLogDb("")
if err == nil {
if isNewLogDb {
waftask.TaskCreateIndexByDbName(enums.DB_LOG)
}
}
isNewStatsDb, err := wafdb.InitStatsDb("")
if err == nil {
if isNewStatsDb {
waftask.TaskCreateIndexByDbName(enums.DB_STATS)
}
}
wafdb.InitCoreDb("")
wafdb.InitLogDb("")
wafdb.InitStatsDb("")
//初始化队列引擎
wafqueue.InitDequeEngine()
@@ -332,11 +317,11 @@ func (m *wafSystenService) run() {
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_NOTICE, waftask.TaskStatusNotify)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_HEALTH, waftask.TaskHealth)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_CLEAR_CC_WINDOWS, waftask.TaskCC)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_CREATE_DB_INDEX, waftask.TaskCreateIndex)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_CLEAR_WEBCACHE, waftask.TaskClearWebcache)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_GC, waftask.TaskGC)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_STATS_PUSH, waftask.TaskStatsPush)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_DB_MONITOR, waftask.TaskDatabaseMonitor)
globalobj.GWAF_RUNTIME_OBJ_WAF_TaskRegistry.RegisterTask(enums.TASK_FIREWALL_CLEAN_EXPIRED, waftask.TaskFirewallCleanExpired)
go waftask.TaskShareDbInfo()
@@ -661,10 +646,6 @@ func (m *wafSystenService) run() {
zlog.Debug("定时清空CCip", clearCcWindowsIp)
globalobj.GWAF_RUNTIME_OBJ_WAF_ENGINE.ClearCcWindowsForIP(clearCcWindowsIp)
break
case createLogIndex := <-global.GWAF_CHAN_CREATE_LOG_INDEX:
zlog.Debug("定时创建日志索引", createLogIndex)
waftask.TaskCreateIndexByDbName(enums.DB_LOG)
break
}
}

View File

@@ -0,0 +1,22 @@
package model
import (
"SamWaf/model/baseorm"
)
// FirewallIPBlock 防火墙IP封禁记录表操作系统级别的防火墙封禁
type FirewallIPBlock struct {
baseorm.BaseOrm
HostCode string `json:"host_code"` // 网站唯一码(主要键)
IP string `json:"ip"` // 被封禁的IP地址支持单个IP或CIDR格式
Reason string `json:"reason"` // 封禁原因
BlockType string `json:"block_type"` // 封禁类型manual-手动封禁, auto-自动封禁, temp-临时封禁
Status string `json:"status"` // 状态active-已生效, inactive-已失效, pending-待生效
ExpireTime int64 `json:"expire_time"` // 过期时间时间戳0表示永久
Remarks string `json:"remarks"` // 备注
}
// TableName 表名
func (FirewallIPBlock) TableName() string {
return "firewall_ip_block"
}

View File

@@ -0,0 +1,78 @@
package request
import "SamWaf/model/common/request"
// WafFirewallIPBlockAddReq 添加防火墙IP封禁请求
type WafFirewallIPBlockAddReq struct {
HostCode string `json:"host_code"` // 网站唯一码(主要键)
IP string `json:"ip" binding:"required"` // 被封禁的IP地址
Reason string `json:"reason"` // 封禁原因
BlockType string `json:"block_type"` // 封禁类型manual-手动封禁, auto-自动封禁, temp-临时封禁
ExpireTime int64 `json:"expire_time"` // 过期时间时间戳0表示永久
Remarks string `json:"remarks"` // 备注
}
// WafFirewallIPBlockEditReq 编辑防火墙IP封禁请求
type WafFirewallIPBlockEditReq struct {
Id string `json:"id" binding:"required"` // 唯一键
HostCode string `json:"host_code"` // 网站唯一码(主要键)
IP string `json:"ip" binding:"required"` // 被封禁的IP地址
Reason string `json:"reason"` // 封禁原因
BlockType string `json:"block_type"` // 封禁类型
Status string `json:"status"` // 状态
ExpireTime int64 `json:"expire_time"` // 过期时间时间戳0表示永久
Remarks string `json:"remarks"` // 备注
}
// WafFirewallIPBlockDelReq 删除防火墙IP封禁请求
type WafFirewallIPBlockDelReq struct {
Id string `json:"id" form:"id" binding:"required"` // 唯一键
}
// WafFirewallIPBlockSearchReq 搜索防火墙IP封禁请求
type WafFirewallIPBlockSearchReq struct {
HostCode string `json:"host_code"` // 主机码
IP string `json:"ip"` // IP地址
Reason string `json:"reason"` // 封禁原因
BlockType string `json:"block_type"` // 封禁类型
Status string `json:"status"` // 状态
request.PageInfo
}
// WafFirewallIPBlockDetailReq 获取防火墙IP封禁详情请求
type WafFirewallIPBlockDetailReq struct {
Id string `json:"id" form:"id" binding:"required"` // 唯一键
}
// WafFirewallIPBlockBatchDelReq 批量删除防火墙IP封禁请求
type WafFirewallIPBlockBatchDelReq struct {
Ids []string `json:"ids" binding:"required"` // 唯一键数组
}
// WafFirewallIPBlockBatchAddReq 批量添加防火墙IP封禁请求
type WafFirewallIPBlockBatchAddReq struct {
HostCode string `json:"host_code"` // 网站唯一码
IPs []string `json:"ips" binding:"required"` // IP地址列表
Reason string `json:"reason"` // 封禁原因
BlockType string `json:"block_type"` // 封禁类型
Remarks string `json:"remarks"` // 备注
}
// WafFirewallIPBlockEnableReq 启用防火墙IP封禁请求
type WafFirewallIPBlockEnableReq struct {
Id string `json:"id" binding:"required"` // 唯一键
}
// WafFirewallIPBlockDisableReq 禁用防火墙IP封禁请求
type WafFirewallIPBlockDisableReq struct {
Id string `json:"id" binding:"required"` // 唯一键
}
// WafFirewallIPBlockSyncReq 同步防火墙规则请求
type WafFirewallIPBlockSyncReq struct {
HostCode string `json:"host_code"` // 网站唯一码(可选,为空则同步所有)
}
// WafFirewallIPBlockClearExpiredReq 清理过期规则请求
type WafFirewallIPBlockClearExpiredReq struct {
}

View File

@@ -48,6 +48,7 @@ type ApiGroup struct {
NotifyChannelRouter
NotifySubscriptionRouter
NotifyLogRouter
FirewallIPBlockRouter
}
type PublicApiGroup struct {
LoginRouter

View File

@@ -0,0 +1,35 @@
package router
import (
"SamWaf/api"
"github.com/gin-gonic/gin"
)
type FirewallIPBlockRouter struct {
}
func (receiver *FirewallIPBlockRouter) InitFirewallIPBlockRouter(group *gin.RouterGroup) {
api := api.APIGroupAPP.WafFirewallIPBlockApi
router := group.Group("")
// 基础CRUD
router.POST("/samwaf/firewall/ipblock/list", api.GetListApi) // 获取列表
router.GET("/samwaf/firewall/ipblock/detail", api.GetDetailApi) // 获取详情
router.POST("/samwaf/firewall/ipblock/add", api.AddApi) // 添加
router.GET("/samwaf/firewall/ipblock/del", api.DelApi) // 删除
router.POST("/samwaf/firewall/ipblock/edit", api.ModifyApi) // 编辑
// 批量操作
router.POST("/samwaf/firewall/ipblock/batch/add", api.BatchAddApi) // 批量添加
router.POST("/samwaf/firewall/ipblock/batch/del", api.BatchDelApi) // 批量删除
// 启用/禁用
router.POST("/samwaf/firewall/ipblock/enable", api.EnableApi) // 启用
router.POST("/samwaf/firewall/ipblock/disable", api.DisableApi) // 禁用
// 高级功能
router.POST("/samwaf/firewall/ipblock/sync", api.SyncApi) // 同步规则
router.POST("/samwaf/firewall/ipblock/clear/expired", api.ClearExpiredApi) // 清理过期
router.GET("/samwaf/firewall/ipblock/statistics", api.GetStatisticsApi) // 统计信息
}

View File

@@ -0,0 +1,433 @@
package waf_service
import (
"SamWaf/common/uuid"
"SamWaf/customtype"
"SamWaf/firewall"
"SamWaf/global"
"SamWaf/model"
"SamWaf/model/baseorm"
"SamWaf/model/request"
"errors"
"fmt"
"time"
)
type WafFirewallIPBlockService struct {
fw firewall.FireWallEngine
}
var WafFirewallIPBlockServiceApp = new(WafFirewallIPBlockService)
// AddApi 添加防火墙IP封禁自动调用系统防火墙
func (receiver *WafFirewallIPBlockService) AddApi(req request.WafFirewallIPBlockAddReq) error {
// 1. 检查IP是否已存在
var existBean model.FirewallIPBlock
err := global.GWAF_LOCAL_DB.Where("ip = ? AND status = ?", req.IP, "active").First(&existBean).Error
if err == nil && existBean.Id != "" {
return errors.New("该IP已经被封禁")
}
// 2. 调用防火墙封禁IP
err = receiver.fw.BlockIP(req.IP, req.Reason)
if err != nil {
return fmt.Errorf("防火墙封禁失败: %v", err)
}
// 3. 保存到数据库
var bean = &model.FirewallIPBlock{
BaseOrm: baseorm.BaseOrm{
Id: uuid.GenUUID(),
USER_CODE: global.GWAF_USER_CODE,
Tenant_ID: global.GWAF_TENANT_ID,
CREATE_TIME: customtype.JsonTime(time.Now()),
UPDATE_TIME: customtype.JsonTime(time.Now()),
},
HostCode: req.HostCode,
IP: req.IP,
Reason: req.Reason,
BlockType: req.BlockType,
Status: "active",
ExpireTime: req.ExpireTime,
Remarks: req.Remarks,
}
// 如果没有指定封禁类型,默认为手动封禁
if bean.BlockType == "" {
bean.BlockType = "manual"
}
err = global.GWAF_LOCAL_DB.Create(bean).Error
if err != nil {
// 如果数据库保存失败,回滚防火墙规则
receiver.fw.UnblockIP(req.IP)
return fmt.Errorf("保存到数据库失败: %v", err)
}
return nil
}
// CheckIsExistApi 检查IP是否已存在
func (receiver *WafFirewallIPBlockService) CheckIsExistApi(req request.WafFirewallIPBlockAddReq) error {
return global.GWAF_LOCAL_DB.First(&model.FirewallIPBlock{}, "ip = ? AND status = ?", req.IP, "active").Error
}
// ModifyApi 修改防火墙IP封禁
func (receiver *WafFirewallIPBlockService) ModifyApi(req request.WafFirewallIPBlockEditReq) error {
// 1. 获取原记录
var oldBean model.FirewallIPBlock
err := global.GWAF_LOCAL_DB.Where("id = ?", req.Id).First(&oldBean).Error
if err != nil {
return errors.New("记录不存在")
}
// 2. 如果IP地址改变了需要更新防火墙规则
if oldBean.IP != req.IP {
// 删除旧的防火墙规则
if oldBean.Status == "active" {
receiver.fw.UnblockIP(oldBean.IP)
}
// 添加新的防火墙规则
err = receiver.fw.BlockIP(req.IP, req.Reason)
if err != nil {
// 如果新规则添加失败,恢复旧规则
if oldBean.Status == "active" {
receiver.fw.BlockIP(oldBean.IP, oldBean.Reason)
}
return fmt.Errorf("更新防火墙规则失败: %v", err)
}
}
// 3. 更新数据库
beanMap := map[string]interface{}{
"HostCode": req.HostCode,
"IP": req.IP,
"Reason": req.Reason,
"BlockType": req.BlockType,
"Status": req.Status,
"ExpireTime": req.ExpireTime,
"Remarks": req.Remarks,
"UPDATE_TIME": customtype.JsonTime(time.Now()),
}
err = global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).Where("id = ?", req.Id).Updates(beanMap).Error
if err != nil {
// 如果数据库更新失败,回滚防火墙规则
if oldBean.IP != req.IP {
receiver.fw.UnblockIP(req.IP)
if oldBean.Status == "active" {
receiver.fw.BlockIP(oldBean.IP, oldBean.Reason)
}
}
return fmt.Errorf("更新数据库失败: %v", err)
}
// 4. 如果状态改变,更新防火墙
if oldBean.Status != req.Status {
if req.Status == "active" {
receiver.fw.BlockIP(req.IP, req.Reason)
} else {
receiver.fw.UnblockIP(req.IP)
}
}
return nil
}
// GetDetailApi 获取防火墙IP封禁详情
func (receiver *WafFirewallIPBlockService) GetDetailApi(req request.WafFirewallIPBlockDetailReq) model.FirewallIPBlock {
var bean model.FirewallIPBlock
global.GWAF_LOCAL_DB.Where("id=?", req.Id).Find(&bean)
return bean
}
// GetDetailByIdApi 根据ID获取详情
func (receiver *WafFirewallIPBlockService) GetDetailByIdApi(id string) model.FirewallIPBlock {
var bean model.FirewallIPBlock
global.GWAF_LOCAL_DB.Where("id=?", id).Find(&bean)
return bean
}
// GetListApi 获取防火墙IP封禁列表
func (receiver *WafFirewallIPBlockService) GetListApi(req request.WafFirewallIPBlockSearchReq) ([]model.FirewallIPBlock, int64, error) {
var list []model.FirewallIPBlock
var total int64 = 0
// 构建查询条件
query := global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{})
if len(req.HostCode) > 0 {
query = query.Where("host_code = ?", req.HostCode)
}
if len(req.IP) > 0 {
query = query.Where("ip LIKE ?", "%"+req.IP+"%")
}
if len(req.Reason) > 0 {
query = query.Where("reason LIKE ?", "%"+req.Reason+"%")
}
if len(req.BlockType) > 0 {
query = query.Where("block_type = ?", req.BlockType)
}
if len(req.Status) > 0 {
query = query.Where("status = ?", req.Status)
}
// 统计总数
query.Count(&total)
// 分页查询
err := query.Order("create_time DESC").
Limit(req.PageSize).
Offset(req.PageSize * (req.PageIndex - 1)).
Find(&list).Error
return list, total, err
}
// DelApi 删除防火墙IP封禁同时删除系统防火墙规则
func (receiver *WafFirewallIPBlockService) DelApi(req request.WafFirewallIPBlockDelReq) error {
// 1. 获取记录
var bean model.FirewallIPBlock
err := global.GWAF_LOCAL_DB.Where("id = ?", req.Id).First(&bean).Error
if err != nil {
return errors.New("记录不存在")
}
// 2. 删除防火墙规则
if bean.Status == "active" {
err = receiver.fw.UnblockIP(bean.IP)
if err != nil {
return fmt.Errorf("删除防火墙规则失败: %v", err)
}
}
// 3. 删除数据库记录
err = global.GWAF_LOCAL_DB.Where("id = ?", req.Id).Delete(&model.FirewallIPBlock{}).Error
return err
}
// BatchDelApi 批量删除防火墙IP封禁
func (receiver *WafFirewallIPBlockService) BatchDelApi(req request.WafFirewallIPBlockBatchDelReq) error {
if len(req.Ids) == 0 {
return errors.New("删除ID列表不能为空")
}
// 1. 获取所有要删除的记录
var beans []model.FirewallIPBlock
err := global.GWAF_LOCAL_DB.Where("id IN ?", req.Ids).Find(&beans).Error
if err != nil {
return err
}
// 2. 批量删除防火墙规则
var ipsToUnblock []string
for _, bean := range beans {
if bean.Status == "active" {
ipsToUnblock = append(ipsToUnblock, bean.IP)
}
}
if len(ipsToUnblock) > 0 {
successCount, failedIPs, _ := receiver.fw.UnblockIPList(ipsToUnblock)
if len(failedIPs) > 0 {
return fmt.Errorf("部分IP解除封禁失败成功%d个失败%d个", successCount, len(failedIPs))
}
}
// 3. 批量删除数据库记录
err = global.GWAF_LOCAL_DB.Where("id IN ?", req.Ids).Delete(&model.FirewallIPBlock{}).Error
return err
}
// BatchAddApi 批量添加防火墙IP封禁
func (receiver *WafFirewallIPBlockService) BatchAddApi(req request.WafFirewallIPBlockBatchAddReq) (successCount int, failedIPs []string, err error) {
if len(req.IPs) == 0 {
return 0, nil, errors.New("IP列表不能为空")
}
successCount = 0
failedIPs = []string{}
for _, ip := range req.IPs {
addReq := request.WafFirewallIPBlockAddReq{
HostCode: req.HostCode,
IP: ip,
Reason: req.Reason,
BlockType: req.BlockType,
Remarks: req.Remarks,
}
err := receiver.AddApi(addReq)
if err != nil {
failedIPs = append(failedIPs, ip)
} else {
successCount++
}
}
if len(failedIPs) > 0 {
return successCount, failedIPs, fmt.Errorf("部分IP封禁失败")
}
return successCount, failedIPs, nil
}
// EnableApi 启用防火墙IP封禁
func (receiver *WafFirewallIPBlockService) EnableApi(req request.WafFirewallIPBlockEnableReq) error {
// 1. 获取记录
var bean model.FirewallIPBlock
err := global.GWAF_LOCAL_DB.Where("id = ?", req.Id).First(&bean).Error
if err != nil {
return errors.New("记录不存在")
}
if bean.Status == "active" {
return errors.New("该IP已经处于封禁状态")
}
// 2. 添加防火墙规则
err = receiver.fw.BlockIP(bean.IP, bean.Reason)
if err != nil {
return fmt.Errorf("启用防火墙规则失败: %v", err)
}
// 3. 更新数据库状态
err = global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).
Where("id = ?", req.Id).
Updates(map[string]interface{}{
"Status": "active",
"UPDATE_TIME": customtype.JsonTime(time.Now()),
}).Error
return err
}
// DisableApi 禁用防火墙IP封禁
func (receiver *WafFirewallIPBlockService) DisableApi(req request.WafFirewallIPBlockDisableReq) error {
// 1. 获取记录
var bean model.FirewallIPBlock
err := global.GWAF_LOCAL_DB.Where("id = ?", req.Id).First(&bean).Error
if err != nil {
return errors.New("记录不存在")
}
if bean.Status == "inactive" {
return errors.New("该IP已经处于未封禁状态")
}
// 2. 删除防火墙规则
err = receiver.fw.UnblockIP(bean.IP)
if err != nil {
return fmt.Errorf("禁用防火墙规则失败: %v", err)
}
// 3. 更新数据库状态
err = global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).
Where("id = ?", req.Id).
Updates(map[string]interface{}{
"Status": "inactive",
"UPDATE_TIME": customtype.JsonTime(time.Now()),
}).Error
return err
}
// SyncFirewallRules 同步防火墙规则(从数据库恢复到系统防火墙)
func (receiver *WafFirewallIPBlockService) SyncFirewallRules(hostCode string) (successCount int, failedCount int, err error) {
// 1. 获取所有active状态的记录
var beans []model.FirewallIPBlock
query := global.GWAF_LOCAL_DB.Where("status = ?", "active")
if hostCode != "" {
query = query.Where("host_code = ?", hostCode)
}
err = query.Find(&beans).Error
if err != nil {
return 0, 0, err
}
// 2. 批量添加到防火墙
var ips []string
for _, bean := range beans {
ips = append(ips, bean.IP)
}
if len(ips) > 0 {
successCount, failedIPs, _ := receiver.fw.BlockIPList(ips)
failedCount = len(failedIPs)
return successCount, failedCount, nil
}
return 0, 0, nil
}
// ClearExpiredRules 清理过期的封禁规则
func (receiver *WafFirewallIPBlockService) ClearExpiredRules() (int, error) {
// 1. 查找所有过期的记录ExpireTime > 0 且 < 当前时间)
currentTime := time.Now().Unix()
var beans []model.FirewallIPBlock
err := global.GWAF_LOCAL_DB.Where("expire_time > 0 AND expire_time < ? AND status = ?", currentTime, "active").
Find(&beans).Error
if err != nil {
return 0, err
}
count := 0
for _, bean := range beans {
// 删除防火墙规则
err := receiver.fw.UnblockIP(bean.IP)
if err == nil {
// 更新数据库状态为inactive
global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).
Where("id = ?", bean.Id).
Updates(map[string]interface{}{
"Status": "inactive",
"UPDATE_TIME": customtype.JsonTime(time.Now()),
})
count++
}
}
return count, nil
}
// GetHostCodesByIds 根据ID列表获取对应的HostCode列表
func (receiver *WafFirewallIPBlockService) GetHostCodesByIds(ids []string) ([]string, error) {
var hostCodes []string
err := global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).
Where("id IN ?", ids).
Distinct("host_code").
Pluck("host_code", &hostCodes).Error
return hostCodes, err
}
// GetAllActiveIPs 获取所有active状态的IP列表
func (receiver *WafFirewallIPBlockService) GetAllActiveIPs() ([]string, error) {
var ips []string
err := global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).
Where("status = ?", "active").
Pluck("ip", &ips).Error
return ips, err
}
// GetStatistics 获取统计信息
func (receiver *WafFirewallIPBlockService) GetStatistics() map[string]interface{} {
var total, active, inactive, expired int64
global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).Count(&total)
global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).Where("status = ?", "active").Count(&active)
global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).Where("status = ?", "inactive").Count(&inactive)
currentTime := time.Now().Unix()
global.GWAF_LOCAL_DB.Model(&model.FirewallIPBlock{}).
Where("expire_time > 0 AND expire_time < ?", currentTime).Count(&expired)
return map[string]interface{}{
"total": total,
"active": active,
"inactive": inactive,
"expired": expired,
}
}

View File

@@ -187,6 +187,37 @@ func RunCoreDBMigrations(db *gorm.DB) error {
)
},
},
// 迁移4: 创建防火墙IP封禁表
{
ID: "202511280001_add_firewall_ip_block_table",
Migrate: func(tx *gorm.DB) error {
zlog.Info("迁移 202511280001: 创建防火墙IP封禁表")
// 创建防火墙IP封禁表
if err := tx.AutoMigrate(
&model.FirewallIPBlock{},
); err != nil {
return fmt.Errorf("创建防火墙IP封禁表失败: %w", err)
}
// 创建索引
if err := tx.Exec("CREATE INDEX IF NOT EXISTS idx_firewall_ip_block_ip ON firewall_ip_block(ip)").Error; err != nil {
zlog.Warn("创建索引 idx_firewall_ip_block_ip 失败", "error", err.Error())
}
if err := tx.Exec("CREATE INDEX IF NOT EXISTS idx_firewall_ip_block_status ON firewall_ip_block(status)").Error; err != nil {
zlog.Warn("创建索引 idx_firewall_ip_block_status 失败", "error", err.Error())
}
if err := tx.Exec("CREATE INDEX IF NOT EXISTS idx_firewall_ip_block_expire_time ON firewall_ip_block(expire_time)").Error; err != nil {
zlog.Warn("创建索引 idx_firewall_ip_block_expire_time 失败", "error", err.Error())
}
zlog.Info("防火墙IP封禁表创建成功")
return nil
},
Rollback: func(tx *gorm.DB) error {
zlog.Info("回滚 202511280001: 删除防火墙IP封禁表")
return tx.Migrator().DropTable(&model.FirewallIPBlock{})
},
},
})
// 执行迁移

View File

@@ -90,6 +90,7 @@ func (web *WafWebManager) initRouter(r *gin.Engine) {
router.ApiGroupApp.InitNotifyChannelRouter(RouterGroup)
router.ApiGroupApp.InitNotifySubscriptionRouter(RouterGroup)
router.ApiGroupApp.InitNotifyLogRouter(RouterGroup)
router.ApiGroupApp.InitFirewallIPBlockRouter(RouterGroup)
}
if global.GWAF_RELEASE == "true" {

View File

@@ -71,6 +71,13 @@ func InitTaskDb() []model.Task {
TaskAt: "",
TaskMethod: enums.TASK_CLEAR_CC_WINDOWS,
})
syncTaskToDb(model.Task{
TaskName: "每5分钟进行防火墙IP封禁规则清理",
TaskUnit: enums.TASK_MIN,
TaskValue: 5,
TaskAt: "",
TaskMethod: enums.TASK_FIREWALL_CLEAN_EXPIRED,
})
syncTaskToDb(model.Task{
TaskName: "每天30分钟删除历史下载文件",
TaskUnit: enums.TASK_MIN,
@@ -133,13 +140,7 @@ func InitTaskDb() []model.Task {
TaskAt: "03:00",
TaskMethod: enums.TASK_SSL_PATH_LOAD,
})
syncTaskToDb(model.Task{
TaskName: "每天04:00进行索引创建",
TaskUnit: enums.TASK_DAY,
TaskValue: 1,
TaskAt: "04:00",
TaskMethod: enums.TASK_CREATE_DB_INDEX,
})
syncTaskToDb(model.Task{
TaskName: "每天05:00进行批量任务",
TaskUnit: enums.TASK_DAY,

View File

@@ -1,60 +0,0 @@
package waftask
import (
"SamWaf/common/zlog"
"SamWaf/enums"
)
// TaskCreateIndex 创建索引
func TaskCreateIndex() {
//主库索引创建
createMainDbIndex()
//日志库索引创建
createLogDbIndex()
//统计库索引创建
createStatDbIndex()
}
// TaskCreateIndexByDbName 创建索引通过数据库名称
func TaskCreateIndexByDbName(dbName string) {
//主库索引创建
if dbName == enums.DB_MAIN {
createMainDbIndex()
}
//日志库索引创建
if dbName == enums.DB_LOG {
createLogDbIndex()
}
//统计库索引创建
if dbName == enums.DB_STATS {
createStatDbIndex()
}
}
func createMainDbIndex() {
// ============ 已废弃:索引创建已迁移到 gormigrate ============
// 从 2025-11-14 开始core 数据库索引通过 gormigrate 在数据库初始化时自动创建
// ============================================================
zlog.Info("createMainDbIndex 已废弃,索引由 gormigrate 自动管理")
return
}
func createLogDbIndex() {
// ============ 已废弃:索引创建已迁移到 gormigrate ============
// 从 2025-11-14 开始log 数据库索引通过 gormigrate 在数据库初始化时自动创建
// ============================================================
zlog.Info("createLogDbIndex 已废弃,索引由 gormigrate 自动管理")
return
}
func createStatDbIndex() {
// ============ 已废弃:索引创建已迁移到 gormigrate ============
// 从 2025-11-11 开始stats 数据库索引通过 gormigrate 在数据库初始化时自动创建
// ============================================================
zlog.Info("createStatDbIndex 已废弃,索引由 gormigrate 自动管理")
return
}

View File

@@ -1,247 +0,0 @@
package waftask
import (
"SamWaf/common/zlog"
"SamWaf/global"
"SamWaf/wafdb"
"sync"
"testing"
"time"
)
// TestCreateIndexWithConcurrentOperations 测试在创建索引的同时进行读写操作
func TestCreateIndexWithConcurrentOperations(t *testing.T) {
//初始化日志
zlog.InitZLog(global.GWAF_LOG_DEBUG_ENABLE, "json")
//初始化本地数据库
wafdb.InitCoreDb("../")
wafdb.InitLogDb("../")
wafdb.InitStatsDb("../")
// 确保数据库连接已初始化
if global.GWAF_LOCAL_DB == nil || global.GWAF_LOCAL_LOG_DB == nil || global.GWAF_LOCAL_STATS_DB == nil {
t.Skip("数据库连接未初始化,跳过测试")
}
var wg sync.WaitGroup
// 用于标记测试是否通过
success := true
errChan := make(chan error, 10)
// 启动索引创建任务
wg.Add(1)
go func() {
defer wg.Done()
t.Log("开始创建索引...")
startTime := time.Now()
// 执行索引创建
TaskCreateIndex()
duration := time.Since(startTime)
t.Logf("索引创建完成,耗时: %s", duration.String())
}()
// 同时进行主库写入操作
wg.Add(1)
go func() {
defer wg.Done()
db := global.GWAF_LOCAL_DB
if db == nil {
errChan <- nil
return
}
// 模拟多次写入操作
for i := 0; i < 10; i++ {
// 插入测试数据
err := db.Exec("INSERT INTO ip_tags (user_code, tenant_id, ip, ip_tag) VALUES (?, ?, ?, ?)",
"test_user", "test_tenant", "192.168.1."+time.Now().Format("15.04.05.000"), "test_tag_"+time.Now().Format("15.04.05.000")).Error
if err != nil {
t.Logf("主库写入失败: %v", err)
errChan <- err
return
}
time.Sleep(100 * time.Millisecond)
}
t.Log("主库写入测试完成")
}()
// 同时进行主库读取操作
wg.Add(1)
go func() {
defer wg.Done()
db := global.GWAF_LOCAL_DB
if db == nil {
errChan <- nil
return
}
// 模拟多次读取操作
for i := 0; i < 10; i++ {
var count int64
err := db.Table("ip_tags").Where("user_code = ?", "test_user").Count(&count).Error
if err != nil {
t.Logf("主库读取失败: %v", err)
errChan <- err
return
}
t.Logf("主库读取成功,记录数: %d", count)
time.Sleep(100 * time.Millisecond)
}
t.Log("主库读取测试完成")
}()
// 同时进行日志库写入操作
wg.Add(1)
go func() {
defer wg.Done()
db := global.GWAF_LOCAL_LOG_DB
if db == nil {
errChan <- nil
return
}
// 模拟多次写入操作
for i := 0; i < 10; i++ {
// 插入测试数据
err := db.Exec("INSERT INTO web_logs (REQ_UUID, tenant_id, user_code, src_ip, unix_add_time, task_flag) VALUES (?, ?, ?, ?, ?, ?)",
"test_uuid_"+time.Now().Format("15.04.05.000"), "test_tenant", "test_user", "192.168.1.1",
time.Now().Unix(), 0).Error
if err != nil {
t.Logf("日志库写入失败: %v", err)
errChan <- err
return
}
time.Sleep(100 * time.Millisecond)
}
t.Log("日志库写入测试完成")
}()
// 同时进行日志库读取操作
wg.Add(1)
go func() {
defer wg.Done()
db := global.GWAF_LOCAL_LOG_DB
if db == nil {
errChan <- nil
return
}
// 模拟多次读取操作
for i := 0; i < 10; i++ {
var count int64
err := db.Table("web_logs").Where("user_code = ?", "test_user").Count(&count).Error
if err != nil {
t.Logf("日志库读取失败: %v", err)
errChan <- err
return
}
t.Logf("日志库读取成功,记录数: %d", count)
time.Sleep(100 * time.Millisecond)
}
t.Log("日志库读取测试完成")
}()
// 同时进行统计库写入操作
wg.Add(1)
go func() {
defer wg.Done()
db := global.GWAF_LOCAL_STATS_DB
if db == nil {
errChan <- nil
return
}
// 模拟多次写入操作
for i := 0; i < 10; i++ {
// 插入测试数据
err := db.Exec("INSERT INTO stats_days (tenant_id, user_code, host_code, type, day, count) VALUES (?, ?, ?, ?, ?, ?)",
"test_tenant", "test_user", "test_host", 1, time.Now().Format("20060102"), i).Error
if err != nil {
t.Logf("统计库写入失败: %v", err)
errChan <- err
return
}
time.Sleep(100 * time.Millisecond)
}
t.Log("统计库写入测试完成")
}()
// 同时进行统计库读取操作
wg.Add(1)
go func() {
defer wg.Done()
db := global.GWAF_LOCAL_STATS_DB
if db == nil {
errChan <- nil
return
}
// 模拟多次读取操作
for i := 0; i < 10; i++ {
var count int64
err := db.Table("stats_days").Where("user_code = ?", "test_user").Count(&count).Error
if err != nil {
t.Logf("统计库读取失败: %v", err)
errChan <- err
return
}
t.Logf("统计库读取成功,记录数: %d", count)
time.Sleep(100 * time.Millisecond)
}
t.Log("统计库读取测试完成")
}()
// 等待所有操作完成
wg.Wait()
close(errChan)
// 检查是否有错误发生
for err := range errChan {
if err != nil {
success = false
t.Errorf("测试过程中发生错误: %v", err)
}
}
if success {
t.Log("测试通过:创建索引过程不影响数据库的读写操作")
} else {
t.Error("测试失败:创建索引过程影响了数据库的读写操作")
}
// 清理测试数据
cleanupTestData(t)
}
// cleanupTestData 清理测试数据
func cleanupTestData(t *testing.T) {
t.Log("开始清理测试数据...")
// 清理主库测试数据
if global.GWAF_LOCAL_DB != nil {
err := global.GWAF_LOCAL_DB.Exec("DELETE FROM ip_tags WHERE user_code = ?", "test_user").Error
if err != nil {
t.Logf("清理主库测试数据失败: %v", err)
}
}
// 清理日志库测试数据
if global.GWAF_LOCAL_LOG_DB != nil {
err := global.GWAF_LOCAL_LOG_DB.Exec("DELETE FROM web_logs WHERE user_code = ?", "test_user").Error
if err != nil {
t.Logf("清理日志库测试数据失败: %v", err)
}
}
// 清理统计库测试数据
if global.GWAF_LOCAL_STATS_DB != nil {
err := global.GWAF_LOCAL_STATS_DB.Exec("DELETE FROM stats_days WHERE user_code = ?", "test_user").Error
if err != nil {
t.Logf("清理统计库测试数据失败: %v", err)
}
}
t.Log("测试数据清理完成")
}

View File

@@ -134,7 +134,6 @@ func TaskShareDbInfo() {
global.GWAF_LOCAL_DB.Create(sharDbBean)
global.GWAF_LOCAL_LOG_DB = nil
wafdb.InitLogDb("")
createLogDbIndex() //重新创建索引
global.GDATA_CURRENT_CHANGE = false
zlog.Info(innerLogName, "切库完成...")
}

View File

@@ -0,0 +1,32 @@
package waftask
import (
"SamWaf/common/zlog"
"SamWaf/service/waf_service"
)
var (
wafFirewallIPBlockService = waf_service.WafFirewallIPBlockServiceApp
)
// TaskFirewallCleanExpired 清理过期的防火墙IP封禁规则
func TaskFirewallCleanExpired() {
innerLogName := "TaskFirewallCleanExpired"
zlog.Info(innerLogName, "开始清理过期的防火墙IP封禁规则")
// 调用清理服务
count, err := wafFirewallIPBlockService.ClearExpiredRules()
if err != nil {
zlog.Error(innerLogName, "清理过期规则失败", "error", err.Error())
return
}
if count > 0 {
zlog.Info(innerLogName, "清理过期规则完成",
"清理数量", count,
"说明", "已自动清理过期的防火墙IP封禁规则并从系统防火墙中移除")
} else {
zlog.Debug(innerLogName, "无过期规则需要清理")
}
}