refactor: 重构main.go入口文件并引入依赖注入架构

- 创建app包实现依赖注入容器和初始化器模式
- 重构main.go为六阶段清晰的初始化流程
- 新增结构化错误处理替代简陋的os.Exit调用
- 为HostInfo添加辅助函数增强功能但保持向后兼容
- 引入TargetInfo包装器支持上下文和元数据管理
- 优化代码组织提升可维护性和可测试性
This commit is contained in:
ZacharyZcR 2025-08-12 14:37:28 +08:00
parent 338dd60c3e
commit b463984e78
7 changed files with 515 additions and 72 deletions

View File

@ -30,12 +30,8 @@ type ScanPlugin = base.ScanPlugin
// 插件类型常量
const (
PluginTypeService = base.PluginTypeService
PluginTypeWeb = base.PluginTypeWeb
PluginTypeLocal = base.PluginTypeLocal
PluginTypeBrute = base.PluginTypeBrute
PluginTypePoc = base.PluginTypePoc
PluginTypeScan = base.PluginTypeScan
)
// 全局插件管理器

109
Common/hostinfo_ext.go Normal file
View File

@ -0,0 +1,109 @@
package common
import (
"fmt"
"strconv"
)
// HostInfoHelper 提供HostInfo的辅助方法
// 使用函数而不是方法,保持向后兼容
// GetHost 获取主机地址
func GetHost(h *HostInfo) string {
return h.Host
}
// GetPort 获取端口号(转换为整数)
func GetPort(h *HostInfo) (int, error) {
if h.Ports == "" {
return 0, fmt.Errorf("端口未设置")
}
return strconv.Atoi(h.Ports)
}
// GetURL 获取URL地址
func GetURL(h *HostInfo) string {
return h.Url
}
// IsWebTarget 判断是否为Web目标
func IsWebTarget(h *HostInfo) bool {
return h.Url != ""
}
// HasPort 检查是否设置了端口
func HasPort(h *HostInfo) bool {
return h.Ports != ""
}
// HasHost 检查是否设置了主机
func HasHost(h *HostInfo) bool {
return h.Host != ""
}
// CloneHostInfo 克隆HostInfo深拷贝
func CloneHostInfo(h *HostInfo) HostInfo {
cloned := HostInfo{
Host: h.Host,
Ports: h.Ports,
Url: h.Url,
}
// 深拷贝Infostr切片
if h.Infostr != nil {
cloned.Infostr = make([]string, len(h.Infostr))
copy(cloned.Infostr, h.Infostr)
}
return cloned
}
// ValidateHostInfo 验证HostInfo的有效性
func ValidateHostInfo(h *HostInfo) error {
if h.Host == "" && h.Url == "" {
return fmt.Errorf("主机地址或URL必须至少指定一个")
}
// 验证端口格式(如果指定了)
if h.Ports != "" {
if _, err := GetPort(h); err != nil {
return fmt.Errorf("端口格式无效: %v", err)
}
}
return nil
}
// HostInfoString 返回HostInfo的字符串表示
func HostInfoString(h *HostInfo) string {
if IsWebTarget(h) {
return h.Url
}
if HasPort(h) {
return fmt.Sprintf("%s:%s", h.Host, h.Ports)
}
return h.Host
}
// AddInfo 添加附加信息
func AddInfo(h *HostInfo, info string) {
if h.Infostr == nil {
h.Infostr = make([]string, 0)
}
h.Infostr = append(h.Infostr, info)
}
// GetInfo 获取所有附加信息
func GetInfo(h *HostInfo) []string {
if h.Infostr == nil {
return []string{}
}
return h.Infostr
}
// HasInfo 检查是否有附加信息
func HasInfo(h *HostInfo) bool {
return len(h.Infostr) > 0
}

118
Common/target.go Normal file
View File

