From ce366ee17f3a92f6015d57083f053256dd9501a5 Mon Sep 17 00:00:00 2001 From: fatedier Date: Fri, 23 May 2025 21:39:47 +0800 Subject: [PATCH] add proxy protocol support for UDP proxies (#4810) --- README.md | 2 +- Release.md | 3 +- client/proxy/proxy.go | 24 +--- client/proxy/sudp.go | 2 +- client/proxy/udp.go | 4 +- pkg/proto/udp/udp.go | 16 ++- pkg/util/net/proxyprotocol.go | 45 ++++++++ pkg/util/net/proxyprotocol_test.go | 178 +++++++++++++++++++++++++++++ test/e2e/v1/features/real_ip.go | 50 ++++++++ 9 files changed, 299 insertions(+), 25 deletions(-) create mode 100644 pkg/util/net/proxyprotocol.go create mode 100644 pkg/util/net/proxyprotocol_test.go diff --git a/README.md b/README.md index 25537220..f0ab4273 100644 --- a/README.md +++ b/README.md @@ -1025,7 +1025,7 @@ You can get user's real IP from HTTP request headers `X-Forwarded-For`. #### Proxy Protocol -frp supports Proxy Protocol to send user's real IP to local services. It support all types except UDP. +frp supports Proxy Protocol to send user's real IP to local services. Here is an example for https service: diff --git a/Release.md b/Release.md index 07c58d4a..19b79d64 100644 --- a/Release.md +++ b/Release.md @@ -1,3 +1,4 @@ ## Features -* Support for YAML merge functionality (anchors and references with dot-prefixed fields) in strict configuration mode without requiring `--strict-config=false` parameter. \ No newline at end of file +* Support for YAML merge functionality (anchors and references with dot-prefixed fields) in strict configuration mode without requiring `--strict-config=false` parameter. +* Support for proxy protocol in UDP proxies to preserve real client IP addresses. \ No newline at end of file diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index debda9fa..876ca579 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -20,13 +20,11 @@ import ( "net" "reflect" "strconv" - "strings" "sync" "time" libio "github.com/fatedier/golib/io" libnet "github.com/fatedier/golib/net" - pp "github.com/pires/go-proxyproto" "golang.org/x/time/rate" "github.com/fatedier/frp/pkg/config/types" @@ -35,6 +33,7 @@ import ( plugin "github.com/fatedier/frp/pkg/plugin/client" "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/limit" + netpkg "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/vnet" ) @@ -176,24 +175,9 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor } if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 { - h := &pp.Header{ - Command: pp.PROXY, - SourceAddr: connInfo.SrcAddr, - DestinationAddr: connInfo.DstAddr, - } - - if strings.Contains(m.SrcAddr, ".") { - h.TransportProtocol = pp.TCPv4 - } else { - h.TransportProtocol = pp.TCPv6 - } - - if baseCfg.Transport.ProxyProtocolVersion == "v1" { - h.Version = 1 - } else if baseCfg.Transport.ProxyProtocolVersion == "v2" { - h.Version = 2 - } - connInfo.ProxyProtocolHeader = h + // Use the common proxy protocol builder function + header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion) + connInfo.ProxyProtocolHeader = header } connInfo.Conn = remote connInfo.UnderlyingConn = workConn diff --git a/client/proxy/sudp.go b/client/proxy/sudp.go index ad9db89a..13741d0d 100644 --- a/client/proxy/sudp.go +++ b/client/proxy/sudp.go @@ -205,5 +205,5 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) { go workConnReaderFn(workConn, readCh) go heartbeatFn(sendCh) - udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize)) + udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize), pxy.cfg.Transport.ProxyProtocolVersion) } diff --git a/client/proxy/udp.go b/client/proxy/udp.go index b08fe160..b70ffe4a 100644 --- a/client/proxy/udp.go +++ b/client/proxy/udp.go @@ -171,5 +171,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) { go workConnSenderFn(pxy.workConn, pxy.sendCh) go workConnReaderFn(pxy.workConn, pxy.readCh) go heartbeatFn(pxy.sendCh) - udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize)) + + // Call Forwarder with proxy protocol version (empty string means no proxy protocol) + udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize), pxy.cfg.Transport.ProxyProtocolVersion) } diff --git a/pkg/proto/udp/udp.go b/pkg/proto/udp/udp.go index 7a11984b..f97b3b43 100644 --- a/pkg/proto/udp/udp.go +++ b/pkg/proto/udp/udp.go @@ -24,6 +24,7 @@ import ( "github.com/fatedier/golib/pool" "github.com/fatedier/frp/pkg/msg" + netpkg "github.com/fatedier/frp/pkg/util/net" ) func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket { @@ -69,7 +70,7 @@ func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh } } -func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int) { +func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int, proxyProtocolVersion string) { var mu sync.RWMutex udpConnMap := make(map[string]*net.UDPConn) @@ -110,6 +111,7 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- if err != nil { continue } + mu.Lock() udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()] if !ok { @@ -122,6 +124,18 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- } mu.Unlock() + // Add proxy protocol header if configured + if proxyProtocolVersion != "" && udpMsg.RemoteAddr != nil { + ppBuf, err := netpkg.BuildProxyProtocolHeader(udpMsg.RemoteAddr, dstAddr, proxyProtocolVersion) + if err == nil { + // Prepend proxy protocol header to the UDP payload + finalBuf := make([]byte, len(ppBuf)+len(buf)) + copy(finalBuf, ppBuf) + copy(finalBuf[len(ppBuf):], buf) + buf = finalBuf + } + } + _, err = udpConn.Write(buf) if err != nil { udpConn.Close() diff --git a/pkg/util/net/proxyprotocol.go b/pkg/util/net/proxyprotocol.go new file mode 100644 index 00000000..5f0cd51f --- /dev/null +++ b/pkg/util/net/proxyprotocol.go @@ -0,0 +1,45 @@ +// Copyright 2025 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "bytes" + "fmt" + "net" + + pp "github.com/pires/go-proxyproto" +) + +func BuildProxyProtocolHeaderStruct(srcAddr, dstAddr net.Addr, version string) *pp.Header { + var versionByte byte + if version == "v1" { + versionByte = 1 + } else { + versionByte = 2 // default to v2 + } + return pp.HeaderProxyFromAddrs(versionByte, srcAddr, dstAddr) +} + +func BuildProxyProtocolHeader(srcAddr, dstAddr net.Addr, version string) ([]byte, error) { + h := BuildProxyProtocolHeaderStruct(srcAddr, dstAddr, version) + + // Convert header to bytes using a buffer + var buf bytes.Buffer + _, err := h.WriteTo(&buf) + if err != nil { + return nil, fmt.Errorf("failed to write proxy protocol header: %v", err) + } + return buf.Bytes(), nil +} diff --git a/pkg/util/net/proxyprotocol_test.go b/pkg/util/net/proxyprotocol_test.go new file mode 100644 index 00000000..187801f6 --- /dev/null +++ b/pkg/util/net/proxyprotocol_test.go @@ -0,0 +1,178 @@ +package net + +import ( + "net" + "testing" + + pp "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/require" +) + +func TestBuildProxyProtocolHeader(t *testing.T) { + require := require.New(t) + + tests := []struct { + name string + srcAddr net.Addr + dstAddr net.Addr + version string + expectError bool + }{ + { + name: "UDP IPv4 v2", + srcAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}, + dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306}, + version: "v2", + expectError: false, + }, + { + name: "TCP IPv4 v1", + srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}, + dstAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80}, + version: "v1", + expectError: false, + }, + { + name: "UDP IPv6 v2", + srcAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345}, + dstAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306}, + version: "v2", + expectError: false, + }, + { + name: "TCP IPv6 v1", + srcAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345}, + dstAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80}, + version: "v1", + expectError: false, + }, + { + name: "nil source address", + srcAddr: nil, + dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306}, + version: "v2", + expectError: false, + }, + { + name: "nil destination address", + srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}, + dstAddr: nil, + version: "v2", + expectError: false, + }, + { + name: "unsupported address type", + srcAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"}, + dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306}, + version: "v2", + expectError: false, + }, + } + + for _, tt := range tests { + header, err := BuildProxyProtocolHeader(tt.srcAddr, tt.dstAddr, tt.version) + + if tt.expectError { + require.Error(err, "test case: %s", tt.name) + continue + } + + require.NoError(err, "test case: %s", tt.name) + require.NotEmpty(header, "test case: %s", tt.name) + } +} + +func TestBuildProxyProtocolHeaderStruct(t *testing.T) { + require := require.New(t) + + tests := []struct { + name string + srcAddr net.Addr + dstAddr net.Addr + version string + expectedProtocol pp.AddressFamilyAndProtocol + expectedVersion byte + expectedCommand pp.ProtocolVersionAndCommand + expectedSourceAddr net.Addr + expectedDestAddr net.Addr + }{ + { + name: "TCP IPv4 v2", + srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}, + dstAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80}, + version: "v2", + expectedProtocol: pp.TCPv4, + expectedVersion: 2, + expectedCommand: pp.PROXY, + expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}, + expectedDestAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80}, + }, + { + name: "UDP IPv6 v1", + srcAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345}, + dstAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306}, + version: "v1", + expectedProtocol: pp.UDPv6, + expectedVersion: 1, + expectedCommand: pp.PROXY, + expectedSourceAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345}, + expectedDestAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306}, + }, + { + name: "TCP IPv6 default version", + srcAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345}, + dstAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80}, + version: "", + expectedProtocol: pp.TCPv6, + expectedVersion: 2, // default to v2 + expectedCommand: pp.PROXY, + expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345}, + expectedDestAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80}, + }, + { + name: "nil source address", + srcAddr: nil, + dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306}, + version: "v2", + expectedProtocol: pp.UNSPEC, + expectedVersion: 2, + expectedCommand: pp.LOCAL, + expectedSourceAddr: nil, // go-proxyproto sets both to nil when srcAddr is nil + expectedDestAddr: nil, + }, + { + name: "nil destination address", + srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}, + dstAddr: nil, + version: "v2", + expectedProtocol: pp.UNSPEC, + expectedVersion: 2, + expectedCommand: pp.LOCAL, + expectedSourceAddr: nil, // go-proxyproto sets both to nil when dstAddr is nil + expectedDestAddr: nil, + }, + { + name: "unsupported address type", + srcAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"}, + dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306}, + version: "v2", + expectedProtocol: pp.UNSPEC, + expectedVersion: 2, + expectedCommand: pp.LOCAL, + expectedSourceAddr: nil, // go-proxyproto sets both to nil for unsupported types + expectedDestAddr: nil, + }, + } + + for _, tt := range tests { + header := BuildProxyProtocolHeaderStruct(tt.srcAddr, tt.dstAddr, tt.version) + + require.NotNil(header, "test case: %s", tt.name) + + require.Equal(tt.expectedCommand, header.Command, "test case: %s", tt.name) + require.Equal(tt.expectedSourceAddr, header.SourceAddr, "test case: %s", tt.name) + require.Equal(tt.expectedDestAddr, header.DestinationAddr, "test case: %s", tt.name) + require.Equal(tt.expectedProtocol, header.TransportProtocol, "test case: %s", tt.name) + require.Equal(tt.expectedVersion, header.Version, "test case: %s", tt.name) + } +} diff --git a/test/e2e/v1/features/real_ip.go b/test/e2e/v1/features/real_ip.go index 216f531d..a52cf0a2 100644 --- a/test/e2e/v1/features/real_ip.go +++ b/test/e2e/v1/features/real_ip.go @@ -227,6 +227,56 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() { }) }) + ginkgo.It("UDP", func() { + serverConf := consts.DefaultServerConfig + clientConf := consts.DefaultClientConfig + + localPort := f.AllocPort() + localServer := streamserver.New(streamserver.UDP, streamserver.WithBindPort(localPort), + streamserver.WithCustomHandler(func(c net.Conn) { + defer c.Close() + rd := bufio.NewReader(c) + ppHeader, err := pp.Read(rd) + if err != nil { + log.Errorf("read proxy protocol error: %v", err) + return + } + + // Read the actual UDP content after proxy protocol header + if _, err := rpc.ReadBytes(rd); err != nil { + return + } + + buf := []byte(ppHeader.SourceAddr.String()) + _, _ = rpc.WriteBytes(c, buf) + })) + f.RunServer("", localServer) + + remotePort := f.AllocPort() + clientConf += fmt.Sprintf(` + [[proxies]] + name = "udp" + type = "udp" + localPort = %d + remotePort = %d + transport.proxyProtocolVersion = "v2" + `, localPort, remotePort) + + f.RunProcesses([]string{serverConf}, []string{clientConf}) + + framework.NewRequestExpect(f).Protocol("udp").Port(remotePort).Ensure(func(resp *request.Response) bool { + log.Tracef("udp proxy protocol get SourceAddr: %s", string(resp.Content)) + addr, err := net.ResolveUDPAddr("udp", string(resp.Content)) + if err != nil { + return false + } + if addr.IP.String() != "127.0.0.1" { + return false + } + return true + }) + }) + ginkgo.It("HTTP", func() { vhostHTTPPort := f.AllocPort() serverConf := consts.DefaultServerConfig + fmt.Sprintf(`