fscan/Common/proxy/TLSDialer.go
ZacharyZcR 84b0bb1e28 refactor: 完成common包常量提取和代码重构优化
- 新增constants.go文件统一管理各包常量定义
- 提取logging、output、parsers、proxy包中的硬编码值
- 将30+个魔法数字替换为语义化常量
- 统一错误代码和消息格式
- 清理死代码和未使用变量
- 优化代码可维护性和可读性
- 保持完全向后兼容性

涉及包:
- common/logging: 日志级别和格式常量
- common/output: 输出配置和格式常量
- common/parsers: 解析器配置和验证常量
- common/proxy: 代理协议和错误常量
2025-08-06 21:29:30 +08:00

157 lines
3.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package proxy
import (
"context"
"crypto/tls"
"net"
"sync/atomic"
"time"
)
// tlsDialerWrapper TLS拨号器包装器
type tlsDialerWrapper struct {
dialer Dialer
config *ProxyConfig
stats *ProxyStats
}
func (t *tlsDialerWrapper) Dial(network, address string) (net.Conn, error) {
return t.dialer.Dial(network, address)
}
func (t *tlsDialerWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return t.dialer.DialContext(ctx, network, address)
}
func (t *tlsDialerWrapper) DialTLS(network, address string, config *tls.Config) (net.Conn, error) {
return t.DialTLSContext(context.Background(), network, address, config)
}
func (t *tlsDialerWrapper) DialTLSContext(ctx context.Context, network, address string, tlsConfig *tls.Config) (net.Conn, error) {
start := time.Now()
// 首先建立TCP连接
tcpConn, err := t.dialer.DialContext(ctx, network, address)
if err != nil {
return nil, NewProxyError(ErrTypeConnection, ErrMsgTLSTCPConnFailed, ErrCodeTLSTCPConnFailed, err)
}
// 创建TLS连接
tlsConn := tls.Client(tcpConn, tlsConfig)
// 设置TLS握手超时
if deadline, ok := ctx.Deadline(); ok {
tlsConn.SetDeadline(deadline)
} else {
tlsConn.SetDeadline(time.Now().Add(t.config.Timeout))
}
// 进行TLS握手
if err := tlsConn.Handshake(); err != nil {
tcpConn.Close()
atomic.AddInt64(&t.stats.FailedConnections, 1)
t.stats.LastError = err.Error()
return nil, NewProxyError(ErrTypeConnection, ErrMsgTLSHandshakeFailed, ErrCodeTLSHandshakeFailed, err)
}
// 清除deadline让上层代码管理超时
tlsConn.SetDeadline(time.Time{})
duration := time.Since(start)
t.updateAverageConnectTime(duration)
return &trackedTLSConn{
trackedConn: &trackedConn{
Conn: tlsConn,
stats: t.stats,
},
isTLS: true,
}, nil
}
// updateAverageConnectTime 更新平均连接时间
func (t *tlsDialerWrapper) updateAverageConnectTime(duration time.Duration) {
// 简单的移动平均
if t.stats.AverageConnectTime == 0 {
t.stats.AverageConnectTime = duration
} else {
t.stats.AverageConnectTime = (t.stats.AverageConnectTime + duration) / 2
}
}
// trackedConn 带统计的连接
type trackedConn struct {
net.Conn
stats *ProxyStats
bytesSent int64
bytesRecv int64
}
func (tc *trackedConn) Read(b []byte) (n int, err error) {
n, err = tc.Conn.Read(b)
if n > 0 {
atomic.AddInt64(&tc.bytesRecv, int64(n))
}
return n, err
}
func (tc *trackedConn) Write(b []byte) (n int, err error) {
n, err = tc.Conn.Write(b)
if n > 0 {
atomic.AddInt64(&tc.bytesSent, int64(n))
}
return n, err
}
func (tc *trackedConn) Close() error {
atomic.AddInt64(&tc.stats.ActiveConnections, -1)
return tc.Conn.Close()
}
// trackedTLSConn 带统计的TLS连接
type trackedTLSConn struct {
*trackedConn
isTLS bool
}
func (ttc *trackedTLSConn) ConnectionState() tls.ConnectionState {
if tlsConn, ok := ttc.Conn.(*tls.Conn); ok {
return tlsConn.ConnectionState()
}
return tls.ConnectionState{}
}
func (ttc *trackedTLSConn) Handshake() error {
if tlsConn, ok := ttc.Conn.(*tls.Conn); ok {
return tlsConn.Handshake()
}
return nil
}
func (ttc *trackedTLSConn) OCSPResponse() []byte {
if tlsConn, ok := ttc.Conn.(*tls.Conn); ok {
return tlsConn.OCSPResponse()
}
return nil
}
func (ttc *trackedTLSConn) PeerCertificates() []*tls.Certificate {
if tlsConn, ok := ttc.Conn.(*tls.Conn); ok {
state := tlsConn.ConnectionState()
var certs []*tls.Certificate
for _, cert := range state.PeerCertificates {
certs = append(certs, &tls.Certificate{
Certificate: [][]byte{cert.Raw},
})
}
return certs
}
return nil
}
func (ttc *trackedTLSConn) VerifyHostname(host string) error {
if tlsConn, ok := ttc.Conn.(*tls.Conn); ok {
return tlsConn.VerifyHostname(host)
}
return nil
}