@ -0,0 +1,118 @@
package common
import (
"context"
)
// TargetInfo 包装HostInfo提供更丰富的功能
type TargetInfo struct {
*HostInfo // 嵌入HostInfo保持向后兼容
context context.Context
metadata map[string]interface{}
}
// NewTargetInfo 创建新的目标信息
func NewTargetInfo(hostInfo HostInfo) *TargetInfo {
return &TargetInfo{
HostInfo: &hostInfo,
context: context.Background(),
metadata: make(map[string]interface{}),
}
}
// NewTargetInfoFromPtr 从HostInfo指针创建目标信息
func NewTargetInfoFromPtr(hostInfo *HostInfo) *TargetInfo {
return &TargetInfo{
HostInfo: hostInfo,
context: context.Background(),
metadata: make(map[string]interface{}),
}
}
// WithContext 设置上下文
func (t *TargetInfo) WithContext(ctx context.Context) *TargetInfo {
t.context = ctx
return t
}
// GetContext 获取上下文
func (t *TargetInfo) GetContext() context.Context {
if t.context == nil {
t.context = context.Background()
}
return t.context
}
// SetMetadata 设置元数据
func (t *TargetInfo) SetMetadata(key string, value interface{}) *TargetInfo {
if t.metadata == nil {
t.metadata = make(map[string]interface{})
}
t.metadata[key] = value
return t
}
// GetMetadata 获取元数据
func (t *TargetInfo) GetMetadata(key string) (interface{}, bool) {
if t.metadata == nil {
return nil, false
}
value, exists := t.metadata[key]
return value, exists
}
// GetAllMetadata 获取所有元数据
func (t *TargetInfo) GetAllMetadata() map[string]interface{} {
if t.metadata == nil {
return make(map[string]interface{})
}
// 返回副本,防止外部修改
result := make(map[string]interface{})
for k, v := range t.metadata {
result[k] = v
}
return result
}
// Clone 克隆目标信息
func (t *TargetInfo) Clone() *TargetInfo {
clonedHost := CloneHostInfo(t.HostInfo)
cloned := &TargetInfo{
HostInfo: &clonedHost,
context: t.context,
metadata: make(map[string]interface{}),
}
// 复制元数据
for k, v := range t.metadata {
cloned.metadata[k] = v
}
return cloned
}
// GetHostInfo 获取原始HostInfo向后兼容
func (t *TargetInfo) GetHostInfo() HostInfo {
return *t.HostInfo
}
// Validate 验证目标信息
func (t *TargetInfo) Validate() error {
return ValidateHostInfo(t.HostInfo)
}
// String 返回字符串表示
func (t *TargetInfo) String() string {
return HostInfoString(t.HostInfo)
}
// IsValid 检查目标是否有效
func (t *TargetInfo) IsValid() bool {
return t.Validate() == nil
}
// HasMetadata 检查是否有指定的元数据
func (t *TargetInfo) HasMetadata(key string) bool {
_, exists := t.GetMetadata(key)
return exists
}

107
app/container.go Normal file
View File

@ -0,0 +1,107 @@
package app
import (
"context"
"fmt"
"sync"
"github.com/shadow1ng/fscan/common"
"github.com/shadow1ng/fscan/core"
)
// Container 依赖注入容器
type Container struct {
services map[string]interface{}
initializers []Initializer
mu sync.RWMutex
initialized bool
}
// NewContainer 创建新的容器
func NewContainer() *Container {
container := &Container{
services: make(map[string]interface{}),
}
// 注册默认初始化器
container.AddInitializer(&PluginInitializer{})
return container
}
// AddInitializer 添加初始化器
func (c *Container) AddInitializer(init Initializer) {
c.initializers = append(c.initializers, init)
}
// Register 注册服务
func (c *Container) Register(name string, service interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
c.services[name] = service
}
// Get 获取服务
func (c *Container) Get(name string) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
service, exists := c.services[name]
return service, exists
}
// Initialize 初始化容器和所有服务
func (c *Container) Initialize() error {
if c.initialized {
return nil
}
// 执行所有初始化器
for _, initializer := range c.initializers {
if err := initializer.Initialize(); err != nil {
return WrapError(ErrInitFailed, err)
}
}
c.initialized = true
return nil
}
// RunScan 执行扫描包装现有的core.RunScan
func (c *Container) RunScan(ctx context.Context, info common.HostInfo) error {
// 使用新的验证函数
if err := common.ValidateHostInfo(&info); err != nil {
return WrapError(ErrScanFailed, err)
}
// 创建目标信息(展示新功能,但保持兼容)
target := common.NewTargetInfo(info)
target.WithContext(ctx)
target.SetMetadata("container_managed", true)
target.SetMetadata("validation_passed", true)
// 记录扫描信息
c.logScanInfo(target)
// 调用现有的扫描逻辑
core.RunScan(info)
return nil
}
// logScanInfo 记录扫描信息
func (c *Container) logScanInfo(target *common.TargetInfo) {
targetStr := target.String()
if targetStr != "" {
common.LogDebug(fmt.Sprintf("容器管理的扫描目标: %s", targetStr))
}
if target.HasMetadata("validation_passed") {
common.LogDebug("目标验证通过")
}
}
// Cleanup 清理资源
func (c *Container) Cleanup() {
// 清理输出资源
common.CloseOutput()
}

43
app/errors.go Normal file
View File

@ -0,0 +1,43 @@
package app
import "fmt"
// AppError 应用程序错误类型
type AppError struct {
Code int
Message string
Cause error
}
func (e *AppError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("[%d] %s: %v", e.Code, e.Message, e.Cause)
}
return fmt.Sprintf("[%d] %s", e.Code, e.Message)
}
// 预定义错误类型
var (
ErrInitFailed = &AppError{Code: 1, Message: "初始化失败"}
ErrParseFailed = &AppError{Code: 2, Message: "参数解析失败"}
ErrOutputFailed = &AppError{Code: 3, Message: "输出初始化失败"}
ErrScanFailed = &AppError{Code: 4, Message: "扫描执行失败"}
)
// NewAppError 创建新的应用程序错误
func NewAppError(code int, message string, cause error) *AppError {
return &AppError{
Code: code,
Message: message,
Cause: cause,
}
}
// WrapError 包装错误为应用程序错误
func WrapError(baseErr *AppError, cause error) *AppError {
return &AppError{
Code: baseErr.Code,
Message: baseErr.Message,
Cause: cause,
}
}

