fscan/plugins/services/mssql/connector.go
ZacharyZcR 4a3f281b6b refactor: 统一Plugins目录大小写为小写
- 将所有Plugins路径重命名为plugins
- 修复Git索引与实际文件系统大小写不一致问题
- 确保跨平台兼容性和路径一致性
2025-08-12 13:08:06 +08:00

210 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mssql
import (
"context"
"database/sql"
"fmt"
"net"
"strings"
"time"
mssqlDriver "github.com/denisenkom/go-mssqldb"
"github.com/shadow1ng/fscan/common"
"github.com/shadow1ng/fscan/plugins/base"
)
// MSSQLProxyDialer 自定义MSSQL代理拨号器
type MSSQLProxyDialer struct {
timeout time.Duration
}
// DialContext 实现mssql.Dialer接口支持socks代理
func (d *MSSQLProxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return common.WrapperTcpWithContext(ctx, network, addr)
}
// MSSQLConnection MSSQL连接包装器
type MSSQLConnection struct {
db *sql.DB
target string
info string
}
// MSSQLConnector MSSQL连接器实现
type MSSQLConnector struct{}
// NewMSSQLConnector 创建MSSQL连接器
func NewMSSQLConnector() *MSSQLConnector {
return &MSSQLConnector{}
}
// Connect 连接到MSSQL服务器不进行认证
func (c *MSSQLConnector) Connect(ctx context.Context, info *common.HostInfo) (interface{}, error) {
target := fmt.Sprintf("%s:%s", info.Host, info.Ports)
timeout := time.Duration(common.Timeout) * time.Second
// 尝试建立连接但不进行认证,使用空凭据进行连接尝试
db, dbInfo, err := c.createConnection(ctx, info.Host, info.Ports, "", "", timeout)
if err != nil {
// 检查是否是MSSQL服务相关错误
if c.isMSSQLError(err) {
// 即使连接失败但可以识别为MSSQL服务
return &MSSQLConnection{
db: nil,
target: target,
info: "Microsoft SQL Server (Service Detected)",
}, nil
}
return nil, err
}
return &MSSQLConnection{
db: db,
target: target,
info: dbInfo,
}, nil
}
// Authenticate 使用凭据进行认证
func (c *MSSQLConnector) Authenticate(ctx context.Context, conn interface{}, cred *base.Credential) error {
mssqlConn, ok := conn.(*MSSQLConnection)
if !ok {
return fmt.Errorf("invalid connection type")
}
// 解析目标地址
parts := strings.Split(mssqlConn.target, ":")
if len(parts) != 2 {
return fmt.Errorf("invalid target format")
}
host := parts[0]
port := parts[1]
timeout := time.Duration(common.Timeout) * time.Second
// 使用提供的凭据创建新连接
db, info, err := c.createConnection(ctx, host, port, cred.Username, cred.Password, timeout)
if err != nil {
return err
}
// 更新连接信息
if mssqlConn.db != nil {
mssqlConn.db.Close()
}
mssqlConn.db = db
mssqlConn.info = info
return nil
}
// Close 关闭连接
func (c *MSSQLConnector) Close(conn interface{}) error {
if mssqlConn, ok := conn.(*MSSQLConnection); ok && mssqlConn.db != nil {
return mssqlConn.db.Close()
}
return nil
}
// createConnection 创建MSSQL数据库连接
func (c *MSSQLConnector) createConnection(ctx context.Context, host, port, username, password string, timeout time.Duration) (*sql.DB, string, error) {
// 构造连接字符串
connStr := fmt.Sprintf(
"server=%s;user id=%s;password=%s;port=%s;encrypt=disable;timeout=%d",
host, username, password, port, int(timeout.Seconds()),
)
var db *sql.DB
var err error
// 检查是否需要使用socks代理
if common.Socks5Proxy != "" {
connector, connErr := mssqlDriver.NewConnector(connStr)
if connErr != nil {
return nil, "", connErr
}
connector.Dialer = &MSSQLProxyDialer{
timeout: timeout,
}
db = sql.OpenDB(connector)
} else {
db, err = sql.Open("mssql", connStr)
if err != nil {
return nil, "", err
}
}
// 设置连接参数
db.SetConnMaxLifetime(timeout)
db.SetConnMaxIdleTime(timeout)
db.SetMaxIdleConns(0)
db.SetMaxOpenConns(1)
// 创建ping上下文
pingCtx, pingCancel := context.WithTimeout(ctx, timeout)
defer pingCancel()
// 执行ping测试连接
err = db.PingContext(pingCtx)
if err != nil {
db.Close()
return nil, "", err
}
// 获取数据库信息
info := c.getDatabaseInfo(db)
return db, info, nil
}
// getDatabaseInfo 获取数据库版本信息
func (c *MSSQLConnector) getDatabaseInfo(db *sql.DB) string {
query := "SELECT @@VERSION"
var version string
err := db.QueryRow(query).Scan(&version)
if err != nil {
return "Microsoft SQL Server"
}
// 提取版本信息的关键部分
if strings.Contains(version, "Microsoft SQL Server") {
lines := strings.Split(version, "\n")
if len(lines) > 0 {
return strings.TrimSpace(lines[0])
}
}
return fmt.Sprintf("Microsoft SQL Server - %s", version)
}
// isMSSQLError 检查是否是MSSQL相关错误
func (c *MSSQLConnector) isMSSQLError(err error) bool {
if err == nil {
return false
}
errorStr := strings.ToLower(err.Error())
mssqlErrorIndicators := []string{
"login failed",
"cannot open database",
"invalid object",
"mssql:",
"sql server",
"sqlserver:",
"database",
"authentication failed",
"server principal",
"user does not have permission",
"the login is from an untrusted domain",
}
for _, indicator := range mssqlErrorIndicators {
if strings.Contains(errorStr, indicator) {
return true
}
}
return false
}