mirror of
https://gitee.com/samwaf/SamWaf.git
synced 2025-12-06 06:58:54 +08:00
216 lines
7.4 KiB
Go
216 lines
7.4 KiB
Go
package waftunnelengine
|
||
|
||
import (
|
||
"SamWaf/common/zlog"
|
||
"SamWaf/model/waftunnelmodel"
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"strconv"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// startTCPServer 启动TCP服务器
|
||
func (waf *WafTunnelEngine) startTCPServer(netRuntime waftunnelmodel.NetRunTime) {
|
||
addr := ":" + strconv.Itoa(netRuntime.Port)
|
||
listener, err := net.Listen("tcp", addr)
|
||
if err != nil {
|
||
serverPort := strconv.Itoa(netRuntime.Port)
|
||
zlog.Error(fmt.Sprintf("TCP服务器启动失败 [服务端口:%s 错误:%s]", serverPort, err.Error()))
|
||
return
|
||
}
|
||
|
||
key := "tcp" + strconv.Itoa(netRuntime.Port)
|
||
// 更新状态
|
||
netClone, _ := waf.NetListerOnline.Get(key)
|
||
netClone.Status = 0
|
||
netClone.Svr = listener
|
||
waf.NetListerOnline.Set(key, netClone)
|
||
|
||
serverPort := strconv.Itoa(netRuntime.Port)
|
||
zlog.Info(fmt.Sprintf("启动TCP服务器 [服务端口:%s]", serverPort))
|
||
|
||
for {
|
||
conn, err := listener.Accept()
|
||
if err != nil {
|
||
zlog.Error(fmt.Sprintf("TCP连接接收失败 [服务端口:%s 错误:%s]", serverPort, err.Error()))
|
||
break
|
||
}
|
||
|
||
// 获取隧道配置
|
||
tunnelInfo, ok := waf.TunnelTarget.Get("tcp" + strconv.Itoa(netRuntime.Port))
|
||
if !ok {
|
||
// 获取客户端信息用于日志
|
||
clientAddr := conn.RemoteAddr().String()
|
||
clientIP, clientPort, _ := net.SplitHostPort(clientAddr)
|
||
if clientIP == "" {
|
||
clientIP = clientAddr
|
||
clientPort = "unknown"
|
||
}
|
||
serverPort := strconv.Itoa(netRuntime.Port)
|
||
zlog.Error(fmt.Sprintf("未找到端口对应的隧道配置 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
|
||
clientIP, clientPort, serverPort))
|
||
conn.Close()
|
||
continue
|
||
}
|
||
|
||
// 检查入站连接数限制
|
||
if tunnelInfo.Tunnel.MaxInConnect > 0 {
|
||
inConnCount := waf.TCPConnections.GetPortConnsCountByType(netRuntime.Port, waftunnelmodel.ConnTypeSource)
|
||
if inConnCount >= tunnelInfo.Tunnel.MaxInConnect {
|
||
// 获取客户端信息用于日志
|
||
clientAddr := conn.RemoteAddr().String()
|
||
clientIP, clientPort, _ := net.SplitHostPort(clientAddr)
|
||
if clientIP == "" {
|
||
clientIP = clientAddr
|
||
clientPort = "unknown"
|
||
}
|
||
serverPort := strconv.Itoa(netRuntime.Port)
|
||
zlog.Warn(fmt.Sprintf("TCP入站连接数超过限制 [客户端IP:%s 客户端端口:%s 服务端口:%s 当前连接数:%d 最大限制:%d]",
|
||
clientIP, clientPort, serverPort, inConnCount, tunnelInfo.Tunnel.MaxInConnect))
|
||
conn.Close()
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 处理连接
|
||
go waf.handleTCPConnection(conn, netRuntime.Port)
|
||
}
|
||
|
||
zlog.Info(fmt.Sprintf("TCP服务器关闭 [服务端口:%s]", serverPort))
|
||
}
|
||
|
||
// handleTCPConnection 处理TCP连接
|
||
func (waf *WafTunnelEngine) handleTCPConnection(clientConn net.Conn, port int) {
|
||
// 获取客户端IP和端口
|
||
clientAddr := clientConn.RemoteAddr().String()
|
||
clientIP, clientPort, _ := net.SplitHostPort(clientAddr)
|
||
if clientIP == "" {
|
||
clientIP = clientAddr
|
||
clientPort = "unknown"
|
||
}
|
||
serverPort := strconv.Itoa(port)
|
||
|
||
// 获取隧道配置
|
||
tunnelInfo, ok := waf.TunnelTarget.Get("tcp" + strconv.Itoa(port))
|
||
if !ok {
|
||
zlog.Error(fmt.Sprintf("未找到端口对应的隧道配置 [客户端IP:%s 客户端端口:%s 服务端口:%s]", clientIP, clientPort, serverPort))
|
||
clientConn.Close()
|
||
return
|
||
}
|
||
|
||
// 检查IP访问权限
|
||
if !CheckIPAccess("TCP", clientIP, clientPort, serverPort, tunnelInfo.Tunnel) {
|
||
zlog.Warn(fmt.Sprintf("TCP连接被拒绝 [客户端IP:%s 客户端端口:%s 服务端口:%s]", clientIP, clientPort, serverPort))
|
||
clientConn.Close()
|
||
return
|
||
}
|
||
|
||
// 将客户端连接添加到活动连接列表,标记为来源连接
|
||
waf.TCPConnections.AddConn(port, clientConn, waftunnelmodel.ConnTypeSource)
|
||
defer func() {
|
||
clientConn.Close()
|
||
waf.TCPConnections.RemoveConn(port, clientConn)
|
||
}()
|
||
|
||
// 检查出站连接数限制
|
||
if tunnelInfo.Tunnel.MaxOutConnect > 0 {
|
||
outConnCount := waf.TCPConnections.GetPortConnsCountByType(port, waftunnelmodel.ConnTypeTarget)
|
||
if outConnCount >= tunnelInfo.Tunnel.MaxOutConnect {
|
||
zlog.Warn(fmt.Sprintf("TCP出站连接数超过限制 [客户端IP:%s 客户端端口:%s 服务端口:%s 当前连接数:%d 最大限制:%d]",
|
||
clientIP, clientPort, serverPort, outConnCount, tunnelInfo.Tunnel.MaxOutConnect))
|
||
return
|
||
}
|
||
}
|
||
|
||
// 连接到目标服务器
|
||
targetAddr := tunnelInfo.Tunnel.RemoteIp + ":" + strconv.Itoa(tunnelInfo.Tunnel.RemotePort)
|
||
targetConn, err := net.Dial("tcp", targetAddr)
|
||
if err != nil {
|
||
zlog.Error(fmt.Sprintf("连接目标服务器失败 [客户端IP:%s 客户端端口:%s 服务端口:%s 目标地址:%s 错误:%s]",
|
||
clientIP, clientPort, serverPort, targetAddr, err.Error()))
|
||
return
|
||
}
|
||
|
||
// 将目标连接也添加到活动连接列表,标记为目标连接
|
||
waf.TCPConnections.AddConn(port, targetConn, waftunnelmodel.ConnTypeTarget)
|
||
defer func() {
|
||
targetConn.Close()
|
||
waf.TCPConnections.RemoveConn(port, targetConn)
|
||
}()
|
||
|
||
// 设置超时
|
||
if tunnelInfo.Tunnel.ConnTimeout > 0 {
|
||
clientConn.SetDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ConnTimeout) * time.Second))
|
||
targetConn.SetDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ConnTimeout) * time.Second))
|
||
}
|
||
|
||
// 设置读取超时
|
||
if tunnelInfo.Tunnel.ReadTimeout > 0 {
|
||
clientConn.SetReadDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ReadTimeout) * time.Second))
|
||
targetConn.SetReadDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ReadTimeout) * time.Second))
|
||
}
|
||
|
||
// 设置写入超时
|
||
if tunnelInfo.Tunnel.WriteTimeout > 0 {
|
||
clientConn.SetWriteDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.WriteTimeout) * time.Second))
|
||
targetConn.SetWriteDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.WriteTimeout) * time.Second))
|
||
}
|
||
|
||
// 使用context和WaitGroup管理连接生命周期
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
var wg sync.WaitGroup
|
||
wg.Add(2)
|
||
|
||
// 客户端到目标服务器的数据转发
|
||
go func() {
|
||
defer wg.Done()
|
||
defer cancel() // 任一方向断开时取消context
|
||
|
||
// 使用io.Copy,当连接断开时会自动返回
|
||
_, err := io.Copy(targetConn, clientConn)
|
||
if err != nil {
|
||
zlog.Error(fmt.Sprintf("客户端->目标 数据转发结束 [客户端IP:%s 客户端端口:%s 服务端口:%s 错误:%s]",
|
||
clientIP, clientPort, serverPort, err.Error()))
|
||
} else {
|
||
zlog.Debug(fmt.Sprintf("客户端->目标 数据转发正常结束 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
|
||
clientIP, clientPort, serverPort))
|
||
}
|
||
}()
|
||
|
||
// 目标服务器到客户端的数据转发
|
||
go func() {
|
||
defer wg.Done()
|
||
defer cancel() // 任一方向断开时取消context
|
||
|
||
// 使用io.Copy,当连接断开时会自动返回
|
||
_, err := io.Copy(clientConn, targetConn)
|
||
if err != nil {
|
||
zlog.Error(fmt.Sprintf("目标->客户端 数据转发结束 [客户端IP:%s 客户端端口:%s 服务端口:%s 错误:%s]",
|
||
clientIP, clientPort, serverPort, err.Error()))
|
||
} else {
|
||
zlog.Debug(fmt.Sprintf("目标->客户端 数据转发正常结束 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
|
||
clientIP, clientPort, serverPort))
|
||
}
|
||
}()
|
||
|
||
// 监控context取消,主动关闭连接
|
||
go func() {
|
||
<-ctx.Done()
|
||
// context被取消时,主动关闭两个连接
|
||
clientConn.Close()
|
||
targetConn.Close()
|
||
zlog.Debug(fmt.Sprintf("Context取消,主动关闭连接 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
|
||
clientIP, clientPort, serverPort))
|
||
}()
|
||
|
||
// 等待任一方向的连接断开
|
||
wg.Wait()
|
||
zlog.Info(fmt.Sprintf("TCP连接处理完成 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
|
||
clientIP, clientPort, serverPort))
|
||
}
|