Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement pluggable transport interface #340

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
19 changes: 19 additions & 0 deletions pkg/core/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# A pluggable transport implementation based on Hysteria

## Hysteria
[Hysteria](https://github.com/HyNetwork/hysteria) uses a custom version of QUIC protocol ([RFC 9000 - QUIC: A UDP-Based Multiplexed and Secure Transport](https://www.rfc-editor.org/rfc/rfc9000.html)):

* a custom congestion control ([RFC 9002 - QUIC Loss Detection and Congestion Control](https://www.rfc-editor.org/rfc/rfc9002.html))
* tweaked QUIC parameters
* an obfuscation layer
* non-standard transports (e.g. [faketcp](https://github.com/wangyu-/udp2raw))

## Usage

* Follow [Custom CA](https://hysteria.network/docs/custom-ca/) doc to generate certificates
* See [server side implementation example](https://github.com/apernet/hysteria/pull/340/files#diff-8a9b6ccee2487fc2b424d9f4b3cad2ebde2cc27b1cf1aa078e0de084872edbaaR62-R155) in the `transport_test.go` file
* See [client side implementation example](https://github.com/apernet/hysteria/pull/340/files#diff-8a9b6ccee2487fc2b424d9f4b3cad2ebde2cc27b1cf1aa078e0de084872edbaaR157-R229) in the `transport_test.go` file

## Implementation

The implementation uses [Pluggable Transport Specification v3.0 - Go Transport API](https://github.com/Pluggable-Transports/Pluggable-Transports-spec/blob/main/releases/PTSpecV3.0/Pluggable%20Transport%20Specification%20v3.0%20-%20Go%20Transport%20API%20v3.0.md)
25 changes: 20 additions & 5 deletions pkg/core/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@ import (
"crypto/tls"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"sync"
"time"

"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lunixbochs/struc"
"github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/pmtud_fix"
"github.com/tobyxdd/hysteria/pkg/transport"
"github.com/tobyxdd/hysteria/pkg/utils"
"math/rand"
"net"
"strconv"
"sync"
"time"
)

var (
Expand Down Expand Up @@ -183,6 +184,20 @@ func (c *Client) openStreamWithReconnect() (quic.Connection, quic.Stream, error)
return c.quicSession, &wrappedQUICStream{stream}, err
}

// Implement Pluggable Transport Client interface
func (c *Client) Dial() (net.Conn, error) {
session, stream, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}

return &quicConn{
Orig: stream,
PseudoLocalAddr: session.LocalAddr(),
PseudoRemoteAddr: session.RemoteAddr(),
}, nil
}

func (c *Client) DialTCP(addr string) (net.Conn, error) {
host, port, err := utils.SplitHostPort(addr)
if err != nil {
Expand Down
187 changes: 186 additions & 1 deletion pkg/core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"

"github.com/lucas-clemente/quic-go"
"github.com/lunixbochs/struc"
"github.com/prometheus/client_golang/prometheus"
"github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/pmtud_fix"
"github.com/tobyxdd/hysteria/pkg/transport"
"net"
)

type ConnectFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
Expand Down Expand Up @@ -42,6 +43,40 @@ type Server struct {
listener quic.Listener
}

type HysteriaTransport struct {
addr string
protocol string
tlsConfig *tls.Config
quicConfig *quic.Config
transport *transport.ServerTransport
sendBPS uint64
recvBPS uint64
congestionFactory CongestionFactory
disableUDP bool
obfuscator obfs.Obfuscator
connectFunc ConnectFunc
disconnectFunc DisconnectFunc
}

type TransportServer struct {
transport *transport.ServerTransport
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
disableUDP bool
aclEngine *acl.Engine

connectFunc ConnectFunc
disconnectFunc DisconnectFunc
tcpRequestFunc TCPRequestFunc
tcpErrorFunc TCPErrorFunc
udpRequestFunc UDPRequestFunc
udpErrorFunc UDPErrorFunc

listener quic.Listener
allStreams chan *quicConn
isListening bool
}

func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig *quic.Config, transport *transport.ServerTransport,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine,
obfuscator obfs.Obfuscator, connectFunc ConnectFunc, disconnectFunc DisconnectFunc,
Expand Down Expand Up @@ -92,6 +127,8 @@ func (s *Server) Serve() error {
}
}

// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (s *Server) Close() error {
return s.listener.Close()
}
Expand Down Expand Up @@ -173,3 +210,151 @@ func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([]
}
return ch.Auth, ok, vb[0] == protocolVersionV2, nil
}

// Implement Pluggable Transport Server interface
func (t *HysteriaTransport) Listen() (net.Listener, error) {
listener, err := t.transport.QUICListen(t.protocol, t.addr, t.tlsConfig, t.quicConfig, t.obfuscator)
if err != nil {
return nil, err
}
s := &TransportServer{
listener: listener,
transport: t.transport,
sendBPS: t.sendBPS,
recvBPS: t.recvBPS,
congestionFactory: t.congestionFactory,
disableUDP: t.disableUDP,
connectFunc: t.connectFunc,
disconnectFunc: t.disconnectFunc,
allStreams: make(chan *quicConn),
isListening: false,
}

return s, nil
}

// Addr returns the listener's network address.
func (s *TransportServer) Addr() net.Addr {
return s.listener.Addr()
}

func (s *TransportServer) Close() error {
s.isListening = false
return s.listener.Close()
}

func (s *TransportServer) Accept() (net.Conn, error) {
if !s.isListening {
s.isListening = true
go acceptConn(s)
}
// Return the next stream
select {
case stream := <-s.allStreams:
return stream, nil
}
}

// An internal goroutine for accepting connections. Then for each accepted
// connection, start a goroutine for handling the control stream & accepting
// streams. Put those streams into a channel
func acceptConn(s *TransportServer) {
for {
cs, err := s.listener.Accept(context.Background())
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}
go acceptStream(cs, s)
}
}

func acceptStream(cs quic.Connection, s *TransportServer) {
// Expect the client to create a control stream to send its own information
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
stream, err := cs.AcceptStream(ctx)
ctxCancel()
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}
// Handle the control stream
_, ok, _, err := s.handleControlStream(cs, stream)
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}
if !ok {
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
return
}
// Close the control stream
stream.Close()

for {
// Accept the next stream
stream, err = cs.AcceptStream(context.Background())
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}

conn := &quicConn{
Orig: stream,
PseudoLocalAddr: cs.LocalAddr(),
PseudoRemoteAddr: cs.RemoteAddr(),
}
s.allStreams <- conn
}
}

// Auth & negotiate speed
// Copy from (s *Server) handleControlStream, TODO: refactor
func (s *TransportServer) handleControlStream(cs quic.Connection, stream quic.Stream) ([]byte, bool, bool, error) {
// Check version
vb := make([]byte, 1)
_, err := stream.Read(vb)
if err != nil {
return nil, false, false, err
}
if vb[0] != protocolVersion && vb[0] != protocolVersionV2 {
return nil, false, false, fmt.Errorf("unsupported protocol version %d, expecting %d/%d",
vb[0], protocolVersionV2, protocolVersion)
}
// Parse client hello
var ch clientHello
err = struc.Unpack(stream, &ch)
if err != nil {
return nil, false, false, err
}
// Speed
if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 {
return nil, false, false, errors.New("invalid rate from client")
}
serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
serverSendBPS = s.sendBPS
}
if s.recvBPS > 0 && serverRecvBPS > s.recvBPS {
serverRecvBPS = s.recvBPS
}
// Auth
ok, msg := s.connectFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS)
// Response
err = struc.Pack(stream, &serverHello{
OK: ok,
Rate: transmissionRate{
SendBPS: serverSendBPS,
RecvBPS: serverRecvBPS,
},
Message: msg,
})
if err != nil {
return nil, false, false, err
}
// Set the congestion accordingly
if ok && s.congestionFactory != nil {
cs.SetCongestionControl(s.congestionFactory(serverSendBPS))
}
return ch.Auth, ok, vb[0] == protocolVersionV2, nil
}
Loading