From 5c8ea51eb574dfcc5efc0bec788d321588d4c01b Mon Sep 17 00:00:00 2001
From: Zeyu Dong <dzy201415@gmail.com>
Date: Mon, 18 Sep 2023 02:28:05 -0400
Subject: [PATCH] return ssl alert unrecognized_name when https domain not
 registered (#3620)

---
 pkg/util/tcpmux/httpconnect.go | 12 +++++++++++-
 pkg/util/vhost/http.go         |  2 +-
 pkg/util/vhost/https.go        |  7 +++++++
 pkg/util/vhost/resource.go     |  2 +-
 pkg/util/vhost/vhost.go        | 14 ++++++++------
 5 files changed, 28 insertions(+), 9 deletions(-)

diff --git a/pkg/util/tcpmux/httpconnect.go b/pkg/util/tcpmux/httpconnect.go
index 650891f4..17989adc 100644
--- a/pkg/util/tcpmux/httpconnect.go
+++ b/pkg/util/tcpmux/httpconnect.go
@@ -40,7 +40,8 @@ func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout tim
 	ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
 	mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
 	mux.SetCheckAuthFunc(ret.auth).
-		SetSuccessHookFunc(ret.sendConnectResponse)
+		SetSuccessHookFunc(ret.sendConnectResponse).
+		SetFailHookFunc(vhostFailed)
 	ret.Muxer = mux
 	return ret, err
 }
@@ -92,6 +93,15 @@ func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, re
 	return false, nil
 }
 
+func vhostFailed(c net.Conn) {
+	res := vhost.NotFoundResponse()
+	if res.Body != nil {
+		defer res.Body.Close()
+	}
+	_ = res.Write(c)
+	_ = c.Close()
+}
+
 func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) {
 	reqInfoMap := make(map[string]string, 0)
 	sc, rd := libnet.NewSharedConn(c)
diff --git a/pkg/util/vhost/http.go b/pkg/util/vhost/http.go
index af3a4ab5..7b914ce9 100644
--- a/pkg/util/vhost/http.go
+++ b/pkg/util/vhost/http.go
@@ -251,7 +251,7 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
 
 	remote, err := rp.CreateConnection(req.Context().Value(RouteInfoKey).(*RequestRouteInfo), false)
 	if err != nil {
-		_ = notFoundResponse().Write(client)
+		_ = NotFoundResponse().Write(client)
 		client.Close()
 		return
 	}
diff --git a/pkg/util/vhost/https.go b/pkg/util/vhost/https.go
index e15c1901..bcfdb81e 100644
--- a/pkg/util/vhost/https.go
+++ b/pkg/util/vhost/https.go
@@ -29,6 +29,7 @@ type HTTPSMuxer struct {
 
 func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
 	mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
+	mux.SetFailHookFunc(vhostFailed)
 	if err != nil {
 		return nil, err
 	}
@@ -69,6 +70,12 @@ func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
 	return hello, nil
 }
 
+func vhostFailed(c net.Conn) {
+	// Alert with alertUnrecognizedName
+	_ = tls.Server(c, &tls.Config{}).Handshake()
+	c.Close()
+}
+
 type readOnlyConn struct {
 	reader io.Reader
 }
diff --git a/pkg/util/vhost/resource.go b/pkg/util/vhost/resource.go
index e09edf21..d78082b2 100644
--- a/pkg/util/vhost/resource.go
+++ b/pkg/util/vhost/resource.go
@@ -67,7 +67,7 @@ func getNotFoundPageContent() []byte {
 	return buf
 }
 
-func notFoundResponse() *http.Response {
+func NotFoundResponse() *http.Response {
 	header := make(http.Header)
 	header.Set("server", "frp/"+version.Full())
 	header.Set("Content-Type", "text/html")
diff --git a/pkg/util/vhost/vhost.go b/pkg/util/vhost/vhost.go
index 6051a217..29123b69 100644
--- a/pkg/util/vhost/vhost.go
+++ b/pkg/util/vhost/vhost.go
@@ -46,6 +46,7 @@ type (
 	authFunc        func(conn net.Conn, username, password string, reqInfoMap map[string]string) (bool, error)
 	hostRewriteFunc func(net.Conn, string) (net.Conn, error)
 	successHookFunc func(net.Conn, map[string]string) error
+	failHookFunc    func(net.Conn)
 )
 
 // Muxer is a functional component used for https and tcpmux proxies.
@@ -58,6 +59,7 @@ type Muxer struct {
 	vhostFunc      muxFunc
 	checkAuth      authFunc
 	successHook    successHookFunc
+	failHook       failHookFunc
 	rewriteHost    hostRewriteFunc
 	registryRouter *Routers
 }
@@ -87,6 +89,11 @@ func (v *Muxer) SetSuccessHookFunc(f successHookFunc) *Muxer {
 	return v
 }
 
+func (v *Muxer) SetFailHookFunc(f failHookFunc) *Muxer {
+	v.failHook = f
+	return v
+}
+
 func (v *Muxer) SetRewriteHostFunc(f hostRewriteFunc) *Muxer {
 	v.rewriteHost = f
 	return v
@@ -206,13 +213,8 @@ func (v *Muxer) handle(c net.Conn) {
 	httpUser := reqInfoMap["HTTPUser"]
 	l, ok := v.getListener(name, path, httpUser)
 	if !ok {
-		res := notFoundResponse()
-		if res.Body != nil {
-			defer res.Body.Close()
-		}
-		_ = res.Write(c)
 		log.Debug("http request for host [%s] path [%s] httpUser [%s] not found", name, path, httpUser)
-		_ = c.Close()
+		v.failHook(sConn)
 		return
 	}