From 5f7669a53772e9d547cdba3c00eb5ec8b1917f00 Mon Sep 17 00:00:00 2001 From: ZacharyZcR Date: Tue, 2 Sep 2025 11:35:46 +0000 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=BC=95=E5=85=A5=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E7=BD=91=E7=BB=9C=E5=8C=85=E8=A3=85=E5=99=A8=EF=BC=8C?= =?UTF-8?q?=E6=8F=90=E5=8D=87=E5=8F=91=E5=8C=85=E6=8E=A7=E5=88=B6=E4=B8=80?= =?UTF-8?q?=E8=87=B4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增common/network.go统一网络操作包装器 - 重构MySQL/FTP/SSH/SNMP插件使用统一包装器 - 简化发包控制逻辑,避免重复代码 - 为未来代理、重试等功能扩展奠定基础 --- common/network.go | 90 +++++++++++++++++++++++++++++++++++++++ plugins/services/ftp.go | 30 +++++++++++++ plugins/services/mysql.go | 15 ++++++- plugins/services/snmp.go | 19 ++------- plugins/services/ssh.go | 14 +----- 5 files changed, 140 insertions(+), 28 deletions(-) create mode 100644 common/network.go diff --git a/common/network.go b/common/network.go new file mode 100644 index 0000000..93a14d0 --- /dev/null +++ b/common/network.go @@ -0,0 +1,90 @@ +package common + +import ( + "fmt" + "net" + "net/http" + "time" +) + +/* +network.go - 统一网络操作包装器 + +提供带发包控制、统计和代理支持的统一网络操作接口。 +消除在每个插件中重复添加发包控制检查的需要。 +*/ + +// ============================================================================= +// 统一网络连接包装器 +// ============================================================================= + +// SafeDialTimeout 带发包控制的TCP连接 +// 自动处理发包限制检查和统计计数 +func SafeDialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + // 检查发包限制 + if canSend, reason := CanSendPacket(); !canSend { + LogError(fmt.Sprintf("网络连接 %s 受限: %s", address, reason)) + return nil, fmt.Errorf("发包受限: %s", reason) + } + + // 建立连接 - 支持SOCKS5代理 + // 注意:WrapperTcpWithTimeout内部已经有计数逻辑,这里不重复计数 + conn, err := WrapperTcpWithTimeout(network, address, timeout) + + return conn, err +} + +// SafeUDPDial 带发包控制的UDP连接 +func SafeUDPDial(network, address string, timeout time.Duration) (net.Conn, error) { + // 检查发包限制 + if canSend, reason := CanSendPacket(); !canSend { + LogError(fmt.Sprintf("UDP连接 %s 受限: %s", address, reason)) + return nil, fmt.Errorf("发包受限: %s", reason) + } + + // 建立UDP连接 + conn, err := net.DialTimeout(network, address, timeout) + + // 统计UDP包数量 + if err == nil { + IncrementUDPPacketCount() + } + + return conn, err +} + +// SafeHTTPDo 带发包控制的HTTP请求 +func SafeHTTPDo(client *http.Client, req *http.Request) (*http.Response, error) { + // 检查发包限制 + if canSend, reason := CanSendPacket(); !canSend { + LogError(fmt.Sprintf("HTTP请求 %s 受限: %s", req.URL.String(), reason)) + return nil, fmt.Errorf("发包受限: %s", reason) + } + + // 执行HTTP请求 + resp, err := client.Do(req) + + // 统计TCP包数量 (HTTP本质上是TCP) + if err != nil { + IncrementTCPFailedPacketCount() + } else { + IncrementTCPSuccessPacketCount() + } + + return resp, err +} + +// ============================================================================= +// 便捷函数封装 +// ============================================================================= + +// SafeTCPDial TCP连接的便捷封装 +func SafeTCPDial(address string, timeout time.Duration) (net.Conn, error) { + return SafeDialTimeout("tcp", address, timeout) +} + +// SafeTCPDialContext 带Context的TCP连接 +func SafeTCPDialContext(network, address string, timeout time.Duration) (net.Conn, error) { + // 这个函数为将来扩展Context支持预留 + return SafeDialTimeout(network, address, timeout) +} \ No newline at end of file diff --git a/plugins/services/ftp.go b/plugins/services/ftp.go index 058f525..85f1966 100644 --- a/plugins/services/ftp.go +++ b/plugins/services/ftp.go @@ -63,7 +63,20 @@ func (p *FTPPlugin) testCredential(ctx context.Context, info *common.HostInfo, c target := fmt.Sprintf("%s:%s", info.Host, info.Ports) timeout := time.Duration(common.Timeout) * time.Second + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("FTP连接 %s 受限: %s", target, reason)) + return nil + } + conn, err := ftplib.DialTimeout(target, timeout) + if err == nil { + // 计数成功连接 + common.IncrementTCPSuccessPacketCount() + } else { + // 计数失败连接 + common.IncrementTCPFailedPacketCount() + } if err != nil { return nil } @@ -83,7 +96,24 @@ func (p *FTPPlugin) identifyService(info *common.HostInfo) *ScanResult { target := fmt.Sprintf("%s:%s", info.Host, info.Ports) timeout := time.Duration(common.Timeout) * time.Second + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("FTP服务识别 %s 受限: %s", target, reason)) + return &ScanResult{ + Success: false, + Service: "ftp", + Error: fmt.Errorf("发包受限: %s", reason), + } + } + conn, err := ftplib.DialTimeout(target, timeout) + if err == nil { + // 计数成功连接 + common.IncrementTCPSuccessPacketCount() + } else { + // 计数失败连接 + common.IncrementTCPFailedPacketCount() + } if err != nil { return &ScanResult{ Success: false, diff --git a/plugins/services/mysql.go b/plugins/services/mysql.go index ef0ff1a..4786e24 100644 --- a/plugins/services/mysql.go +++ b/plugins/services/mysql.go @@ -61,6 +61,13 @@ func (p *MySQLPlugin) Scan(ctx context.Context, info *common.HostInfo) *ScanResu } func (p *MySQLPlugin) testCredential(ctx context.Context, info *common.HostInfo, cred Credential) bool { + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("MySQL连接 %s:%s 受限: %s", info.Host, info.Ports, reason)) + return false + } + + // 建立MySQL连接 connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/mysql?charset=utf8&timeout=%ds", cred.Username, cred.Password, info.Host, info.Ports, common.Timeout) @@ -75,6 +82,12 @@ func (p *MySQLPlugin) testCredential(ctx context.Context, info *common.HostInfo, db.SetMaxIdleConns(0) err = db.PingContext(ctx) + // 统计MySQL连接结果 + if err == nil { + common.IncrementTCPSuccessPacketCount() + } else { + common.IncrementTCPFailedPacketCount() + } return err == nil } @@ -82,7 +95,7 @@ func (p *MySQLPlugin) testCredential(ctx context.Context, info *common.HostInfo, func (p *MySQLPlugin) identifyService(info *common.HostInfo) *ScanResult { target := fmt.Sprintf("%s:%s", info.Host, info.Ports) - conn, err := common.WrapperTcpWithTimeout("tcp", target, time.Duration(common.Timeout)*time.Second) + conn, err := common.SafeTCPDial(target, time.Duration(common.Timeout)*time.Second) if err != nil { return &ScanResult{ Success: false, diff --git a/plugins/services/snmp.go b/plugins/services/snmp.go index 15788a0..1b556fa 100644 --- a/plugins/services/snmp.go +++ b/plugins/services/snmp.go @@ -4,7 +4,6 @@ import ( "context" "encoding/hex" "fmt" - "net" "time" "github.com/shadow1ng/fscan/common" @@ -62,20 +61,12 @@ func (p *SNMPPlugin) Scan(ctx context.Context, info *common.HostInfo) *ScanResul func (p *SNMPPlugin) testCredential(ctx context.Context, info *common.HostInfo, cred Credential) bool { target := fmt.Sprintf("%s:%s", info.Host, info.Ports) - // 检查发包限制 - if canSend, reason := common.CanSendPacket(); !canSend { - common.LogError(fmt.Sprintf("SNMP请求 %s 受限: %s", target, reason)) - return false - } - - conn, err := net.DialTimeout("udp", target, time.Duration(common.Timeout)*time.Second) + // 使用统一UDP包装器 + conn, err := common.SafeUDPDial("udp", target, time.Duration(common.Timeout)*time.Second) if err != nil { return false } defer conn.Close() - - // 计数UDP连接包 - common.IncrementUDPPacketCount() packet := p.buildSNMPGetRequest(cred.Username, "1.3.6.1.2.1.1.1.0") @@ -129,7 +120,8 @@ func (p *SNMPPlugin) identifyService(ctx context.Context, info *common.HostInfo) } } - conn, err := net.DialTimeout("udp", target, time.Duration(common.Timeout)*time.Second) + // 使用统一UDP包装器 + conn, err := common.SafeUDPDial("udp", target, time.Duration(common.Timeout)*time.Second) if err != nil { return &ScanResult{ Success: false, @@ -138,9 +130,6 @@ func (p *SNMPPlugin) identifyService(ctx context.Context, info *common.HostInfo) } } defer conn.Close() - - // 计数UDP连接包 - common.IncrementUDPPacketCount() banner := "SNMP网络管理服务" common.LogSuccess(fmt.Sprintf("SNMP %s %s", target, banner)) diff --git a/plugins/services/ssh.go b/plugins/services/ssh.go index 7e05389..d73aaf0 100644 --- a/plugins/services/ssh.go +++ b/plugins/services/ssh.go @@ -207,18 +207,8 @@ func (p *SSHPlugin) executeCommand(client *ssh.Client, cmd string) (string, erro func (p *SSHPlugin) identifyService(info *common.HostInfo) *ScanResult { target := fmt.Sprintf("%s:%s", info.Host, info.Ports) - // 检查发包限制 - if canSend, reason := common.CanSendPacket(); !canSend { - common.LogError(fmt.Sprintf("SSH服务识别 %s 受限: %s", target, reason)) - return &ScanResult{ - Success: false, - Service: "ssh", - Error: fmt.Errorf("发包受限: %s", reason), - } - } - - // 尝试连接获取SSH Banner - conn, err := common.WrapperTcpWithTimeout("tcp", target, time.Duration(common.Timeout)*time.Second) + // 使用统一TCP包装器获取SSH Banner + conn, err := common.SafeTCPDial(target, time.Duration(common.Timeout)*time.Second) if err != nil { return &ScanResult{ Success: false,