fscan/plugins/services/postgresql/connector.go
ZacharyZcR 4a3f281b6b refactor: 统一Plugins目录大小写为小写
- 将所有Plugins路径重命名为plugins
- 修复Git索引与实际文件系统大小写不一致问题
- 确保跨平台兼容性和路径一致性
2025-08-12 13:08:06 +08:00

229 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package postgresql
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net"
"strings"
"time"
"github.com/lib/pq"
"github.com/shadow1ng/fscan/common"
"github.com/shadow1ng/fscan/plugins/base"
)
// PostgreSQLProxyDialer 自定义PostgreSQL代理拨号器
type PostgreSQLProxyDialer struct {
timeout time.Duration
}
// Dial 实现pq.Dialer接口支持socks代理
func (d *PostgreSQLProxyDialer) Dial(network, address string) (net.Conn, error) {
return common.WrapperTcpWithTimeout(network, address, d.timeout)
}
// DialTimeout 实现具有超时的连接
func (d *PostgreSQLProxyDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return common.WrapperTcpWithTimeout(network, address, timeout)
}
// PostgreSQLConnection PostgreSQL连接包装器
type PostgreSQLConnection struct {
db *sql.DB
target string
info string
}
// PostgreSQLConnector PostgreSQL连接器实现
type PostgreSQLConnector struct{}
// NewPostgreSQLConnector 创建PostgreSQL连接器
func NewPostgreSQLConnector() *PostgreSQLConnector {
return &PostgreSQLConnector{}
}
// Connect 连接到PostgreSQL服务器不进行认证
func (c *PostgreSQLConnector) Connect(ctx context.Context, info *common.HostInfo) (interface{}, error) {
target := fmt.Sprintf("%s:%s", info.Host, info.Ports)
// 尝试建立连接但不进行认证,使用空凭据进行连接尝试
db, dbInfo, err := c.createConnection(ctx, info.Host, info.Ports, "", "")
if err != nil {
// 检查是否是PostgreSQL服务相关错误
if c.isPostgreSQLError(err) {
// 即使连接失败但可以识别为PostgreSQL服务
return &PostgreSQLConnection{
db: nil,
target: target,
info: "PostgreSQL Database (Service Detected)",
}, nil
}
return nil, err
}
return &PostgreSQLConnection{
db: db,
target: target,
info: dbInfo,
}, nil
}
// Authenticate 使用凭据进行认证
func (c *PostgreSQLConnector) Authenticate(ctx context.Context, conn interface{}, cred *base.Credential) error {
pgConn, ok := conn.(*PostgreSQLConnection)
if !ok {
return fmt.Errorf("invalid connection type")
}
// 解析目标地址
parts := strings.Split(pgConn.target, ":")
if len(parts) != 2 {
return fmt.Errorf("invalid target format")
}
host := parts[0]
port := parts[1]
// 使用提供的凭据创建新连接
db, info, err := c.createConnection(ctx, host, port, cred.Username, cred.Password)
if err != nil {
return err
}
// 更新连接信息
if pgConn.db != nil {
pgConn.db.Close()
}
pgConn.db = db
pgConn.info = info
return nil
}
// Close 关闭连接
func (c *PostgreSQLConnector) Close(conn interface{}) error {
if pgConn, ok := conn.(*PostgreSQLConnection); ok && pgConn.db != nil {
return pgConn.db.Close()
}
return nil
}
// createConnection 创建PostgreSQL数据库连接
func (c *PostgreSQLConnector) createConnection(ctx context.Context, host, port, username, password string) (*sql.DB, string, error) {
timeout := time.Duration(common.Timeout) * time.Second
// 构造连接字符串
connStr := fmt.Sprintf(
"postgres://%s:%s@%s:%s/postgres?sslmode=disable&connect_timeout=%d",
username, password, host, port, int(timeout.Seconds()),
)
var db *sql.DB
var err error
// 检查是否需要使用socks代理
if common.Socks5Proxy != "" {
// 使用自定义dialer通过socks代理连接
dialer := &PostgreSQLProxyDialer{
timeout: timeout,
}
// 使用pq.DialOpen通过自定义dialer建立连接
conn, err := pq.DialOpen(dialer, connStr)
if err != nil {
return nil, "", err
}
// 转换为sql.DB进行测试
db = sql.OpenDB(&postgresConnector{conn: conn})
} else {
// 使用标准连接方式
db, err = sql.Open("postgres", connStr)
if err != nil {
return nil, "", err
}
}
// 设置连接参数
db.SetConnMaxLifetime(timeout)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(0)
// 创建ping上下文
pingCtx, pingCancel := context.WithTimeout(ctx, timeout)
defer pingCancel()
// 使用上下文测试连接
err = db.PingContext(pingCtx)
if err != nil {
db.Close()
return nil, "", err
}
// 获取数据库信息
info := c.getDatabaseInfo(db, pingCtx)
return db, info, nil
}
// getDatabaseInfo 获取PostgreSQL数据库信息
func (c *PostgreSQLConnector) getDatabaseInfo(db *sql.DB, ctx context.Context) string {
var version string
err := db.QueryRowContext(ctx, "SELECT version()").Scan(&version)
if err != nil {
return "PostgreSQL Database"
}
// 提取版本信息的关键部分
if strings.Contains(version, "PostgreSQL") {
parts := strings.Fields(version)
if len(parts) >= 2 {
return fmt.Sprintf("%s %s", parts[0], parts[1])
}
}
return version
}
// isPostgreSQLError 检查是否是PostgreSQL相关错误
func (c *PostgreSQLConnector) isPostgreSQLError(err error) bool {
if err == nil {
return false
}
errorStr := strings.ToLower(err.Error())
postgresErrorIndicators := []string{
"postgres",
"postgresql",
"authentication failed",
"password authentication failed",
"database",
"connection refused",
"pq:",
"invalid authorization specification",
"role",
"does not exist",
}
for _, indicator := range postgresErrorIndicators {
if strings.Contains(errorStr, indicator) {
return true
}
}
return false
}
// postgresConnector 封装driver.Conn为sql.driver.Connector
type postgresConnector struct {
conn driver.Conn
}
func (c *postgresConnector) Connect(ctx context.Context) (driver.Conn, error) {
return c.conn, nil
}
func (c *postgresConnector) Driver() driver.Driver {
return &pq.Driver{}
}