package postgresql import ( "context" "database/sql" "database/sql/driver" "fmt" "net" "strings" "time" "github.com/lib/pq" "github.com/shadow1ng/fscan/common" "github.com/shadow1ng/fscan/plugins/base" ) // PostgreSQLProxyDialer 自定义PostgreSQL代理拨号器 type PostgreSQLProxyDialer struct { timeout time.Duration } // Dial 实现pq.Dialer接口,支持socks代理 func (d *PostgreSQLProxyDialer) Dial(network, address string) (net.Conn, error) { return common.WrapperTcpWithTimeout(network, address, d.timeout) } // DialTimeout 实现具有超时的连接 func (d *PostgreSQLProxyDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { return common.WrapperTcpWithTimeout(network, address, timeout) } // PostgreSQLConnection PostgreSQL连接包装器 type PostgreSQLConnection struct { db *sql.DB target string info string } // PostgreSQLConnector PostgreSQL连接器实现 type PostgreSQLConnector struct{} // NewPostgreSQLConnector 创建PostgreSQL连接器 func NewPostgreSQLConnector() *PostgreSQLConnector { return &PostgreSQLConnector{} } // Connect 连接到PostgreSQL服务器(不进行认证) func (c *PostgreSQLConnector) Connect(ctx context.Context, info *common.HostInfo) (interface{}, error) { target := fmt.Sprintf("%s:%s", info.Host, info.Ports) // 尝试建立连接但不进行认证,使用空凭据进行连接尝试 db, dbInfo, err := c.createConnection(ctx, info.Host, info.Ports, "", "") if err != nil { // 检查是否是PostgreSQL服务相关错误 if c.isPostgreSQLError(err) { // 即使连接失败,但可以识别为PostgreSQL服务 return &PostgreSQLConnection{ db: nil, target: target, info: "PostgreSQL Database (Service Detected)", }, nil } return nil, err } return &PostgreSQLConnection{ db: db, target: target, info: dbInfo, }, nil } // Authenticate 使用凭据进行认证 func (c *PostgreSQLConnector) Authenticate(ctx context.Context, conn interface{}, cred *base.Credential) error { pgConn, ok := conn.(*PostgreSQLConnection) if !ok { return fmt.Errorf("invalid connection type") } // 解析目标地址 parts := strings.Split(pgConn.target, ":") if len(parts) != 2 { return fmt.Errorf("invalid target format") } host := parts[0] port := parts[1] // 使用提供的凭据创建新连接 db, info, err := c.createConnection(ctx, host, port, cred.Username, cred.Password) if err != nil { return err } // 更新连接信息 if pgConn.db != nil { pgConn.db.Close() } pgConn.db = db pgConn.info = info return nil } // Close 关闭连接 func (c *PostgreSQLConnector) Close(conn interface{}) error { if pgConn, ok := conn.(*PostgreSQLConnection); ok && pgConn.db != nil { return pgConn.db.Close() } return nil } // createConnection 创建PostgreSQL数据库连接 func (c *PostgreSQLConnector) createConnection(ctx context.Context, host, port, username, password string) (*sql.DB, string, error) { timeout := time.Duration(common.Timeout) * time.Second // 构造连接字符串 connStr := fmt.Sprintf( "postgres://%s:%s@%s:%s/postgres?sslmode=disable&connect_timeout=%d", username, password, host, port, int(timeout.Seconds()), ) var db *sql.DB var err error // 检查是否需要使用socks代理 if common.Socks5Proxy != "" { // 使用自定义dialer通过socks代理连接 dialer := &PostgreSQLProxyDialer{ timeout: timeout, } // 使用pq.DialOpen通过自定义dialer建立连接 conn, err := pq.DialOpen(dialer, connStr) if err != nil { return nil, "", err } // 转换为sql.DB进行测试 db = sql.OpenDB(&postgresConnector{conn: conn}) } else { // 使用标准连接方式 db, err = sql.Open("postgres", connStr) if err != nil { return nil, "", err } } // 设置连接参数 db.SetConnMaxLifetime(timeout) db.SetMaxOpenConns(1) db.SetMaxIdleConns(0) // 创建ping上下文 pingCtx, pingCancel := context.WithTimeout(ctx, timeout) defer pingCancel() // 使用上下文测试连接 err = db.PingContext(pingCtx) if err != nil { db.Close() return nil, "", err } // 获取数据库信息 info := c.getDatabaseInfo(db, pingCtx) return db, info, nil } // getDatabaseInfo 获取PostgreSQL数据库信息 func (c *PostgreSQLConnector) getDatabaseInfo(db *sql.DB, ctx context.Context) string { var version string err := db.QueryRowContext(ctx, "SELECT version()").Scan(&version) if err != nil { return "PostgreSQL Database" } // 提取版本信息的关键部分 if strings.Contains(version, "PostgreSQL") { parts := strings.Fields(version) if len(parts) >= 2 { return fmt.Sprintf("%s %s", parts[0], parts[1]) } } return version } // isPostgreSQLError 检查是否是PostgreSQL相关错误 func (c *PostgreSQLConnector) isPostgreSQLError(err error) bool { if err == nil { return false } errorStr := strings.ToLower(err.Error()) postgresErrorIndicators := []string{ "postgres", "postgresql", "authentication failed", "password authentication failed", "database", "connection refused", "pq:", "invalid authorization specification", "role", "does not exist", } for _, indicator := range postgresErrorIndicators { if strings.Contains(errorStr, indicator) { return true } } return false } // postgresConnector 封装driver.Conn为sql.driver.Connector type postgresConnector struct { conn driver.Conn } func (c *postgresConnector) Connect(ctx context.Context) (driver.Conn, error) { return c.conn, nil } func (c *postgresConnector) Driver() driver.Driver { return &pq.Driver{} }