From 53ab64669aa7034eb2671886d0f8d194e63a9aea Mon Sep 17 00:00:00 2001 From: ZacharyZcR Date: Tue, 2 Sep 2025 12:01:18 +0000 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=B8=BASMBInfo=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=8C=85=E6=8E=A7=E5=88=B6=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在主TCP连接和SMBv2连接中添加包控制检查 - 统一错误处理和包计数逻辑 - 确保所有网络操作遵循发包限制 --- plugins/services/smbinfo.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/plugins/services/smbinfo.go b/plugins/services/smbinfo.go index 1b9456c..c922c40 100644 --- a/plugins/services/smbinfo.go +++ b/plugins/services/smbinfo.go @@ -41,15 +41,27 @@ func (p *SMBInfoPlugin) Scan(ctx context.Context, info *common.HostInfo) *ScanRe } } + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("SMBInfo连接 %s 受限: %s", target, reason)) + return &ScanResult{ + Success: false, + Service: "smbinfo", + Error: fmt.Errorf("发包受限: %s", reason), + } + } + // 建立连接 conn, err := net.DialTimeout("tcp", target, time.Duration(common.Timeout)*time.Second) if err != nil { + common.IncrementTCPFailedPacketCount() return &ScanResult{ Success: false, Service: "smbinfo", Error: fmt.Errorf("连接失败: %v", err), } } + common.IncrementTCPSuccessPacketCount() defer conn.Close() conn.SetDeadline(time.Now().Add(time.Duration(common.Timeout) * time.Second)) @@ -253,11 +265,18 @@ func (p *SMBInfoPlugin) handleSMBv1(conn net.Conn, target string) (*SMBInfo, err // handleSMBv2 处理SMBv2协议信息收集 func (p *SMBInfoPlugin) handleSMBv2(target string) (*SMBInfo, error) { + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + return nil, fmt.Errorf("SMBv2连接受限: %s", reason) + } + // 重新建立连接处理SMBv2 conn2, err := net.DialTimeout("tcp", target, time.Duration(common.Timeout)*time.Second) if err != nil { + common.IncrementTCPFailedPacketCount() return nil, fmt.Errorf("SMBv2连接失败: %v", err) } + common.IncrementTCPSuccessPacketCount() defer conn2.Close() conn2.SetDeadline(time.Now().Add(time.Duration(common.Timeout) * time.Second))