Skip to content

Commit

Permalink
Support udp c/s
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jan 21, 2021
1 parent 6a70ce3 commit 19f1315
Show file tree
Hide file tree
Showing 13 changed files with 756 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The following is the implementation of other proxy protocols
## Features

- [x] Support TCP proxy
- [ ] Support UDP proxy
- [x] Support UDP proxy

## Encrypto method

Expand Down
46 changes: 42 additions & 4 deletions aead/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"golang.org/x/crypto/hkdf"
)

var _zerononce [128]byte // read-only. 128 bytes is more than enough.

func RegisterCipher(method string, keyLen int, cipher func(key []byte) (cipher.AEAD, error)) {
shadowsocks.RegisterCipher(method, func(password string) (shadowsocks.ConnCipher, error) {
return &Cipher{Rand: rand.Reader, Key: shadowsocks.KDF(password, keyLen), NewAEAD: cipher}, nil
Expand Down Expand Up @@ -46,13 +48,13 @@ func (c *Cipher) SaltSize() int {

var sssubkey = []byte("ss-subkey")

func (c *Cipher) Encrypt(salt []byte) (cipher.AEAD, error) {
func (c *Cipher) newEncrypt(salt []byte) (cipher.AEAD, error) {
subkey := make([]byte, c.KeySize())
hkdfSHA1(c.Key, salt, sssubkey, subkey)
return c.NewAEAD(subkey)
}

func (c *Cipher) Decrypt(salt []byte) (cipher.AEAD, error) {
func (c *Cipher) newDecrypt(salt []byte) (cipher.AEAD, error) {
subkey := make([]byte, c.KeySize())
hkdfSHA1(c.Key, salt, sssubkey, subkey)
return c.NewAEAD(subkey)
Expand All @@ -64,7 +66,7 @@ func (c *Cipher) initReader(r io.Reader) (*cipherReader, error) {
if err != nil {
return nil, err
}
aead, err := c.Decrypt(salt)
aead, err := c.newDecrypt(salt)
if err != nil {
return nil, err
}
Expand All @@ -77,7 +79,7 @@ func (c *Cipher) initWriter(w io.Writer) (*cipherWriter, error) {
if err != nil {
return nil, err
}
aead, err := c.Encrypt(salt)
aead, err := c.newEncrypt(salt)
if err != nil {
return nil, err
}
Expand All @@ -88,6 +90,42 @@ func (c *Cipher) initWriter(w io.Writer) (*cipherWriter, error) {
return newCipherWriter(w, aead), nil
}

func (c *Cipher) Encrypt(dest, src []byte) (int, error) {
saltSize := c.SaltSize()
salt := dest[:saltSize]
_, err := io.ReadFull(c.Rand, salt)
if err != nil {
return 0, err
}
aead, err := c.newEncrypt(salt)
if err != nil {
return 0, err
}
if len(dest) < saltSize+len(src)+aead.Overhead() {
return 0, io.ErrShortBuffer
}
b := aead.Seal(dest[saltSize:saltSize], _zerononce[:aead.NonceSize()], src, nil)
return saltSize + len(b), nil
}

func (c *Cipher) Decrypt(dest, src []byte) (int, error) {
saltSize := c.SaltSize()
if len(src) < saltSize {
return 0, io.ErrShortBuffer
}
salt := src[:saltSize]
aead, err := c.newDecrypt(salt)
if err != nil {
return 0, err
}
head := len(src) - (saltSize + aead.Overhead())
if head < 0 || head >= len(dest) {
return 0, io.ErrShortBuffer
}
b, err := aead.Open(dest[:0], _zerononce[:aead.NonceSize()], src[saltSize:], nil)
return len(b), err
}

// payloadSizeMask is the maximum size of payload in bytes.
const payloadSizeMask = 0x3FFF // 16*1024 - 1

Expand Down
86 changes: 86 additions & 0 deletions all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package shadowsocks_test

import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
Expand All @@ -12,6 +13,7 @@ import (
)

var list = []string{
"dummy",
"aes-128-cfb",
"aes-128-ctr",
"aes-128-gcm",
Expand Down Expand Up @@ -78,3 +80,87 @@ func TestAll(t *testing.T) {
})
}
}

func TestEncryptor(t *testing.T) {
var tmp1 [255]byte
var tmp2 [255]byte
for _, c := range list {
t.Run(c, func(t *testing.T) {
cipher, err := shadowsocks.NewCipher(c, "pwd")
if err != nil {
t.Fatal(err)
}

n1, err := cipher.Encrypt(tmp1[:], []byte(c))
if err != nil {
t.Fatal(err)
}

n2, err := cipher.Decrypt(tmp2[:], tmp1[:n1])
if err != nil {
t.Fatal(err)
}
if string(tmp2[:n2]) != c {
t.Errorf("%q %q %q", c, tmp1[:n1], tmp2[:n2])
}
})
}
}

func TestPacket(t *testing.T) {
// echo server
p, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
go func() {
var buf [1024 * 32]byte
for {
i, addr, err := p.ReadFrom(buf[:])
if err != nil {
t.Fatal(err)
}
tmp := append([]byte("echo "), buf[:i]...)
_, err = p.WriteTo(tmp, addr)
if err != nil {
t.Fatal(err)
}
}
}()

remote, err := shadowsocks.NewSimplePacketServer("ss://aes-128-cfb:[email protected]:0")
if err != nil {
t.Fatal(err)
}

err = remote.Start(context.Background())
if err != nil {
t.Fatal(err)
}

t.Log(remote.ProxyURL())
local, err := shadowsocks.NewPacketClient(remote.ProxyURL())
if err != nil {
t.Fatal(err)
}
client, err := local.ListenPacket(context.Background(), "udp", ":0")
if err != nil {
t.Fatal(err)
}

for i := 0; i != 10; i++ {
tmp := fmt.Sprintf("hello %d", i)
_, err = client.WriteTo([]byte(tmp), p.LocalAddr())
if err != nil {
t.Fatal(err)
}
var buf [1024 * 32]byte
i, addr, err := client.ReadFrom(buf[:])
if err != nil {
t.Fatal(err)
}
if "echo "+tmp != string(buf[:i]) {
t.Error("resp", i, string(buf[:i]), addr)
}
}
}
2 changes: 2 additions & 0 deletions cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ var (

type ConnCipher interface {
StreamConn(net.Conn) net.Conn
Decrypt(dist, src []byte) (n int, err error)
Encrypt(dist, src []byte) (n int, err error)
}

var registerCipher = map[string]func(password string) (ConnCipher, error){}
Expand Down
35 changes: 26 additions & 9 deletions cmd/shadowsocks/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,31 @@ func init() {

func main() {
logger := log.New(os.Stderr, "[shadowsocks] ", log.LstdFlags)
svc := &shadowsocks.Server{
Logger: logger,
Cipher: cipher,
Password: password,
}
go func() {
svc := &shadowsocks.Server{
Logger: logger,
Cipher: cipher,
Password: password,
}

err := svc.ListenAndServe("tcp", address)
if err != nil {
logger.Println(err)
}
err := svc.ListenAndServe("tcp", address)
if err != nil {
logger.Println(err)
}
os.Exit(1 )
}()
go func() {
svc := &shadowsocks.PacketServer{
Logger: logger,
Cipher: cipher,
Password: password,
}

err := svc.ListenAndServe("udp", address)
if err != nil {
logger.Println(err)
}
os.Exit(1 )
}()
<-make(chan struct{})
}
13 changes: 13 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,16 @@ type BytesPool interface {
Get() []byte
Put([]byte)
}

func getBytes(p BytesPool) []byte {
if p != nil {
return p.Get()
}
return make([]byte, 32*1024)
}

func putBytes(p BytesPool, d []byte) {
if p != nil {
p.Put(d)
}
}
27 changes: 27 additions & 0 deletions dummy/cipher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package dummy

import (
"net"

"github.com/wzshiming/shadowsocks"
)

func init() {
shadowsocks.RegisterCipher("dummy", func(password string) (shadowsocks.ConnCipher, error) {
return &cipher{}, nil
})
}

type cipher struct {
}

func (cipher) StreamConn(conn net.Conn) net.Conn {
return conn
}

func (cipher) Decrypt(dist, src []byte) (n int, err error) {
return copy(dist, src), nil
}
func (cipher) Encrypt(dist, src []byte) (n int, err error) {
return copy(dist, src), nil
}
1 change: 1 addition & 0 deletions init/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package init
import (
_ "github.com/wzshiming/shadowsocks/aead/aes-x-gcm"
_ "github.com/wzshiming/shadowsocks/aead/chacha20-ietf-poly1305"
_ "github.com/wzshiming/shadowsocks/dummy"
_ "github.com/wzshiming/shadowsocks/stream/aes-x-cfb"
_ "github.com/wzshiming/shadowsocks/stream/aes-x-ctr"
_ "github.com/wzshiming/shadowsocks/stream/bf-cfb"
Expand Down
63 changes: 63 additions & 0 deletions packet.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package shadowsocks

import (
"bytes"
"context"
"net"
)

type ListenPacket interface {
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
}

func decryptPacket(c ConnCipher, p BytesPool, dist, src []byte) (n int, addr net.Addr, err error) {
i, err := c.Decrypt(dist, src)
if err != nil {
return 0, nil, err
}
buf := bytes.NewBuffer(dist[:i])
a, err := readAddress(buf)
if err != nil {
return 0, nil, err
}
i = copy(dist, buf.Bytes())
return i, a, nil
}

func encryptPacket(c ConnCipher, p BytesPool, dist, src []byte, addr net.Addr) (n int, err error) {
a, err := parseAddress(addr.String())
if err != nil {
return 0, err
}
buf := getBytes(p)
defer putBytes(p, buf)
b := bytes.NewBuffer(buf[:0])
err = writeAddress(b, a)
if err != nil {
return 0, err
}
b.Write(src)
i, err := c.Encrypt(dist, b.Bytes())
if err != nil {
return 0, err
}
return i, nil
}

func toUDPAddr(addr net.Addr) (net.Addr, error) {
switch a := addr.(type) {
case *net.UDPAddr:
return addr, nil
case *address:
return &net.UDPAddr{
IP: a.IP,
Port: a.Port,
}, nil
default:
a, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
return nil, err
}
return a, nil
}
}
Loading

0 comments on commit 19f1315

Please sign in to comment.