From 622795740f94ca28dd3e6ba556fa4295956930ef Mon Sep 17 00:00:00 2001 From: ZacharyZcR Date: Tue, 2 Sep 2025 11:45:12 +0000 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84PostgreSQL?= =?UTF-8?q?=E5=92=8CMongoDB=E6=8F=92=E4=BB=B6=E4=BD=BF=E7=94=A8=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E5=8F=91=E5=8C=85=E6=8E=A7=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修改PostgreSQL插件,在testCredential和identifyService中添加发包控制 - 修改MongoDB插件,在checkMongoAuth和testMongoCredential中添加发包控制 - 统一包计数逻辑,确保TCP连接成功和失败都正确计数 - 保持现有功能完整性,提升发包控制一致性 --- plugins/services/mongodb.go | 19 +++++++++++++++++++ plugins/services/postgresql.go | 27 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/plugins/services/mongodb.go b/plugins/services/mongodb.go index a8faad0..7b898ba 100644 --- a/plugins/services/mongodb.go +++ b/plugins/services/mongodb.go @@ -135,6 +135,12 @@ func (p *MongoDBPlugin) mongodbUnauth(ctx context.Context, info *common.HostInfo // checkMongoAuth 检查MongoDB认证状态 - 基于原始工作版本 func (p *MongoDBPlugin) checkMongoAuth(ctx context.Context, address string, packet []byte) (string, error) { + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("MongoDB连接 %s 受限: %s", address, reason)) + return "", fmt.Errorf("发包受限: %s", reason) + } + // 创建连接超时上下文 connCtx, cancel := context.WithTimeout(ctx, time.Duration(common.Timeout)*time.Second) defer cancel() @@ -143,8 +149,11 @@ func (p *MongoDBPlugin) checkMongoAuth(ctx context.Context, address string, pack var d net.Dialer conn, err := d.DialContext(connCtx, "tcp", address) if err != nil { + common.IncrementTCPFailedPacketCount() return "", fmt.Errorf("连接失败: %v", err) } + // 连接成功,计数TCP成功包 + common.IncrementTCPSuccessPacketCount() defer conn.Close() // 检查上下文是否已取消 @@ -216,6 +225,12 @@ func (p *MongoDBPlugin) createOpQueryPacket() []byte { // testMongoCredential 使用官方MongoDB驱动测试凭据 func (p *MongoDBPlugin) testMongoCredential(ctx context.Context, info *common.HostInfo, cred plugins.Credential) bool { + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("MongoDB认证测试 %s:%s 受限: %s", info.Host, info.Ports, reason)) + return false + } + // 构建MongoDB连接URI var uri string if cred.Username != "" && cred.Password != "" { @@ -242,13 +257,17 @@ func (p *MongoDBPlugin) testMongoCredential(ctx context.Context, info *common.Ho // 连接到MongoDB client, err := mongo.Connect(authCtx, clientOptions) if err != nil { + common.IncrementTCPFailedPacketCount() return false } + // 连接成功,计数TCP成功包 + common.IncrementTCPSuccessPacketCount() defer client.Disconnect(authCtx) // 测试连接 - 尝试ping数据库 err = client.Ping(authCtx, nil) if err != nil { + // ping失败,但已经连接成功了,不需要额外计数 return false } diff --git a/plugins/services/postgresql.go b/plugins/services/postgresql.go index f4353f3..fe5fbee 100644 --- a/plugins/services/postgresql.go +++ b/plugins/services/postgresql.go @@ -64,6 +64,12 @@ func (p *PostgreSQLPlugin) Scan(ctx context.Context, info *common.HostInfo) *Sca func (p *PostgreSQLPlugin) testCredential(ctx context.Context, info *common.HostInfo, cred Credential) *sql.DB { + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("PostgreSQL连接 %s:%s 受限: %s", info.Host, info.Ports, reason)) + return nil + } + connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/postgres?sslmode=disable&connect_timeout=%d", cred.Username, cred.Password, info.Host, info.Ports, common.Timeout) @@ -81,9 +87,13 @@ func (p *PostgreSQLPlugin) testCredential(ctx context.Context, info *common.Host err = db.PingContext(pingCtx) if err != nil { + common.IncrementTCPFailedPacketCount() db.Close() return nil } + + // 连接成功,计数TCP成功包 + common.IncrementTCPSuccessPacketCount() return db } @@ -96,6 +106,16 @@ func (p *PostgreSQLPlugin) testCredential(ctx context.Context, info *common.Host func (p *PostgreSQLPlugin) identifyService(ctx context.Context, info *common.HostInfo) *ScanResult { target := fmt.Sprintf("%s:%s", info.Host, info.Ports) + // 检查发包限制 + if canSend, reason := common.CanSendPacket(); !canSend { + common.LogError(fmt.Sprintf("PostgreSQL识别 %s 受限: %s", target, reason)) + return &ScanResult{ + Success: false, + Service: "postgresql", + Error: fmt.Errorf("发包受限: %s", reason), + } + } + connStr := fmt.Sprintf("postgres://invalid:invalid@%s:%s/postgres?sslmode=disable&connect_timeout=%d", info.Host, info.Ports, common.Timeout) @@ -114,6 +134,13 @@ func (p *PostgreSQLPlugin) identifyService(ctx context.Context, info *common.Hos err = db.PingContext(pingCtx) + // 统计包数量 + if err != nil { + common.IncrementTCPFailedPacketCount() + } else { + common.IncrementTCPSuccessPacketCount() + } + // 改进识别逻辑:任何PostgreSQL相关的响应都认为是有效服务 var banner string if err != nil {