package mssql import ( "context" "database/sql" "fmt" "net" "strings" "time" mssqlDriver "github.com/denisenkom/go-mssqldb" "github.com/shadow1ng/fscan/common" "github.com/shadow1ng/fscan/plugins/base" ) // MSSQLProxyDialer 自定义MSSQL代理拨号器 type MSSQLProxyDialer struct { timeout time.Duration } // DialContext 实现mssql.Dialer接口,支持socks代理 func (d *MSSQLProxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { return common.WrapperTcpWithContext(ctx, network, addr) } // MSSQLConnection MSSQL连接包装器 type MSSQLConnection struct { db *sql.DB target string info string } // MSSQLConnector MSSQL连接器实现 type MSSQLConnector struct{} // NewMSSQLConnector 创建MSSQL连接器 func NewMSSQLConnector() *MSSQLConnector { return &MSSQLConnector{} } // Connect 连接到MSSQL服务器(不进行认证) func (c *MSSQLConnector) Connect(ctx context.Context, info *common.HostInfo) (interface{}, error) { target := fmt.Sprintf("%s:%s", info.Host, info.Ports) timeout := time.Duration(common.Timeout) * time.Second // 尝试建立连接但不进行认证,使用空凭据进行连接尝试 db, dbInfo, err := c.createConnection(ctx, info.Host, info.Ports, "", "", timeout) if err != nil { // 检查是否是MSSQL服务相关错误 if c.isMSSQLError(err) { // 即使连接失败,但可以识别为MSSQL服务 return &MSSQLConnection{ db: nil, target: target, info: "Microsoft SQL Server (Service Detected)", }, nil } return nil, err } return &MSSQLConnection{ db: db, target: target, info: dbInfo, }, nil } // Authenticate 使用凭据进行认证 func (c *MSSQLConnector) Authenticate(ctx context.Context, conn interface{}, cred *base.Credential) error { mssqlConn, ok := conn.(*MSSQLConnection) if !ok { return fmt.Errorf("invalid connection type") } // 解析目标地址 parts := strings.Split(mssqlConn.target, ":") if len(parts) != 2 { return fmt.Errorf("invalid target format") } host := parts[0] port := parts[1] timeout := time.Duration(common.Timeout) * time.Second // 使用提供的凭据创建新连接 db, info, err := c.createConnection(ctx, host, port, cred.Username, cred.Password, timeout) if err != nil { return err } // 更新连接信息 if mssqlConn.db != nil { mssqlConn.db.Close() } mssqlConn.db = db mssqlConn.info = info return nil } // Close 关闭连接 func (c *MSSQLConnector) Close(conn interface{}) error { if mssqlConn, ok := conn.(*MSSQLConnection); ok && mssqlConn.db != nil { return mssqlConn.db.Close() } return nil } // createConnection 创建MSSQL数据库连接 func (c *MSSQLConnector) createConnection(ctx context.Context, host, port, username, password string, timeout time.Duration) (*sql.DB, string, error) { // 构造连接字符串 connStr := fmt.Sprintf( "server=%s;user id=%s;password=%s;port=%s;encrypt=disable;timeout=%d", host, username, password, port, int(timeout.Seconds()), ) var db *sql.DB var err error // 检查是否需要使用socks代理 if common.Socks5Proxy != "" { connector, connErr := mssqlDriver.NewConnector(connStr) if connErr != nil { return nil, "", connErr } connector.Dialer = &MSSQLProxyDialer{ timeout: timeout, } db = sql.OpenDB(connector) } else { db, err = sql.Open("mssql", connStr) if err != nil { return nil, "", err } } // 设置连接参数 db.SetConnMaxLifetime(timeout) db.SetConnMaxIdleTime(timeout) db.SetMaxIdleConns(0) db.SetMaxOpenConns(1) // 创建ping上下文 pingCtx, pingCancel := context.WithTimeout(ctx, timeout) defer pingCancel() // 执行ping测试连接 err = db.PingContext(pingCtx) if err != nil { db.Close() return nil, "", err } // 获取数据库信息 info := c.getDatabaseInfo(db) return db, info, nil } // getDatabaseInfo 获取数据库版本信息 func (c *MSSQLConnector) getDatabaseInfo(db *sql.DB) string { query := "SELECT @@VERSION" var version string err := db.QueryRow(query).Scan(&version) if err != nil { return "Microsoft SQL Server" } // 提取版本信息的关键部分 if strings.Contains(version, "Microsoft SQL Server") { lines := strings.Split(version, "\n") if len(lines) > 0 { return strings.TrimSpace(lines[0]) } } return fmt.Sprintf("Microsoft SQL Server - %s", version) } // isMSSQLError 检查是否是MSSQL相关错误 func (c *MSSQLConnector) isMSSQLError(err error) bool { if err == nil { return false } errorStr := strings.ToLower(err.Error()) mssqlErrorIndicators := []string{ "login failed", "cannot open database", "invalid object", "mssql:", "sql server", "sqlserver:", "database", "authentication failed", "server principal", "user does not have permission", "the login is from an untrusted domain", } for _, indicator := range mssqlErrorIndicators { if strings.Contains(errorStr, indicator) { return true } } return false }