refactor: 重构PostgreSQL和MongoDB插件使用统一发包控制

- 修改PostgreSQL插件,在testCredential和identifyService中添加发包控制
- 修改MongoDB插件,在checkMongoAuth和testMongoCredential中添加发包控制
- 统一包计数逻辑,确保TCP连接成功和失败都正确计数
- 保持现有功能完整性,提升发包控制一致性
This commit is contained in:
ZacharyZcR 2025-09-02 11:45:12 +00:00
parent a23c82142d
commit 622795740f
2 changed files with 46 additions and 0 deletions

View File

@ -135,6 +135,12 @@ func (p *MongoDBPlugin) mongodbUnauth(ctx context.Context, info *common.HostInfo
// checkMongoAuth 检查MongoDB认证状态 - 基于原始工作版本 // checkMongoAuth 检查MongoDB认证状态 - 基于原始工作版本
func (p *MongoDBPlugin) checkMongoAuth(ctx context.Context, address string, packet []byte) (string, error) { func (p *MongoDBPlugin) checkMongoAuth(ctx context.Context, address string, packet []byte) (string, error) {
// 检查发包限制
if canSend, reason := common.CanSendPacket(); !canSend {
common.LogError(fmt.Sprintf("MongoDB连接 %s 受限: %s", address, reason))
return "", fmt.Errorf("发包受限: %s", reason)
}
// 创建连接超时上下文 // 创建连接超时上下文
connCtx, cancel := context.WithTimeout(ctx, time.Duration(common.Timeout)*time.Second) connCtx, cancel := context.WithTimeout(ctx, time.Duration(common.Timeout)*time.Second)
defer cancel() defer cancel()
@ -143,8 +149,11 @@ func (p *MongoDBPlugin) checkMongoAuth(ctx context.Context, address string, pack
var d net.Dialer var d net.Dialer
conn, err := d.DialContext(connCtx, "tcp", address) conn, err := d.DialContext(connCtx, "tcp", address)
if err != nil { if err != nil {
common.IncrementTCPFailedPacketCount()
return "", fmt.Errorf("连接失败: %v", err) return "", fmt.Errorf("连接失败: %v", err)
} }
// 连接成功计数TCP成功包
common.IncrementTCPSuccessPacketCount()
defer conn.Close() defer conn.Close()
// 检查上下文是否已取消 // 检查上下文是否已取消
@ -216,6 +225,12 @@ func (p *MongoDBPlugin) createOpQueryPacket() []byte {
// testMongoCredential 使用官方MongoDB驱动测试凭据 // testMongoCredential 使用官方MongoDB驱动测试凭据
func (p *MongoDBPlugin) testMongoCredential(ctx context.Context, info *common.HostInfo, cred plugins.Credential) bool { func (p *MongoDBPlugin) testMongoCredential(ctx context.Context, info *common.HostInfo, cred plugins.Credential) bool {
// 检查发包限制
if canSend, reason := common.CanSendPacket(); !canSend {
common.LogError(fmt.Sprintf("MongoDB认证测试 %s:%s 受限: %s", info.Host, info.Ports, reason))
return false
}
// 构建MongoDB连接URI // 构建MongoDB连接URI
var uri string var uri string
if cred.Username != "" && cred.Password != "" { if cred.Username != "" && cred.Password != "" {
@ -242,13 +257,17 @@ func (p *MongoDBPlugin) testMongoCredential(ctx context.Context, info *common.Ho
// 连接到MongoDB // 连接到MongoDB
client, err := mongo.Connect(authCtx, clientOptions) client, err := mongo.Connect(authCtx, clientOptions)
if err != nil { if err != nil {
common.IncrementTCPFailedPacketCount()
return false return false
} }
// 连接成功计数TCP成功包
common.IncrementTCPSuccessPacketCount()
defer client.Disconnect(authCtx) defer client.Disconnect(authCtx)
// 测试连接 - 尝试ping数据库 // 测试连接 - 尝试ping数据库
err = client.Ping(authCtx, nil) err = client.Ping(authCtx, nil)
if err != nil { if err != nil {
// ping失败但已经连接成功了不需要额外计数
return false return false
} }

View File

@ -64,6 +64,12 @@ func (p *PostgreSQLPlugin) Scan(ctx context.Context, info *common.HostInfo) *Sca
func (p *PostgreSQLPlugin) testCredential(ctx context.Context, info *common.HostInfo, cred Credential) *sql.DB { func (p *PostgreSQLPlugin) testCredential(ctx context.Context, info *common.HostInfo, cred Credential) *sql.DB {
// 检查发包限制
if canSend, reason := common.CanSendPacket(); !canSend {
common.LogError(fmt.Sprintf("PostgreSQL连接 %s:%s 受限: %s", info.Host, info.Ports, reason))
return nil
}
connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/postgres?sslmode=disable&connect_timeout=%d", connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/postgres?sslmode=disable&connect_timeout=%d",
cred.Username, cred.Password, info.Host, info.Ports, common.Timeout) cred.Username, cred.Password, info.Host, info.Ports, common.Timeout)
@ -81,9 +87,13 @@ func (p *PostgreSQLPlugin) testCredential(ctx context.Context, info *common.Host
err = db.PingContext(pingCtx) err = db.PingContext(pingCtx)
if err != nil { if err != nil {
common.IncrementTCPFailedPacketCount()
db.Close() db.Close()
return nil return nil
} }
// 连接成功计数TCP成功包
common.IncrementTCPSuccessPacketCount()
return db return db
} }
@ -96,6 +106,16 @@ func (p *PostgreSQLPlugin) testCredential(ctx context.Context, info *common.Host
func (p *PostgreSQLPlugin) identifyService(ctx context.Context, info *common.HostInfo) *ScanResult { func (p *PostgreSQLPlugin) identifyService(ctx context.Context, info *common.HostInfo) *ScanResult {
target := fmt.Sprintf("%s:%s", info.Host, info.Ports) target := fmt.Sprintf("%s:%s", info.Host, info.Ports)
// 检查发包限制
if canSend, reason := common.CanSendPacket(); !canSend {
common.LogError(fmt.Sprintf("PostgreSQL识别 %s 受限: %s", target, reason))
return &ScanResult{
Success: false,
Service: "postgresql",
Error: fmt.Errorf("发包受限: %s", reason),
}
}
connStr := fmt.Sprintf("postgres://invalid:invalid@%s:%s/postgres?sslmode=disable&connect_timeout=%d", connStr := fmt.Sprintf("postgres://invalid:invalid@%s:%s/postgres?sslmode=disable&connect_timeout=%d",
info.Host, info.Ports, common.Timeout) info.Host, info.Ports, common.Timeout)
@ -114,6 +134,13 @@ func (p *PostgreSQLPlugin) identifyService(ctx context.Context, info *common.Hos
err = db.PingContext(pingCtx) err = db.PingContext(pingCtx)
// 统计包数量
if err != nil {
common.IncrementTCPFailedPacketCount()
} else {
common.IncrementTCPSuccessPacketCount()
}
// 改进识别逻辑任何PostgreSQL相关的响应都认为是有效服务 // 改进识别逻辑任何PostgreSQL相关的响应都认为是有效服务
var banner string var banner string
if err != nil { if err != nil {