mirror of
https://github.com/shadow1ng/fscan.git
synced 2025-09-14 05:56:46 +08:00
210 lines
4.9 KiB
Go
210 lines
4.9 KiB
Go
package mssql
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"net"
|
||
"strings"
|
||
"time"
|
||
|
||
mssqlDriver "github.com/denisenkom/go-mssqldb"
|
||
"github.com/shadow1ng/fscan/common"
|
||
"github.com/shadow1ng/fscan/plugins/base"
|
||
)
|
||
|
||
// MSSQLProxyDialer 自定义MSSQL代理拨号器
|
||
type MSSQLProxyDialer struct {
|
||
timeout time.Duration
|
||
}
|
||
|
||
// DialContext 实现mssql.Dialer接口,支持socks代理
|
||
func (d *MSSQLProxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
return common.WrapperTcpWithContext(ctx, network, addr)
|
||
}
|
||
|
||
// MSSQLConnection MSSQL连接包装器
|
||
type MSSQLConnection struct {
|
||
db *sql.DB
|
||
target string
|
||
info string
|
||
}
|
||
|
||
// MSSQLConnector MSSQL连接器实现
|
||
type MSSQLConnector struct{}
|
||
|
||
// NewMSSQLConnector 创建MSSQL连接器
|
||
func NewMSSQLConnector() *MSSQLConnector {
|
||
return &MSSQLConnector{}
|
||
}
|
||
|
||
// Connect 连接到MSSQL服务器(不进行认证)
|
||
func (c *MSSQLConnector) Connect(ctx context.Context, info *common.HostInfo) (interface{}, error) {
|
||
target := fmt.Sprintf("%s:%s", info.Host, info.Ports)
|
||
timeout := time.Duration(common.Timeout) * time.Second
|
||
|
||
// 尝试建立连接但不进行认证,使用空凭据进行连接尝试
|
||
db, dbInfo, err := c.createConnection(ctx, info.Host, info.Ports, "", "", timeout)
|
||
if err != nil {
|
||
// 检查是否是MSSQL服务相关错误
|
||
if c.isMSSQLError(err) {
|
||
// 即使连接失败,但可以识别为MSSQL服务
|
||
return &MSSQLConnection{
|
||
db: nil,
|
||
target: target,
|
||
info: "Microsoft SQL Server (Service Detected)",
|
||
}, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
return &MSSQLConnection{
|
||
db: db,
|
||
target: target,
|
||
info: dbInfo,
|
||
}, nil
|
||
}
|
||
|
||
// Authenticate 使用凭据进行认证
|
||
func (c *MSSQLConnector) Authenticate(ctx context.Context, conn interface{}, cred *base.Credential) error {
|
||
mssqlConn, ok := conn.(*MSSQLConnection)
|
||
if !ok {
|
||
return fmt.Errorf("invalid connection type")
|
||
}
|
||
|
||
// 解析目标地址
|
||
parts := strings.Split(mssqlConn.target, ":")
|
||
if len(parts) != 2 {
|
||
return fmt.Errorf("invalid target format")
|
||
}
|
||
|
||
host := parts[0]
|
||
port := parts[1]
|
||
timeout := time.Duration(common.Timeout) * time.Second
|
||
|
||
// 使用提供的凭据创建新连接
|
||
db, info, err := c.createConnection(ctx, host, port, cred.Username, cred.Password, timeout)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 更新连接信息
|
||
if mssqlConn.db != nil {
|
||
mssqlConn.db.Close()
|
||
}
|
||
mssqlConn.db = db
|
||
mssqlConn.info = info
|
||
|
||
return nil
|
||
}
|
||
|
||
// Close 关闭连接
|
||
func (c *MSSQLConnector) Close(conn interface{}) error {
|
||
if mssqlConn, ok := conn.(*MSSQLConnection); ok && mssqlConn.db != nil {
|
||
return mssqlConn.db.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// createConnection 创建MSSQL数据库连接
|
||
func (c *MSSQLConnector) createConnection(ctx context.Context, host, port, username, password string, timeout time.Duration) (*sql.DB, string, error) {
|
||
// 构造连接字符串
|
||
connStr := fmt.Sprintf(
|
||
"server=%s;user id=%s;password=%s;port=%s;encrypt=disable;timeout=%d",
|
||
host, username, password, port, int(timeout.Seconds()),
|
||
)
|
||
|
||
var db *sql.DB
|
||
var err error
|
||
|
||
// 检查是否需要使用socks代理
|
||
if common.Socks5Proxy != "" {
|
||
connector, connErr := mssqlDriver.NewConnector(connStr)
|
||
if connErr != nil {
|
||
return nil, "", connErr
|
||
}
|
||
|
||
connector.Dialer = &MSSQLProxyDialer{
|
||
timeout: timeout,
|
||
}
|
||
|
||
db = sql.OpenDB(connector)
|
||
} else {
|
||
db, err = sql.Open("mssql", connStr)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
}
|
||
|
||
// 设置连接参数
|
||
db.SetConnMaxLifetime(timeout)
|
||
db.SetConnMaxIdleTime(timeout)
|
||
db.SetMaxIdleConns(0)
|
||
db.SetMaxOpenConns(1)
|
||
|
||
// 创建ping上下文
|
||
pingCtx, pingCancel := context.WithTimeout(ctx, timeout)
|
||
defer pingCancel()
|
||
|
||
// 执行ping测试连接
|
||
err = db.PingContext(pingCtx)
|
||
if err != nil {
|
||
db.Close()
|
||
return nil, "", err
|
||
}
|
||
|
||
// 获取数据库信息
|
||
info := c.getDatabaseInfo(db)
|
||
return db, info, nil
|
||
}
|
||
|
||
// getDatabaseInfo 获取数据库版本信息
|
||
func (c *MSSQLConnector) getDatabaseInfo(db *sql.DB) string {
|
||
query := "SELECT @@VERSION"
|
||
var version string
|
||
|
||
err := db.QueryRow(query).Scan(&version)
|
||
if err != nil {
|
||
return "Microsoft SQL Server"
|
||
}
|
||
|
||
// 提取版本信息的关键部分
|
||
if strings.Contains(version, "Microsoft SQL Server") {
|
||
lines := strings.Split(version, "\n")
|
||
if len(lines) > 0 {
|
||
return strings.TrimSpace(lines[0])
|
||
}
|
||
}
|
||
|
||
return fmt.Sprintf("Microsoft SQL Server - %s", version)
|
||
}
|
||
|
||
// isMSSQLError 检查是否是MSSQL相关错误
|
||
func (c *MSSQLConnector) isMSSQLError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
errorStr := strings.ToLower(err.Error())
|
||
mssqlErrorIndicators := []string{
|
||
"login failed",
|
||
"cannot open database",
|
||
"invalid object",
|
||
"mssql:",
|
||
"sql server",
|
||
"sqlserver:",
|
||
"database",
|
||
"authentication failed",
|
||
"server principal",
|
||
"user does not have permission",
|
||
"the login is from an untrusted domain",
|
||
}
|
||
|
||
for _, indicator := range mssqlErrorIndicators {
|
||
if strings.Contains(errorStr, indicator) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
} |