66
app/initializer.go Normal file
View File

@ -0,0 +1,66 @@
package app
import (
"sort"
"github.com/shadow1ng/fscan/common"
"github.com/shadow1ng/fscan/plugins/base"
)
// Initializer 初始化器接口
type Initializer interface {
Initialize() error
Name() string
}
// PluginInitializer 插件初始化器
type PluginInitializer struct{}
func (p *PluginInitializer) Name() string {
return "PluginInitializer"
}
func (p *PluginInitializer) Initialize() error {
var localPlugins []string
// 获取所有注册的插件
allPlugins := base.GlobalPluginRegistry.GetAll()
for _, pluginName := range allPlugins {
metadata := base.GlobalPluginRegistry.GetMetadata(pluginName)
if metadata != nil && metadata.Category == "local" {
localPlugins = append(localPlugins, pluginName)
}
}
// 排序以保持一致性
sort.Strings(localPlugins)
// 设置全局变量
common.LocalPluginsList = localPlugins
return nil
}
// LoggerInitializer 日志初始化器
type LoggerInitializer struct{}
func (l *LoggerInitializer) Name() string {
return "LoggerInitializer"
}
func (l *LoggerInitializer) Initialize() error {
common.InitLogger()
return nil
}
// OutputInitializer 输出初始化器
type OutputInitializer struct{}
func (o *OutputInitializer) Name() string {
return "OutputInitializer"
}
func (o *OutputInitializer) Initialize() error {
return common.InitOutput()
}

88
main.go
View File

@ -1,58 +1,62 @@
package main
import (
"context"
"fmt"
"os"
"sort"
"github.com/shadow1ng/fscan/app"
"github.com/shadow1ng/fscan/common"
"github.com/shadow1ng/fscan/core"
"github.com/shadow1ng/fscan/plugins/base"
)
// initLocalPlugins 初始化本地插件列表
func initLocalPlugins() {
var localPlugins []string
// 获取所有注册的插件
allPlugins := base.GlobalPluginRegistry.GetAll()
for _, pluginName := range allPlugins {
metadata := base.GlobalPluginRegistry.GetMetadata(pluginName)
if metadata != nil && metadata.Category == "local" {
localPlugins = append(localPlugins, pluginName)
}
}
// 排序以保持一致性
sort.Strings(localPlugins)
// 设置全局变量
common.LocalPluginsList = localPlugins
}
func main() {
// 初始化本地插件列表
initLocalPlugins()
// 创建应用容器
container := app.NewContainer()
var Info common.HostInfo
common.Flag(&Info)
// 第一阶段:基础初始化(插件系统)
if err := container.Initialize(); err != nil {
handleError("基础初始化失败", err)
}
defer container.Cleanup()
// 在flag解析后初始化logger确保LogLevel参数生效
common.InitLogger()
// 第二阶段:解析配置
var info common.HostInfo
common.Flag(&info)
// 解析 CLI 参数
if err := common.Parse(&Info); err != nil {
// 第三阶段日志初始化依赖于flag解析
logInit := &app.LoggerInitializer{}
if err := logInit.Initialize(); err != nil {
handleError("日志初始化失败", err)
}
// 第四阶段:参数解析和验证
if err := common.Parse(&info); err != nil {
handleError("参数解析失败", err)
}
// 第五阶段:输出系统初始化
outputInit := &app.OutputInitializer{}
if err := outputInit.Initialize(); err != nil {
handleError("输出初始化失败", err)
}
// 第六阶段:执行扫描
ctx := context.Background()
if err := container.RunScan(ctx, info); err != nil {
handleError("扫描失败", err)
}
}
func handleError(msg string, err error) {
// 检查是否是应用程序错误
if appErr, ok := err.(*app.AppError); ok {
common.LogError(fmt.Sprintf("%s: %s", msg, appErr.Message))
if appErr.Cause != nil {
common.LogError(fmt.Sprintf("详细错误: %v", appErr.Cause))
}
os.Exit(appErr.Code)
} else {
common.LogError(fmt.Sprintf("%s: %v", msg, err))
os.Exit(1)
}
// 初始化输出系统,如果失败则直接退出
if err := common.InitOutput(); err != nil {
common.LogError(fmt.Sprintf("初始化输出系统失败: %v", err))
os.Exit(1)
}
defer common.CloseOutput()
// 执行 CLI 扫描逻辑
core.RunScan(Info)
}