Files
SamWaf/waftunnelengine/tcp_handler.go
2025-11-14 14:38:04 +08:00

216 lines
7.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package waftunnelengine
import (
"SamWaf/common/zlog"
"SamWaf/model/waftunnelmodel"
"context"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
)
// startTCPServer 启动TCP服务器
func (waf *WafTunnelEngine) startTCPServer(netRuntime waftunnelmodel.NetRunTime) {
addr := ":" + strconv.Itoa(netRuntime.Port)
listener, err := net.Listen("tcp", addr)
if err != nil {
serverPort := strconv.Itoa(netRuntime.Port)
zlog.Error(fmt.Sprintf("TCP服务器启动失败 [服务端口:%s 错误:%s]", serverPort, err.Error()))
return
}
key := "tcp" + strconv.Itoa(netRuntime.Port)
// 更新状态
netClone, _ := waf.NetListerOnline.Get(key)
netClone.Status = 0
netClone.Svr = listener
waf.NetListerOnline.Set(key, netClone)
serverPort := strconv.Itoa(netRuntime.Port)
zlog.Info(fmt.Sprintf("启动TCP服务器 [服务端口:%s]", serverPort))
for {
conn, err := listener.Accept()
if err != nil {
zlog.Error(fmt.Sprintf("TCP连接接收失败 [服务端口:%s 错误:%s]", serverPort, err.Error()))
break
}
// 获取隧道配置
tunnelInfo, ok := waf.TunnelTarget.Get("tcp" + strconv.Itoa(netRuntime.Port))
if !ok {
// 获取客户端信息用于日志
clientAddr := conn.RemoteAddr().String()
clientIP, clientPort, _ := net.SplitHostPort(clientAddr)
if clientIP == "" {
clientIP = clientAddr
clientPort = "unknown"
}
serverPort := strconv.Itoa(netRuntime.Port)
zlog.Error(fmt.Sprintf("未找到端口对应的隧道配置 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
clientIP, clientPort, serverPort))
conn.Close()
continue
}
// 检查入站连接数限制
if tunnelInfo.Tunnel.MaxInConnect > 0 {
inConnCount := waf.TCPConnections.GetPortConnsCountByType(netRuntime.Port, waftunnelmodel.ConnTypeSource)
if inConnCount >= tunnelInfo.Tunnel.MaxInConnect {
// 获取客户端信息用于日志
clientAddr := conn.RemoteAddr().String()
clientIP, clientPort, _ := net.SplitHostPort(clientAddr)
if clientIP == "" {
clientIP = clientAddr
clientPort = "unknown"
}
serverPort := strconv.Itoa(netRuntime.Port)
zlog.Warn(fmt.Sprintf("TCP入站连接数超过限制 [客户端IP:%s 客户端端口:%s 服务端口:%s 当前连接数:%d 最大限制:%d]",
clientIP, clientPort, serverPort, inConnCount, tunnelInfo.Tunnel.MaxInConnect))
conn.Close()
continue
}
}
// 处理连接
go waf.handleTCPConnection(conn, netRuntime.Port)
}
zlog.Info(fmt.Sprintf("TCP服务器关闭 [服务端口:%s]", serverPort))
}
// handleTCPConnection 处理TCP连接
func (waf *WafTunnelEngine) handleTCPConnection(clientConn net.Conn, port int) {
// 获取客户端IP和端口
clientAddr := clientConn.RemoteAddr().String()
clientIP, clientPort, _ := net.SplitHostPort(clientAddr)
if clientIP == "" {
clientIP = clientAddr
clientPort = "unknown"
}
serverPort := strconv.Itoa(port)
// 获取隧道配置
tunnelInfo, ok := waf.TunnelTarget.Get("tcp" + strconv.Itoa(port))
if !ok {
zlog.Error(fmt.Sprintf("未找到端口对应的隧道配置 [客户端IP:%s 客户端端口:%s 服务端口:%s]", clientIP, clientPort, serverPort))
clientConn.Close()
return
}
// 检查IP访问权限
if !CheckIPAccess("TCP", clientIP, clientPort, serverPort, tunnelInfo.Tunnel) {
zlog.Warn(fmt.Sprintf("TCP连接被拒绝 [客户端IP:%s 客户端端口:%s 服务端口:%s]", clientIP, clientPort, serverPort))
clientConn.Close()
return
}
// 将客户端连接添加到活动连接列表,标记为来源连接
waf.TCPConnections.AddConn(port, clientConn, waftunnelmodel.ConnTypeSource)
defer func() {
clientConn.Close()
waf.TCPConnections.RemoveConn(port, clientConn)
}()
// 检查出站连接数限制
if tunnelInfo.Tunnel.MaxOutConnect > 0 {
outConnCount := waf.TCPConnections.GetPortConnsCountByType(port, waftunnelmodel.ConnTypeTarget)
if outConnCount >= tunnelInfo.Tunnel.MaxOutConnect {
zlog.Warn(fmt.Sprintf("TCP出站连接数超过限制 [客户端IP:%s 客户端端口:%s 服务端口:%s 当前连接数:%d 最大限制:%d]",
clientIP, clientPort, serverPort, outConnCount, tunnelInfo.Tunnel.MaxOutConnect))
return
}
}
// 连接到目标服务器
targetAddr := tunnelInfo.Tunnel.RemoteIp + ":" + strconv.Itoa(tunnelInfo.Tunnel.RemotePort)
targetConn, err := net.Dial("tcp", targetAddr)
if err != nil {
zlog.Error(fmt.Sprintf("连接目标服务器失败 [客户端IP:%s 客户端端口:%s 服务端口:%s 目标地址:%s 错误:%s]",
clientIP, clientPort, serverPort, targetAddr, err.Error()))
return
}
// 将目标连接也添加到活动连接列表,标记为目标连接
waf.TCPConnections.AddConn(port, targetConn, waftunnelmodel.ConnTypeTarget)
defer func() {
targetConn.Close()
waf.TCPConnections.RemoveConn(port, targetConn)
}()
// 设置超时
if tunnelInfo.Tunnel.ConnTimeout > 0 {
clientConn.SetDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ConnTimeout) * time.Second))
targetConn.SetDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ConnTimeout) * time.Second))
}
// 设置读取超时
if tunnelInfo.Tunnel.ReadTimeout > 0 {
clientConn.SetReadDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ReadTimeout) * time.Second))
targetConn.SetReadDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.ReadTimeout) * time.Second))
}
// 设置写入超时
if tunnelInfo.Tunnel.WriteTimeout > 0 {
clientConn.SetWriteDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.WriteTimeout) * time.Second))
targetConn.SetWriteDeadline(time.Now().Add(time.Duration(tunnelInfo.Tunnel.WriteTimeout) * time.Second))
}
// 使用context和WaitGroup管理连接生命周期
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
wg.Add(2)
// 客户端到目标服务器的数据转发
go func() {
defer wg.Done()
defer cancel() // 任一方向断开时取消context
// 使用io.Copy当连接断开时会自动返回
_, err := io.Copy(targetConn, clientConn)
if err != nil {
zlog.Error(fmt.Sprintf("客户端->目标 数据转发结束 [客户端IP:%s 客户端端口:%s 服务端口:%s 错误:%s]",
clientIP, clientPort, serverPort, err.Error()))
} else {
zlog.Debug(fmt.Sprintf("客户端->目标 数据转发正常结束 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
clientIP, clientPort, serverPort))
}
}()
// 目标服务器到客户端的数据转发
go func() {
defer wg.Done()
defer cancel() // 任一方向断开时取消context
// 使用io.Copy当连接断开时会自动返回
_, err := io.Copy(clientConn, targetConn)
if err != nil {
zlog.Error(fmt.Sprintf("目标->客户端 数据转发结束 [客户端IP:%s 客户端端口:%s 服务端口:%s 错误:%s]",
clientIP, clientPort, serverPort, err.Error()))
} else {
zlog.Debug(fmt.Sprintf("目标->客户端 数据转发正常结束 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
clientIP, clientPort, serverPort))
}
}()
// 监控context取消主动关闭连接
go func() {
<-ctx.Done()
// context被取消时主动关闭两个连接
clientConn.Close()
targetConn.Close()
zlog.Debug(fmt.Sprintf("Context取消主动关闭连接 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
clientIP, clientPort, serverPort))
}()
// 等待任一方向的连接断开
wg.Wait()
zlog.Info(fmt.Sprintf("TCP连接处理完成 [客户端IP:%s 客户端端口:%s 服务端口:%s]",
clientIP, clientPort, serverPort))
}