package proxyproto

import (
	"bufio"
	"bytes"
	"encoding/binary"
	"io"
)

var (
	lengthV4   = uint16(12)
	lengthV6   = uint16(36)
	lengthUnix = uint16(218)

	lengthV4Bytes = func() []byte {
		a := make([]byte, 2)
		binary.BigEndian.PutUint16(a, lengthV4)
		return a
	}()
	lengthV6Bytes = func() []byte {
		a := make([]byte, 2)
		binary.BigEndian.PutUint16(a, lengthV6)
		return a
	}()
	lengthUnixBytes = func() []byte {
		a := make([]byte, 2)
		binary.BigEndian.PutUint16(a, lengthUnix)
		return a
	}()
)

type _ports struct {
	SrcPort uint16
	DstPort uint16
}

type _addr4 struct {
	Src     [4]byte
	Dst     [4]byte
	SrcPort uint16
	DstPort uint16
}

type _addr6 struct {
	Src [16]byte
	Dst [16]byte
	_ports
}

type _addrUnix struct {
	Src [108]byte
	Dst [108]byte
}

func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
	// Skip first 12 bytes (signature)
	for i := 0; i < 12; i++ {
		if _, err = reader.ReadByte(); err != nil {
			return nil, ErrCantReadProtocolVersionAndCommand
		}
	}

	header = new(Header)
	header.Version = 2

	// Read the 13th byte, protocol version and command
	b13, err := reader.ReadByte()
	if err != nil {
		return nil, ErrCantReadProtocolVersionAndCommand
	}
	header.Command = ProtocolVersionAndCommand(b13)
	if _, ok := supportedCommand[header.Command]; !ok {
		return nil, ErrUnsupportedProtocolVersionAndCommand
	}
	// If command is LOCAL, header ends here
	if header.Command.IsLocal() {
		return header, nil
	}

	// Read the 14th byte, address family and protocol
	b14, err := reader.ReadByte()
	if err != nil {
		return nil, ErrCantReadAddressFamilyAndProtocol
	}
	header.TransportProtocol = AddressFamilyAndProtocol(b14)
	if _, ok := supportedTransportProtocol[header.TransportProtocol]; !ok {
		return nil, ErrUnsupportedAddressFamilyAndProtocol
	}

	// Make sure there are bytes available as specified in length
	var length uint16
	if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil {
		return nil, ErrCantReadLength
	}
	if !header.validateLength(length) {
		return nil, ErrInvalidLength
	}

	if _, err := reader.Peek(int(length)); err != nil {
		return nil, ErrInvalidLength
	}

	// Length-limited reader for payload section
	payloadReader := io.LimitReader(reader, int64(length))

	// Read addresses and ports
	if header.TransportProtocol.IsIPv4() {
		var addr _addr4
		if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
			return nil, ErrInvalidAddress
		}
		header.SourceAddress = addr.Src[:]
		header.DestinationAddress = addr.Dst[:]
		header.SourcePort = addr.SrcPort
		header.DestinationPort = addr.DstPort
	} else if header.TransportProtocol.IsIPv6() {
		var addr _addr6
		if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
			return nil, ErrInvalidAddress
		}
		header.SourceAddress = addr.Src[:]
		header.DestinationAddress = addr.Dst[:]
		header.SourcePort = addr.SrcPort
		header.DestinationPort = addr.DstPort
	}
	// TODO fully support Unix addresses
	//	else if header.TransportProtocol.IsUnix() {
	//		var addr _addrUnix
	//		if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
	//			return nil, ErrInvalidAddress
	//		}
	//
	//if header.SourceAddress, err = net.ResolveUnixAddr("unix", string(addr.Src[:])); err != nil {
	//	return nil, ErrCantResolveSourceUnixAddress
	//}
	//if header.DestinationAddress, err = net.ResolveUnixAddr("unix", string(addr.Dst[:])); err != nil {
	//	return nil, ErrCantResolveDestinationUnixAddress
	//}
	//}

	// TODO add encapsulated TLV support

	// Drain the remaining padding
	payloadReader.Read(make([]byte, length))

	return header, nil
}

func (header *Header) writeVersion2(w io.Writer) (int64, error) {
	var buf bytes.Buffer
	buf.Write(SIGV2)
	buf.WriteByte(header.Command.toByte())
	if !header.Command.IsLocal() {
		buf.WriteByte(header.TransportProtocol.toByte())
		// TODO add encapsulated TLV length
		var addrSrc, addrDst []byte
		if header.TransportProtocol.IsIPv4() {
			buf.Write(lengthV4Bytes)
			addrSrc = header.SourceAddress.To4()
			addrDst = header.DestinationAddress.To4()
		} else if header.TransportProtocol.IsIPv6() {
			buf.Write(lengthV6Bytes)
			addrSrc = header.SourceAddress.To16()
			addrDst = header.DestinationAddress.To16()
		} else if header.TransportProtocol.IsUnix() {
			buf.Write(lengthUnixBytes)
			// TODO is below right?
			addrSrc = []byte(header.SourceAddress.String())
			addrDst = []byte(header.DestinationAddress.String())
		}
		buf.Write(addrSrc)
		buf.Write(addrDst)

		portSrcBytes := func() []byte {
			a := make([]byte, 2)
			binary.BigEndian.PutUint16(a, header.SourcePort)
			return a
		}()
		buf.Write(portSrcBytes)

		portDstBytes := func() []byte {
			a := make([]byte, 2)
			binary.BigEndian.PutUint16(a, header.DestinationPort)
			return a
		}()
		buf.Write(portDstBytes)

	}

	return buf.WriteTo(w)
}

func (header *Header) validateLength(length uint16) bool {
	if header.TransportProtocol.IsIPv4() {
		return length >= lengthV4
	} else if header.TransportProtocol.IsIPv6() {
		return length >= lengthV6
	} else if header.TransportProtocol.IsUnix() {
		return length >= lengthUnix
	}
	return false
}