package conn

import (
	"bufio"
	"fmt"
	"io"
	"net"
	"sync"

	"github.com/fatedier/frp/pkg/utils/log"
)

type Listener struct {
	Addr  net.Addr
	Conns chan *Conn
}

// wait util get one
func (l *Listener) GetConn() (conn *Conn) {
	conn = <-l.Conns
	return conn
}

type Conn struct {
	TcpConn *net.TCPConn
	Reader  *bufio.Reader
}

func (c *Conn) ConnectServer(host string, port int64) (err error) {
	servertAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port))
	if err != nil {
		return err
	}
	conn, err := net.DialTCP("tcp", nil, servertAddr)
	if err != nil {
		return err
	}
	c.TcpConn = conn
	c.Reader = bufio.NewReader(c.TcpConn)
	return nil
}

func (c *Conn) GetRemoteAddr() (addr string) {
	return c.TcpConn.RemoteAddr().String()
}

func (c *Conn) GetLocalAddr() (addr string) {
	return c.TcpConn.LocalAddr().String()
}

func (c *Conn) ReadLine() (buff string, err error) {
	buff, err = c.Reader.ReadString('\n')
	return buff, err
}

func (c *Conn) Write(content string) (err error) {
	_, err = c.TcpConn.Write([]byte(content))
	return err
}

func (c *Conn) Close() {
	if c.TcpConn != nil {
		c.TcpConn.Close()
	}
}

func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
	tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort))
	listener, err := net.ListenTCP("tcp", tcpAddr)
	if err != nil {
		return l, err
	}

	l = &Listener{
		Addr:  listener.Addr(),
		Conns: make(chan *Conn),
	}

	go func() {
		for {
			conn, err := listener.AcceptTCP()
			if err != nil {
				log.Error("Accept new tcp connection error, %v", err)
				continue
			}

			c := &Conn{
				TcpConn: conn,
			}
			c.Reader = bufio.NewReader(c.TcpConn)
			l.Conns <- c
		}
	}()
	return l, err
}

// will block until conn close
func Join(c1 *Conn, c2 *Conn) {
	var wait sync.WaitGroup
	pipe := func(to *Conn, from *Conn) {
		defer to.Close()
		defer from.Close()
		defer wait.Done()

		var err error
		_, err = io.Copy(to.TcpConn, from.TcpConn)
		if err != nil {
			log.Warn("join conns error, %v", err)
		}
	}

	wait.Add(2)
	go pipe(c1, c2)
	go pipe(c2, c1)
	wait.Wait()
	return
}