From 42745a3da21d10c1ac5cb354c4d68f5567a61879 Mon Sep 17 00:00:00 2001
From: fatedier <fatedier@gmail.com>
Date: Wed, 11 Aug 2021 23:10:35 +0800
Subject: [PATCH] frpc: add disable_custom_tls_first_byte to not send first
 custom tls to frps (#2520)

---
 client/control.go               |  2 +-
 client/service.go               |  2 +-
 conf/frpc_full.ini              |  4 ++++
 pkg/config/client.go            |  3 +++
 pkg/util/net/conn.go            |  4 ++--
 pkg/util/net/tls.go             | 16 +++++++++++++---
 server/service.go               | 12 +++++++-----
 test/e2e/basic/client_server.go | 18 ++++++++++++++++++
 8 files changed, 49 insertions(+), 12 deletions(-)

diff --git a/client/control.go b/client/control.go
index 42cc0464..50cca409 100644
--- a/client/control.go
+++ b/client/control.go
@@ -228,7 +228,7 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
 		}
 
 		address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))
-		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig)
+		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte)
 
 		if err != nil {
 			xl.Warn("start new connection to server error: %v", err)
diff --git a/client/service.go b/client/service.go
index 3033ee2a..63987574 100644
--- a/client/service.go
+++ b/client/service.go
@@ -232,7 +232,7 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
 	}
 
 	address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))
-	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig)
+	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte)
 	if err != nil {
 		return
 	}
diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini
index b07efc2b..71e5a143 100644
--- a/conf/frpc_full.ini
+++ b/conf/frpc_full.ini
@@ -105,6 +105,10 @@ udp_packet_size = 1500
 # include other config files for proxies.
 # includes = ./confd/*.ini
 
+# By default, frpc will connect frps with first custom byte if tls is enabled.
+# If DisableCustomTLSFirstByte is true, frpc will not send that custom byte.
+disable_custom_tls_first_byte = false
+
 # 'ssh' is the unique proxy name
 # if user in [common] section is not empty, it will be changed to {user}.{proxy} such as 'your_name.ssh'
 [ssh]
diff --git a/pkg/config/client.go b/pkg/config/client.go
index 6cfeb6c0..b2efb79a 100644
--- a/pkg/config/client.go
+++ b/pkg/config/client.go
@@ -124,6 +124,9 @@ type ClientCommonConf struct {
 	// TLSServerName specifices the custom server name of tls certificate. By
 	// default, server name if same to ServerAddr.
 	TLSServerName string `ini:"tls_server_name" json:"tls_server_name"`
+	// By default, frpc will connect frps with first custom byte if tls is enabled.
+	// If DisableCustomTLSFirstByte is true, frpc will not send that custom byte.
+	DisableCustomTLSFirstByte bool `ini:"disable_custom_tls_first_byte" json:"disable_custom_tls_first_byte"`
 	// HeartBeatInterval specifies at what interval heartbeats are sent to the
 	// server, in seconds. It is not recommended to change this value. By
 	// default, this value is 30.
diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go
index 5a33dd62..ccb199e5 100644
--- a/pkg/util/net/conn.go
+++ b/pkg/util/net/conn.go
@@ -228,7 +228,7 @@ func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.
 	}
 }
 
-func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config) (c net.Conn, err error) {
+func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) {
 	c, err = ConnectServerByProxy(proxyURL, protocol, addr)
 	if err != nil {
 		return
@@ -238,6 +238,6 @@ func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string,
 		return
 	}
 
-	c = WrapTLSClientConn(c, tlsConfig)
+	c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte)
 	return
 }
diff --git a/pkg/util/net/tls.go b/pkg/util/net/tls.go
index 52a17787..80a98aaa 100644
--- a/pkg/util/net/tls.go
+++ b/pkg/util/net/tls.go
@@ -27,13 +27,18 @@ var (
 	FRPTLSHeadByte = 0x17
 )
 
-func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out net.Conn) {
-	c.Write([]byte{byte(FRPTLSHeadByte)})
+func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (out net.Conn) {
+	if !disableCustomTLSHeadByte {
+		c.Write([]byte{byte(FRPTLSHeadByte)})
+	}
 	out = tls.Client(c, tlsConfig)
 	return
 }
 
-func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration) (out net.Conn, err error) {
+func CheckAndEnableTLSServerConnWithTimeout(
+	c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration,
+) (out net.Conn, isTLS bool, custom bool, err error) {
+
 	sc, r := gnet.NewSharedConnSize(c, 2)
 	buf := make([]byte, 1)
 	var n int
@@ -46,6 +51,11 @@ func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, t
 
 	if n == 1 && int(buf[0]) == FRPTLSHeadByte {
 		out = tls.Server(c, tlsConfig)
+		isTLS = true
+		custom = true
+	} else if n == 1 && int(buf[0]) == 0x16 {
+		out = tls.Server(sc, tlsConfig)
+		isTLS = true
 	} else {
 		if tlsOnly {
 			err = fmt.Errorf("non-TLS connection received on a TlsOnly server")
diff --git a/server/service.go b/server/service.go
index 677bda0d..b678010f 100644
--- a/server/service.go
+++ b/server/service.go
@@ -258,8 +258,9 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 	}
 
 	// frp tls listener
-	svr.tlsListener = svr.muxer.Listen(1, 1, func(data []byte) bool {
-		return int(data[0]) == frpNet.FRPTLSHeadByte
+	svr.tlsListener = svr.muxer.Listen(2, 1, func(data []byte) bool {
+		// tls first byte can be 0x16 only when vhost https port is not same with bind port
+		return int(data[0]) == frpNet.FRPTLSHeadByte || int(data[0]) == 0x16
 	})
 
 	// Create nat hole controller.
@@ -395,15 +396,16 @@ func (svr *Service) HandleListener(l net.Listener) {
 
 		log.Trace("start check TLS connection...")
 		originConn := c
-		c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TLSOnly, connReadTimeout)
+		var isTLS, custom bool
+		c, isTLS, custom, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TLSOnly, connReadTimeout)
 		if err != nil {
 			log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
 			originConn.Close()
 			continue
 		}
-		log.Trace("success check TLS connection")
+		log.Trace("check TLS connection success, isTLS: %v custom: %v", isTLS, custom)
 
-		// Start a new goroutine for dealing connections.
+		// Start a new goroutine to handle connection.
 		go func(ctx context.Context, frpConn net.Conn) {
 			if svr.cfg.TCPMux {
 				fmuxCfg := fmux.DefaultConfig()
diff --git a/test/e2e/basic/client_server.go b/test/e2e/basic/client_server.go
index 56abe96e..67f1efd3 100644
--- a/test/e2e/basic/client_server.go
+++ b/test/e2e/basic/client_server.go
@@ -231,4 +231,22 @@ var _ = Describe("[Feature: Client-Server]", func() {
 			})
 		})
 	})
+
+	Describe("TLS with disable_custom_tls_first_byte", func() {
+		supportProtocols := []string{"tcp", "kcp", "websocket"}
+		for _, protocol := range supportProtocols {
+			tmp := protocol
+			defineClientServerTest("TLS over "+strings.ToUpper(tmp), f, &generalTestConfigures{
+				server: fmt.Sprintf(`
+					kcp_bind_port = {{ .%s }}
+					protocol = %s
+					`, consts.PortServerName, protocol),
+				client: fmt.Sprintf(`
+					tls_enable = true
+					protocol = %s
+					disable_custom_tls_first_byte = true
+					`, protocol),
+			})
+		}
+	})
 })