From 5496f0d6ec7685b34c72a3af44f6153ba34f2154 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Sun, 1 Oct 2023 17:16:23 +0200 Subject: [PATCH] ground 1 --- .cargo/config.toml | 8 + .gitignore | 27 +- Cargo.lock | 811 ++++++++++++++------------------------------- Cargo.toml | 20 +- src/main.rs | 508 ++++++++++++++++++++++++---- src/stdio.rs | 19 ++ src/tcp.rs | 110 ++++++ src/tls.rs | 82 +++++ src/transport.rs | 371 +++++++++++++++++++++ src/udp.rs | 241 ++++++++++++++ 10 files changed, 1557 insertions(+), 640 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 src/stdio.rs create mode 100644 src/tcp.rs create mode 100644 src/tls.rs create mode 100644 src/transport.rs create mode 100644 src/udp.rs diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..b2b48c6b --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,8 @@ +[build] +#target = "x86_64-unknown-linux-musl" +rustflags = ["--cfg", "uuid_unstable"] + +[target.'cfg(target_os = "linux")'] +#rustflags = ["-C", "linker=ld.lld", "-C", "relocation-model=static", "-C", "strip=symbols", "--cfg", "uuid_unstable"] + +#[build] diff --git a/.gitignore b/.gitignore index 9b3eda0a..73fab072 100644 --- a/.gitignore +++ b/.gitignore @@ -1,21 +1,10 @@ -dist -cabal-dev -*.o -*.hi -*.chi -*.chs.h -.virtualenv -.hsenv -.cabal-sandbox/ -cabal.sandbox.config -cabal.config -*.log -tags -bin/ -*~ -.stack-work +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ +# These are backup files generated by rustfmt +**/*.rs.bk -# Added by cargo - -/target +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb diff --git a/Cargo.lock b/Cargo.lock index d6ef4f8f..b33716c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "1.1.1" @@ -75,15 +87,16 @@ dependencies = [ ] [[package]] -name = "async-trait" -version = "0.1.73" +name = "anyhow" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.37", -] +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" + +[[package]] +name = "atomic" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba" [[package]] name = "autocfg" @@ -119,10 +132,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] -name = "bitflags" -version = "2.4.0" +name = "block-buffer" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] [[package]] name = "bumpalo" @@ -182,7 +198,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.37", + "syn", ] [[package]] @@ -214,64 +230,55 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] -name = "data-encoding" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" - -[[package]] -name = "deranged" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" - -[[package]] -name = "encoding_rs" -version = "0.8.33" +name = "cpufeatures" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" +checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" dependencies = [ - "cfg-if", + "libc", ] [[package]] -name = "enum-as-inner" -version = "0.5.1" +name = "crypto-common" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn 1.0.109", + "generic-array", + "typenum", ] [[package]] -name = "errno" -version = "0.3.3" +name = "deranged" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136526188508e25c6fef639d7927dfb3e0e3084488bf202267829cf7fc23dbdd" -dependencies = [ - "errno-dragonfly", - "libc", - "windows-sys", -] +checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "digest" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "cc", - "libc", + "block-buffer", + "crypto-common", ] [[package]] -name = "fastrand" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +name = "fastwebsockets" +version = "0.4.4" +source = "git+https://github.com/mmastrac/fastwebsockets?branch=split#6433d9d60ac498959d6354d7e0c49e08037b3a37" +dependencies = [ + "base64", + "hyper", + "pin-project", + "rand", + "sha1", + "simdutf8", + "thiserror", + "tokio", + "utf-8", +] [[package]] name = "fnv" @@ -279,21 +286,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.2.0" @@ -318,12 +310,6 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" -[[package]] -name = "futures-io" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" - [[package]] name = "futures-macro" version = "0.3.28" @@ -332,15 +318,9 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn", ] -[[package]] -name = "futures-sink" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" - [[package]] name = "futures-task" version = "0.3.28" @@ -354,16 +334,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-core", - "futures-io", "futures-macro", - "futures-sink", "futures-task", - "memchr", "pin-project-lite", "pin-utils", "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.10" @@ -381,31 +368,6 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" -[[package]] -name = "h2" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "heck" version = "0.4.1" @@ -418,17 +380,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" -[[package]] -name = "hostname" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" -dependencies = [ - "libc", - "match_cfg", - "winapi", -] - [[package]] name = "http" version = "0.2.9" @@ -451,16 +402,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "http-body" -version = "1.0.0-rc.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "951dfc2e32ac02d67c90c0d65bd27009a635dc9b381a2cc7d284ab01e3a0150d" -dependencies = [ - "bytes", - "http", -] - [[package]] name = "httparse" version = "1.8.0" @@ -483,9 +424,8 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", "http", - "http-body 0.4.5", + "http-body", "httparse", "httpdate", "itoa", @@ -497,69 +437,6 @@ dependencies = [ "want", ] -[[package]] -name = "hyper" -version = "1.0.0-rc.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d280a71f348bcc670fc55b02b63c53a04ac0bf2daff2980795aeaf53edae10e6" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "h2", - "http", - "http-body 1.0.0-rc.2", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "tokio", - "tracing", - "want", -] - -[[package]] -name = "hyper-openssl" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6ee5d7a8f718585d1c3c61dfde28ef5b0bb14734b4db13f5ada856cdc6c612b" -dependencies = [ - "http", - "hyper 0.14.27", - "linked_hash_set", - "once_cell", - "openssl", - "openssl-sys", - "parking_lot", - "tokio", - "tokio-openssl", - "tower-layer", -] - -[[package]] -name = "hyper-tls" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" -dependencies = [ - "bytes", - "hyper 0.14.27", - "native-tls", - "tokio", - "tokio-native-tls", -] - -[[package]] -name = "idna" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" -dependencies = [ - "matches", - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "idna" version = "0.4.0" @@ -570,34 +447,6 @@ dependencies = [ "unicode-normalization", ] -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown", -] - -[[package]] -name = "ipconfig" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" -dependencies = [ - "socket2 0.5.4", - "widestring", - "windows-sys", - "winreg", -] - -[[package]] -name = "ipnet" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" - [[package]] name = "itoa" version = "1.0.9" @@ -625,27 +474,6 @@ version = "0.2.148" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - -[[package]] -name = "linked_hash_set" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47186c6da4d81ca383c7c47c1bfc80f4b95f4720514d860a5407aaf4233f9588" -dependencies = [ - "linked-hash-map", -] - -[[package]] -name = "linux-raw-sys" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" - [[package]] name = "lock_api" version = "0.4.10" @@ -662,21 +490,6 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" -[[package]] -name = "lru-cache" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" -dependencies = [ - "linked-hash-map", -] - -[[package]] -name = "match_cfg" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" - [[package]] name = "matchers" version = "0.1.0" @@ -686,24 +499,12 @@ dependencies = [ "regex-automata 0.1.10", ] -[[package]] -name = "matches" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" - [[package]] name = "memchr" version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" -[[package]] -name = "mime" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" - [[package]] name = "miniz_oxide" version = "0.7.1" @@ -724,24 +525,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "native-tls" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" -dependencies = [ - "lazy_static", - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -786,50 +569,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" -[[package]] -name = "openssl" -version = "0.10.57" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bac25ee399abb46215765b1cb35bc0212377e58a061560d8b29b024fd0430e7c" -dependencies = [ - "bitflags 2.4.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.37", -] - [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" -[[package]] -name = "openssl-sys" -version = "0.9.93" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db4d56a4c0478783083cfafcc42493dd4a981d41669da64b4572a2a089b51b1d" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "overload" version = "0.1.1" @@ -865,6 +610,26 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +[[package]] +name = "pin-project" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -877,12 +642,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkg-config" -version = "0.3.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" - [[package]] name = "ppv-lite86" version = "0.2.17" @@ -898,12 +657,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "quick-error" -version = "1.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" - [[package]] name = "quote" version = "1.0.33" @@ -949,7 +702,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags 1.3.2", + "bitflags", ] [[package]] @@ -997,79 +750,74 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] -name = "reqwest" -version = "0.11.20" +name = "ring" +version = "0.16.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9ad3fe7488d7e34558a2033d45a0c90b72d97b4f80705666fea71472e2e6a1" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" dependencies = [ - "base64", - "bytes", - "encoding_rs", - "futures-core", - "futures-util", - "h2", - "http", - "http-body 0.4.5", - "hyper 0.14.27", - "hyper-tls", - "ipnet", - "js-sys", - "log", - "mime", - "native-tls", + "cc", + "libc", "once_cell", - "percent-encoding", - "pin-project-lite", - "serde", - "serde_json", - "serde_urlencoded", - "tokio", - "tokio-native-tls", - "tokio-util", - "tower-service", - "trust-dns-resolver", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "wasm-streams", + "spin", + "untrusted", "web-sys", - "winreg", + "winapi", ] [[package]] -name = "resolv-conf" -version = "0.7.0" +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] +name = "rustls" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" dependencies = [ - "hostname", - "quick-error", + "log", + "ring", + "rustls-webpki", + "sct", ] [[package]] -name = "rustc-demangle" -version = "0.1.23" +name = "rustls-native-certs" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] [[package]] -name = "rustix" -version = "0.38.14" +name = "rustls-pemfile" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "747c788e9ce8e92b12cd485c49ddf90723550b654b32508f979b71a7b1ecda4f" +checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ - "bitflags 2.4.0", - "errno", - "libc", - "linux-raw-sys", - "windows-sys", + "base64", ] [[package]] -name = "ryu" -version = "1.0.15" +name = "rustls-pki-types" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47003264dea418db67060fa420ad16d0d2f8f0a0360d825c00e177ac52cb5d8" + +[[package]] +name = "rustls-webpki" +version = "0.101.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe" +dependencies = [ + "ring", + "untrusted", +] [[package]] name = "schannel" @@ -1086,13 +834,23 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ - "bitflags 1.3.2", + "bitflags", "core-foundation", "core-foundation-sys", "libc", @@ -1126,30 +884,18 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn", ] [[package]] -name = "serde_json" -version = "1.0.107" +name = "sha1" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "serde_urlencoded" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" -dependencies = [ - "form_urlencoded", - "itoa", - "ryu", - "serde", + "cfg-if", + "cpufeatures", + "digest", ] [[package]] @@ -1170,6 +916,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simdutf8" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" + [[package]] name = "slab" version = "0.4.9" @@ -1206,21 +958,16 @@ dependencies = [ ] [[package]] -name = "strsim" -version = "0.10.0" +name = "spin" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] -name = "syn" -version = "1.0.109" +name = "strsim" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" @@ -1233,37 +980,24 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "tempfile" -version = "3.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" -dependencies = [ - "cfg-if", - "fastrand", - "redox_syscall", - "rustix", - "windows-sys", -] - [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn", ] [[package]] @@ -1341,58 +1075,47 @@ dependencies = [ ] [[package]] -name = "tokio-macros" -version = "2.1.0" +name = "tokio-fd" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5cedf0b897610a4baff98bf6116c060c5cfe7574d4339c50e9d23fe09377641d" dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.37", + "libc", + "tokio", ] [[package]] -name = "tokio-native-tls" -version = "0.3.1" +name = "tokio-macros" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ - "native-tls", - "tokio", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "tokio-openssl" -version = "0.6.3" +name = "tokio-rustls" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08f9ffb7809f1b20c1b398d92acf4cc719874b3b2b2d9ea2f09b4a80350878a" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "futures-util", - "openssl", - "openssl-sys", + "rustls", "tokio", ] [[package]] -name = "tokio-util" -version = "0.7.9" +name = "tokio-stream" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" dependencies = [ - "bytes", "futures-core", - "futures-sink", "pin-project-lite", "tokio", - "tracing", ] -[[package]] -name = "tower-layer" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" - [[package]] name = "tower-service" version = "0.3.2" @@ -1420,7 +1143,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn", ] [[package]] @@ -1463,57 +1186,18 @@ dependencies = [ "tracing-log", ] -[[package]] -name = "trust-dns-proto" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26" -dependencies = [ - "async-trait", - "cfg-if", - "data-encoding", - "enum-as-inner", - "futures-channel", - "futures-io", - "futures-util", - "idna 0.2.3", - "ipnet", - "lazy_static", - "rand", - "smallvec", - "thiserror", - "tinyvec", - "tokio", - "tracing", - "url", -] - -[[package]] -name = "trust-dns-resolver" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe" -dependencies = [ - "cfg-if", - "futures-util", - "ipconfig", - "lazy_static", - "lru-cache", - "parking_lot", - "resolv-conf", - "smallvec", - "thiserror", - "tokio", - "tracing", - "trust-dns-proto", -] - [[package]] name = "try-lock" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -1535,6 +1219,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "url" version = "2.4.1" @@ -1542,16 +1232,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" dependencies = [ "form_urlencoded", - "idna 0.4.0", + "idna", "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "uuid" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +dependencies = [ + "atomic", + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" @@ -1559,10 +1265,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" [[package]] -name = "vcpkg" -version = "0.2.15" +name = "version_check" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "want" @@ -1600,22 +1306,10 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.37", + "syn", "wasm-bindgen-shared", ] -[[package]] -name = "wasm-bindgen-futures" -version = "0.4.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" -dependencies = [ - "cfg-if", - "js-sys", - "wasm-bindgen", - "web-sys", -] - [[package]] name = "wasm-bindgen-macro" version = "0.2.87" @@ -1634,7 +1328,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1645,19 +1339,6 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" -[[package]] -name = "wasm-streams" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "web-sys" version = "0.3.64" @@ -1669,10 +1350,13 @@ dependencies = [ ] [[package]] -name = "widestring" -version = "1.0.2" +name = "webpki-roots" +version = "0.26.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" +checksum = "42157929d7ca9c353222a4d1763c52ef86d25d0fd2eca66076df5975fd4e25ed" +dependencies = [ + "rustls-pki-types", +] [[package]] name = "winapi" @@ -1762,26 +1446,29 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" -[[package]] -name = "winreg" -version = "0.50.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" -dependencies = [ - "cfg-if", - "windows-sys", -] - [[package]] name = "wstunnel" version = "0.1.0" dependencies = [ + "ahash", + "anyhow", + "base64", "clap", - "hyper 1.0.0-rc.4", - "hyper-openssl", - "reqwest", + "fastwebsockets", + "futures-util", + "hyper", + "libc", + "once_cell", + "pin-project", + "rustls-native-certs", + "scopeguard", "tokio", + "tokio-fd", + "tokio-rustls", + "tokio-stream", "tracing", "tracing-subscriber", "url", + "uuid", + "webpki-roots", ] diff --git a/Cargo.toml b/Cargo.toml index 50887986..98d78ebf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,12 +8,26 @@ edition = "2021" [dependencies] clap = { version = "4.4.5", features = ["derive"]} url = "2.4.1" +anyhow = "1.0.75" -reqwest = { version = "0.11.20", features = ["stream", "trust-dns"] } -hyper = { version = "1.0.0-rc.4", features = ["client", "http2"] } -hyper-openssl = {version = "0.9.2", features = []} +hyper = { version = "0.14.27", features = ["client", "runtime"] } +#fastwebsockets = { version = "0.4.4", features = ["upgrade"]} +fastwebsockets = { git = "https://github.com/mmastrac/fastwebsockets", branch = "split", features = ["upgrade", "simd"]} +libc = { version = "0.2.148", features = []} +once_cell = { version = "1.18.0", features = [] } +ahash = { version = "0.8.3", features = []} +pin-project = "1" +scopeguard = "1.2.0" +uuid = { version = "1.4.1", features = ["v7"] } +rustls-native-certs = { version = "0.6.3", features = [] } tokio = { version = "1.32.0", features = ["full"] } +tokio-rustls = { version = "0.24.1", features = ["tls12", "dangerous_configuration", "early-data"] } +tokio-stream = { version = "0.1.14", features = ["net"] } +tokio-fd = "0.3.0" +futures-util = { version = "0.3.28" } tracing = { version = "0.1.37", features = ["log"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter", "fmt", "local-time"] } +webpki-roots = "0.26.0-alpha.1" +base64 = "0.21.4" diff --git a/src/main.rs b/src/main.rs index 0065e1ba..570f87f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,36 @@ +mod stdio; +mod tcp; +mod tls; +mod transport; +mod udp; + +use base64::Engine; +use clap::Parser; +use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt}; +use hyper::http::HeaderValue; use std::borrow::Cow; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; +use std::fmt::{Display, Formatter}; +use std::io; use std::io::ErrorKind; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; -use clap::Parser; -use hyper::body::Body; -use hyper::Request; -use hyper_openssl::HttpsConnector; -use url::{Host, Url, UrlQuery}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; + +use tokio_rustls::rustls::server::DnsName; +use tokio_rustls::rustls::ServerName; +use tracing::{error, field, instrument, Instrument, Span}; + +use tracing_subscriber::EnvFilter; +use url::{Host, Url}; /// Simple program to greet a person #[derive(clap::Parser, Debug)] #[command(author, version, about, long_about = None)] struct Wstunnel { - #[command(subcommand)] commands: Commands, } @@ -22,139 +38,519 @@ struct Wstunnel { #[derive(clap::Subcommand, Debug)] enum Commands { Client(Client), - Server(Server) + Server(Server), } #[derive(clap::Args, Debug)] struct Client { /// Name of the person to greet #[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)] local_to_remote: Vec, + + /// (linux only) Mark network packet with SO_MARK sockoption with the specified value. + /// You need to use {root, sudo, capabilities} to run wstunnel when using this option + #[arg(long, value_name = "INT", verbatim_doc_comment)] + socket_so_mark: Option, + + /// Domain name that will be use as SNI during TLS handshake + /// Warning: If you are behind a CDN (i.e: Cloudflare) you must set this domain also in the http HOST header. + /// or it will be flag as fishy as your request rejected + #[arg(long, value_name = "DOMAIN_NAME", value_parser = parse_sni_override, verbatim_doc_comment)] + tls_sni_override: Option, + + /// Enable TLS certificate verification. + /// Disabled by default. The client will happily connect to any server with self signed certificate. + #[arg(long, verbatim_doc_comment)] + tls_verify_certificate: bool, + + /// Use a specific prefix that will show up in the http path during the upgrade request. + /// Useful if you need to route requests server side but don't have vhosts + #[arg(long, default_value = "wstunnel", verbatim_doc_comment)] + http_upgrade_path_prefix: String, + + /// Pass authorization header with basic auth credentials during the upgrade request. + /// If you need more customization, you can use the http_headers option. + #[arg(long, value_name = "USER[:PASS]", value_parser = parse_http_credentials, verbatim_doc_comment)] + http_upgrade_credentials: Option, + + /// Frequency at which the client will send websocket ping to the server. + #[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)] + websocket_ping_frequency_sec: Option, + + /// Enable the masking of websocket frames. Default is false + /// Enable this option only if you use unsecure (non TLS) websocket server and you see some issues. Otherwise, it is just overhead. + #[arg(long, default_value = "false", verbatim_doc_comment)] + websocket_mask_frame: bool, + + #[arg(short='H', long, value_name = "HEADER_NAME: HEADER_VALUE", value_parser = parse_http_headers, verbatim_doc_comment)] + http_headers: Vec<(String, HeaderValue)>, + + /// Address of the wstunnel server + /// Example: With TLS wss://wstunnel.example.com or without ws://wstunnel.example.com + #[arg(value_name = "ws[s]://wstunnel.server.com[:port]", value_parser = parse_server_url, verbatim_doc_comment)] + remote_addr: Url, } #[derive(clap::Args, Debug)] struct Server { - /// Name of the person to greet - #[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)] - local_to_remote: String, + /// Address of the wstunnel server to bind to + /// Example: With TLS wss://0.0.0.0:8080 or without ws://[::]:8080 + #[arg(value_name = "ws[s]://0.0.0.0[:port]", value_parser = parse_server_url, verbatim_doc_comment)] + remote_addr: Url, + + /// (linux only) Mark network packet with SO_MARK sockoption with the specified value. + /// You need to use {root, sudo, capabilities} to run wstunnel when using this option + #[arg(long, value_name = "INT", verbatim_doc_comment)] + socket_so_mark: Option, + + /// Frequency at which the server will send websocket ping to client. + #[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)] + websocket_ping_frequency_sec: Option, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum L4Protocol { - TCP, UDP { timeout: Duration } + Tcp, + Udp { timeout: Option }, + Stdio, +} + +impl Display for L4Protocol { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + L4Protocol::Tcp => f.write_str("tcp"), + L4Protocol::Udp { .. } => f.write_str("udp"), + L4Protocol::Stdio => f.write_str("tcp"), + } + } } impl L4Protocol { fn new_udp() -> L4Protocol { - L4Protocol::UDP { timeout: Duration::from_secs(30) } + L4Protocol::Udp { + timeout: Some(Duration::from_secs(30)), + } } } #[derive(Clone, Debug)] -struct LocalToRemote { +pub struct LocalToRemote { + socket_so_mark: Option, protocol: L4Protocol, local: SocketAddr, remote: (Host, u16), } -fn parse_env_var(arg: &str) -> Result { +fn parse_duration_sec(arg: &str) -> Result { + use std::io::Error; + + let Ok(secs) = arg.parse::() else { + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot duration of seconds from {}", arg), + )); + }; + + Ok(Duration::from_secs(secs)) +} + +fn parse_env_var(arg: &str) -> Result { use std::io::Error; let (mut protocol, arg) = match &arg[..6] { - "tcp://" => (L4Protocol::TCP, &arg[6..]), + "tcp://" => (L4Protocol::Tcp, &arg[6..]), "udp://" => (L4Protocol::new_udp(), &arg[6..]), - _ => (L4Protocol::TCP, arg) + _ => match &arg[..8] { + "stdio://" => (L4Protocol::Stdio, &arg[8..]), + _ => (L4Protocol::Tcp, arg), + }, }; let (bind, remaining) = if arg.starts_with('[') { // ipv6 bind let Some((ipv6_str, remaining)) = arg.split_once(']') else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", arg))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse IPv6 bind from {}", arg), + )); }; let Ok(ipv6_addr) = Ipv6Addr::from_str(&ipv6_str[1..]) else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", ipv6_str))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse IPv6 bind from {}", ipv6_str), + )); }; (IpAddr::V6(ipv6_addr), remaining) } else { // Maybe ipv4 addr let Some((ipv4_str, remaining)) = arg.split_once(':') else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv4 bind from {}", arg))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse IPv4 bind from {}", arg), + )); }; - match Ipv4Addr::from_str(ipv4_str) { + match Ipv4Addr::from_str(ipv4_str) { Ok(ip4_addr) => (IpAddr::V4(ip4_addr), remaining), // Must be the port, so we default to ipv6 bind - Err(_) => (IpAddr::V6(Ipv6Addr::from_str("::1").unwrap()), arg) + Err(_) => (IpAddr::V4(Ipv4Addr::from_str("127.0.0.1").unwrap()), arg), } }; let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", remaining))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse bind port from {}", remaining), + )); }; let Ok(bind_port): Result = port_str.parse() else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", port_str))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse bind port from {}", port_str), + )); }; - let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote from {}", remaining))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse remote from {}", remaining), + )); }; let Some(remote_host) = remote.host() else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote host from {}", remaining))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse remote host from {}", remaining), + )); }; let Some(remote_port) = remote.port() else { - return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote port from {}", remaining))); + return Err(Error::new( + ErrorKind::InvalidInput, + format!("cannot parse remote port from {}", remaining), + )); }; + let options: BTreeMap, Cow<'_, str>> = remote.query_pairs().collect(); match &mut protocol { - L4Protocol::TCP => {} - L4Protocol::UDP { ref mut timeout, .. } => { - let options: BTreeMap, Cow<'_, str>> = remote.query_pairs().collect(); - if let Some(duration) = options.get("timeout_sec") + L4Protocol::Stdio => {} + L4Protocol::Tcp => {} + L4Protocol::Udp { + ref mut timeout, .. + } => { + if let Some(duration) = options + .get("timeout_sec") .and_then(|x| x.parse::().ok()) - .map(|x| Duration::from_secs(x)) { + .map(|d| { + if d == 0 { + None + } else { + Some(Duration::from_secs(d)) + } + }) + { *timeout = duration; } } }; Ok(LocalToRemote { + socket_so_mark: options + .get("socket_so_mark") + .and_then(|x| x.parse::().ok()), protocol, local: SocketAddr::new(bind, bind_port), - remote: (remote_host.to_owned(), remote_port) + remote: (remote_host.to_owned(), remote_port), }) } -fn main() { - println!("Hello, world!"); +fn parse_sni_override(arg: &str) -> Result { + match DnsName::try_from(arg.to_string()) { + Ok(val) => Ok(val), + Err(err) => Err(io::Error::new( + ErrorKind::InvalidInput, + format!("Invalid sni override: {}", err), + )), + } +} + +fn parse_http_headers(arg: &str) -> Result<(String, HeaderValue), io::Error> { + let Some((key, value)) = arg.split_once(':') else { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!("cannot parse http header from {}", arg), + )); + }; + + let value = match HeaderValue::from_str(value.trim()) { + Ok(value) => value, + Err(err) => { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!( + "cannot parse http header value from {} due to {:?}", + value, err + ), + )) + } + }; + + Ok((key.to_owned(), value)) +} + +fn parse_http_credentials(arg: &str) -> Result { + let encoded = base64::engine::general_purpose::STANDARD.encode(arg.trim().as_bytes()); + let Ok(header) = HeaderValue::from_str(&format!("Basic {}", encoded)) else { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!("cannot parse http credentials {}", arg), + )); + }; + + Ok(header) +} + +fn parse_server_url(arg: &str) -> Result { + let Ok(url) = Url::parse(arg) else { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!("cannot parse server url {}", arg), + )); + }; + + if url.scheme() != "ws" && url.scheme() != "wss" { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!("invalid scheme {}", url.scheme()), + )); + } + + if url.host().is_none() { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!("invalid server host {}", arg), + )); + } + + Ok(url) +} + +#[derive(Clone, Debug)] +pub struct TlsConfig { + pub tls_sni_override: Option, + pub tls_verify_certificate: bool, +} + +#[derive(Clone, Debug)] +pub struct WsServerConfig { + pub socket_so_mark: Option, + pub bind: SocketAddr, + pub restrict_to: Option>, + pub websocket_ping_frequency: Option, + pub timeout_connect: Duration, +} + +#[derive(Clone, Debug)] +pub struct WsClientConfig { + pub remote_addr: (Host, u16), + pub tls: Option, + pub http_upgrade_path_prefix: String, + pub http_upgrade_credentials: Option, + pub http_headers: HashMap, + pub timeout_connect: Duration, + pub websocket_ping_frequency: Duration, + pub websocket_mask_frame: bool, +} + +impl WsClientConfig { + pub fn websocket_scheme(&self) -> &'static str { + match self.tls { + None => "ws", + Some(_) => "wss", + } + } + + pub fn websocket_host_url(&self) -> String { + format!("{}:{}", self.remote_addr.0, self.remote_addr.1) + } + + pub fn tls_server_name(&self) -> ServerName { + match self + .tls + .as_ref() + .and_then(|tls| tls.tls_sni_override.as_ref()) + { + None => match &self.remote_addr.0 { + Host::Domain(domain) => { + ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()) + } + Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)), + Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)), + }, + Some(sni_override) => ServerName::DnsName(sni_override.clone()), + } + } +} + +#[tokio::main] +async fn main() { let args = Wstunnel::parse(); - println!("Hello {:?}!", args) + // Setup logging + match &args.commands { + // Disable logging if there is a stdio tunnel + Commands::Client(args) + if args + .local_to_remote + .iter() + .filter(|x| x.protocol == L4Protocol::Stdio) + .count() + > 0 => {} + _ => { + tracing_subscriber::fmt() + .with_ansi(true) + .with_env_filter(EnvFilter::from_default_env()) + .init(); + } + } + + match args.commands { + Commands::Client(args) => { + let tls = match args.remote_addr.scheme() { + "ws" => None, + "wss" => Some(TlsConfig { + tls_sni_override: args.tls_sni_override, + tls_verify_certificate: args.tls_verify_certificate, + }), + _ => panic!("invalid scheme in server url {}", args.remote_addr.scheme()), + }; + + let server_config = Arc::new(WsClientConfig { + remote_addr: ( + args.remote_addr.host().unwrap().to_owned(), + args.remote_addr.port_or_known_default().unwrap(), + ), + tls, + http_upgrade_path_prefix: args.http_upgrade_path_prefix, + http_upgrade_credentials: args.http_upgrade_credentials, + http_headers: args.http_headers.into_iter().collect(), + timeout_connect: Duration::from_secs(10), + websocket_ping_frequency: args + .websocket_ping_frequency_sec + .unwrap_or(Duration::from_secs(30)), + websocket_mask_frame: args.websocket_mask_frame, + }); + + // Start tunnels + for tunnel in args.local_to_remote.into_iter() { + let server_config = server_config.clone(); + + match &tunnel.protocol { + L4Protocol::Tcp => { + let server = tcp::run_server(tunnel.local) + .await + .unwrap_or_else(|err| { + panic!("Cannot start TCP server on {}: {}", tunnel.local, err) + }) + .map_ok(TcpStream::into_split); + + tokio::spawn(async move { + if let Err(err) = run_tunnel(server_config, tunnel, server).await { + error!("{:?}", err); + } + }); + } + L4Protocol::Udp { timeout } => { + let server = udp::run_server(tunnel.local, *timeout) + .await + .unwrap_or_else(|err| { + panic!("Cannot start UDP server on {}: {}", tunnel.local, err) + }) + .map_ok(tokio::io::split); - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(10)) - .connect_timeout(Duration::from_secs(10)) - .danger_accept_invalid_certs(true) - .build().unwrap(); + tokio::spawn(async move { + if let Err(err) = run_tunnel(server_config, tunnel, server).await { + error!("{:?}", err); + } + }); + } + L4Protocol::Stdio => { + let server = stdio::run_server().await.unwrap_or_else(|err| { + panic!("Cannot start STDIO server: {}", err); + }); + tokio::spawn(async move { + if let Err(err) = run_tunnel( + server_config, + tunnel, + stream::once(async move { Ok(server) }), + ) + .await + { + error!("{:?}", err); + } + }); + } + } + } + } + Commands::Server(args) => { + let server_config = WsServerConfig { + socket_so_mark: args.socket_so_mark, + bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0], + restrict_to: None, + websocket_ping_frequency: args.websocket_ping_frequency_sec, + timeout_connect: Duration::from_secs(10), + }; + + transport::run_server(Arc::new(server_config)) + .await + .unwrap_or_else(|err| { + panic!("Cannot start wstunnel server: {}", err); + }); + } + } + tokio::signal::ctrl_c().await.unwrap(); +} - let mut conn = HttpsConnector::new()?; - conn.set_callback(move |c, _| { - // Prevent native TLS lib from inferring and verifying a default SNI. - c.set_use_server_name_indication(false); - c.set_verify_hostname(false); +#[instrument(name="tunnel", level="info", skip_all, fields(id=%uuid::Uuid::now_v7(), remote=field::Empty))] +async fn run_tunnel( + server_config: Arc, + tunnel: LocalToRemote, + incoming_cnx: T, +) -> anyhow::Result<()> +where + T: Stream>, + R: AsyncRead + Send + 'static, + W: AsyncWrite + Send + 'static, +{ + let span = Span::current(); + span.record( + "remote", + &format!("{}:{}", tunnel.remote.0, tunnel.remote.1), + ); - // And set a custom SNI instead. - c.set_hostname("somewhere.com") - }); - Client::builder() - .build::<_, Body>(conn) - .request(Request::get("somewhere-else.com").body(())?) - .await?; + let tunnel = Arc::new(tunnel); + pin_mut!(incoming_cnx); - reqwest::Proxy::all("https://google.com").unwrap().basic_auth("", "") + while let Some(Ok(cnx_stream)) = incoming_cnx.next().await { + let server_config = server_config.clone(); + let tunnel = tunnel.clone(); + + tokio::spawn( + async move { + let ret = transport::connect_to_server(&server_config, &tunnel, cnx_stream).await; + + if let Err(ret) = ret { + error!("{:?}", ret); + } + + anyhow::Ok(()) + } + .instrument(span.clone()), + ); + } + Ok(()) } diff --git a/src/stdio.rs b/src/stdio.rs new file mode 100644 index 00000000..7403ea00 --- /dev/null +++ b/src/stdio.rs @@ -0,0 +1,19 @@ +#![allow(unused_imports)] + +use libc::STDIN_FILENO; +use std::os::fd::{AsRawFd, FromRawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::fs::File; +use tokio::io::{stdout, AsyncRead, ReadBuf, Stdout}; +use tokio_fd::AsyncFd; +use tracing::info; + +pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> { + info!("Starting STDIO server"); + + let stdin = AsyncFd::try_from(libc::STDIN_FILENO)?; + let stdout = AsyncFd::try_from(libc::STDOUT_FILENO)?; + + Ok((stdin, stdout)) +} diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 00000000..f7884c0a --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,110 @@ +use anyhow::{anyhow, Context}; +use std::{io, vec}; + +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::os::fd::AsRawFd; +use std::time::Duration; +use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio::time::timeout; +use tokio_stream::wrappers::TcpListenerStream; +use tracing::debug; +use tracing::log::info; +use url::Host; + +fn configure_socket(socket: &mut TcpSocket, so_mark: &Option) -> Result<(), anyhow::Error> { + socket.set_nodelay(true).with_context(|| { + format!( + "cannot set no_delay on socket: {}", + io::Error::last_os_error() + ) + })?; + + if let Some(so_mark) = so_mark { + unsafe { + let optval: libc::c_int = *so_mark; + let ret = libc::setsockopt( + socket.as_raw_fd(), + libc::SOL_SOCKET, + libc::SO_MARK, + &optval as *const _ as *const libc::c_void, + std::mem::size_of_val(&optval) as libc::socklen_t, + ); + + if ret != 0 { + return Err(anyhow!( + "Cannot set SO_MARK on the connection {:?}", + io::Error::last_os_error() + )); + } + } + } + + Ok(()) +} +pub async fn connect( + host: &Host, + port: u16, + so_mark: &Option, + connect_timeout: Duration, +) -> Result { + info!("Opening TCP connection to {}:{}", host, port); + + // TODO: Avoid allocation of vec by extracting the code that does the connection in a separate function + let socket_addrs: Vec = match host { + Host::Domain(domain) => tokio::net::lookup_host(format!("{}:{}", domain, port)) + .await + .with_context(|| format!("cannot resolve domain: {}", domain))? + .collect(), + Host::Ipv4(ip) => vec![SocketAddr::V4(SocketAddrV4::new(*ip, port))], + Host::Ipv6(ip) => vec![SocketAddr::V6(SocketAddrV6::new(*ip, port, 0, 0))], + }; + + let mut cnx = None; + let mut last_err = None; + for addr in socket_addrs { + debug!("connecting to {}", addr); + + let mut socket = match &addr { + SocketAddr::V4(_) => TcpSocket::new_v4()?, + SocketAddr::V6(_) => TcpSocket::new_v6()?, + }; + + configure_socket(&mut socket, so_mark)?; + match timeout(connect_timeout, socket.connect(addr)).await { + Ok(Ok(stream)) => { + cnx = Some(stream); + break; + } + Ok(Err(err)) => { + debug!("Cannot connect to tcp endpoint {addr} reason {err}"); + last_err = Some(err); + } + Err(_) => { + debug!( + "Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed", + connect_timeout.as_secs() + ); + } + } + } + + if let Some(cnx) = cnx { + Ok(cnx) + } else { + Err(anyhow!( + "Cannot connect to tcp endpoint {}:{} reason {:?}", + host, + port, + last_err + )) + } +} + +pub async fn run_server(bind: SocketAddr) -> Result { + info!("Starting TCP server listening cnx on {}", bind); + + let listener = TcpListener::bind(bind) + .await + .with_context(|| format!("Cannot create TCP server {:?}", bind))?; + Ok(TcpListenerStream::new(listener)) +} diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 00000000..e2a56251 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,82 @@ +use crate::{TlsConfig, WsClientConfig}; +use anyhow::Context; +use std::sync::Arc; +use std::time::SystemTime; +use tokio::net::TcpStream; +use tokio_rustls::client::TlsStream; +use tokio_rustls::rustls::client::{ServerCertVerified, ServerCertVerifier}; +use tokio_rustls::rustls::{Certificate, ClientConfig, ServerName}; +use tokio_rustls::{rustls, TlsConnector}; +use tracing::info; + +pub struct NullVerifier; +impl ServerCertVerifier for NullVerifier { + fn verify_server_cert( + &self, + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +} + +fn tls_connector( + tls_cfg: &TlsConfig, + alpn_protocols: Option>>, +) -> anyhow::Result { + let mut root_store = rustls::RootCertStore::empty(); + + // Load system certificates and add them to the root store + let certs = rustls_native_certs::load_native_certs() + .with_context(|| "Cannot load system certificates")?; + for cert in certs { + root_store.add(&Certificate(cert.0)).unwrap(); + } + + let mut config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + + // To bypass certificate verification + if !tls_cfg.tls_verify_certificate { + config + .dangerous() + .set_certificate_verifier(Arc::new(NullVerifier)); + } + + if let Some(alpn_protocols) = alpn_protocols { + config.alpn_protocols = alpn_protocols; + } + let tls_connector = TlsConnector::from(Arc::new(config)); + Ok(tls_connector) +} + +pub async fn connect( + server_cfg: &WsClientConfig, + tls_cfg: &TlsConfig, + tcp_stream: TcpStream, +) -> anyhow::Result> { + let sni = server_cfg.tls_server_name(); + info!( + "Doing TLS handshake using sni {sni:?} with the server {}:{}", + server_cfg.remote_addr.0, server_cfg.remote_addr.1 + ); + + let tls_connector = tls_connector(tls_cfg, Some(vec![b"http/1.1".to_vec()]))?; + let tls_stream = tls_connector + .connect(sni, tcp_stream) + .await + .with_context(|| { + format!( + "failed to do TLS handshake with the server {:?}", + server_cfg.remote_addr + ) + })?; + + Ok(tls_stream) +} diff --git a/src/transport.rs b/src/transport.rs new file mode 100644 index 00000000..22afb227 --- /dev/null +++ b/src/transport.rs @@ -0,0 +1,371 @@ +#![allow(unused_imports)] + +use std::future::Future; +use std::net::Ipv4Addr; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use crate::{tcp, tls, L4Protocol, LocalToRemote, WsClientConfig, WsServerConfig}; +use anyhow::Context; +use fastwebsockets::upgrade::UpgradeFut; +use fastwebsockets::{ + Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite, +}; +use futures_util::{pin_mut, StreamExt}; +use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_VERSION, UPGRADE}; +use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; +use hyper::server::conn::Http; +use hyper::service::service_fn; +use hyper::upgrade::Upgraded; +use hyper::{http, Body, Request, Response, StatusCode}; +use tokio::io::{ + AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Interest, ReadHalf, WriteHalf, +}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::select; +use tokio::sync::oneshot; +use tokio::time::error::Elapsed; +use tokio::time::timeout; + +use crate::udp::{MyUdpSocket, UdpStream}; +use crate::L4Protocol::{Tcp, Udp}; +use tracing::log::debug; +use tracing::{error, field, info, instrument, trace, warn, Instrument, Span}; +use url::quirks::host; +use url::Host; + +struct SpawnExecutor; + +impl hyper::rt::Executor for SpawnExecutor +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } +} + +pub async fn connect( + server_cfg: &WsClientConfig, + tunnel_cfg: &LocalToRemote, +) -> anyhow::Result> { + let (host, port) = &server_cfg.remote_addr; + let tcp_stream = tcp::connect( + host, + *port, + &tunnel_cfg.socket_so_mark, + server_cfg.timeout_connect, + ) + .await?; + + let mut req = Request::builder() + .method("GET") + .uri(format!( + "/{}/{}/{}/{}", + &server_cfg.http_upgrade_path_prefix, + tunnel_cfg.protocol, + tunnel_cfg.remote.0, + tunnel_cfg.remote.1, + )) //stream we want to subscribe to + .header(HOST, server_cfg.remote_addr.0.to_string()) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) + .header(SEC_WEBSOCKET_VERSION, "13") + .version(hyper::Version::HTTP_11); + + for (k, v) in &server_cfg.http_headers { + req = req.header(k.clone(), v.clone()); + } + if let Some(auth) = &server_cfg.http_upgrade_credentials { + req = req.header(AUTHORIZATION, auth.clone()); + } + + let req = req.body(Body::empty()).with_context(|| { + format!( + "failed to build HTTP request to contact the server {:?}", + server_cfg.remote_addr + ) + })?; + debug!("with HTTP upgrade request {:?}", req); + + let ws_handshake = match &server_cfg.tls { + None => fastwebsockets::handshake::client(&SpawnExecutor, req, tcp_stream).await, + Some(tls_cfg) => { + let tls_stream = tls::connect(server_cfg, tls_cfg, tcp_stream).await?; + fastwebsockets::handshake::client(&SpawnExecutor, req, tls_stream).await + } + }; + + let (ws, _) = ws_handshake.with_context(|| { + format!( + "failed to do websocket handshake with the server {:?}", + server_cfg.remote_addr + ) + })?; + + Ok(ws) +} + +pub async fn connect_to_server( + server_config: &WsClientConfig, + remote_cfg: &LocalToRemote, + duplex_stream: (R, W), +) -> anyhow::Result<()> +where + R: AsyncRead + Send + 'static, + W: AsyncWrite + Send + 'static, +{ + let mut ws = connect(server_config, remote_cfg).await?; + ws.set_auto_apply_mask(server_config.websocket_mask_frame); + + let (ws_rx, ws_tx) = ws.split(tokio::io::split); + let (local_rx, local_tx) = duplex_stream; + let (close_tx, close_rx) = oneshot::channel::<()>(); + + // Forward local tx to websocket tx + let ping_frequency = server_config.websocket_ping_frequency; + tokio::spawn( + propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()), + ); + + // Forward websocket rx to local rx + let _ = propagate_write(local_tx, ws_rx, close_rx, server_config.timeout_connect).await; + + Ok(()) +} + +async fn from_path( + server_config: &WsServerConfig, + paths: &[&str], +) -> anyhow::Result<( + L4Protocol, + Host, + u16, + Pin>, + Pin>, +)> { + match paths { + [_, _, "udp", dest, port, ..] => { + let host = Host::parse(dest)?; + let port = port.parse::()?; + let cnx = Arc::new(UdpSocket::bind("0").await?); + cnx.connect((host.to_string(), port)).await?; + Ok(( + Udp { timeout: None }, + host, + port, + Box::pin(MyUdpSocket::new(cnx.clone())), + Box::pin(MyUdpSocket::new(cnx)), + )) + } + [_, _, "tcp", dest, port, ..] => { + let host = Host::parse(dest)?; + let port = port.parse::()?; + let (rx, tx) = tcp::connect( + &host, + port, + &server_config.socket_so_mark, + Duration::from_secs(10), + ) + .await? + .into_split(); + Ok((Tcp, host, port, Box::pin(rx), Box::pin(tx))) + } + _ => Err(anyhow::anyhow!("Invalid upgrade request")), + } +} + +#[instrument(name="tunnel", level="info", skip_all, fields(id=field::Empty, remote=field::Empty))] +async fn server_upgrade( + server_config: Arc, + mut req: Request, +) -> Result, anyhow::Error> { + let paths: Vec<&str> = req.uri().path().split('/').collect(); + info!("path {:?} {}", paths, req.uri()); + let (protocol, dest, port, local_rx, local_tx) = match from_path(&server_config, &paths).await { + Ok(ret) => ret, + Err(err) => { + return Ok(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from(format!("Error: {:?}", err))) + .unwrap_or_default()); + } + }; + + Span::current().record("remote", format!("{}:{}", dest, port)); + info!("connected to {:?} {:?} {:?}", protocol, dest, port); + let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { + Ok(ret) => ret, + Err(err) => { + return Ok(http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from(format!("Invalid upgrade request: {:?}", err))) + .unwrap_or_default()) + } + }; + + tokio::spawn( + async move { + let (ws_rx, ws_tx) = fut.await.unwrap().split(tokio::io::split); + let (close_tx, close_rx) = oneshot::channel::<()>(); + let connect_timeout = server_config.timeout_connect; + let ping_frequency = server_config + .websocket_ping_frequency + .unwrap_or(Duration::MAX); + + tokio::task::spawn( + propagate_write(local_tx, ws_rx, close_rx, connect_timeout) + .instrument(Span::current()), + ); + + let _ = propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await; + } + .instrument(Span::current()), + ); + + Ok(response) +} + +pub async fn run_server(server_config: Arc) -> anyhow::Result<()> { + info!( + "Starting wstunnel server listening on {}", + server_config.bind + ); + + let config = server_config.clone(); + let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req); + + let listener = TcpListener::bind(&server_config.bind).await?; + loop { + let (stream, peer_addr) = listener.accept().await?; + info!("Accepting connection from {}", peer_addr); + let upgrade_fn = upgrade_fn.clone(); + let _ = stream.set_nodelay(true); + let conn_fut = Http::new() + .http1_only(true) + .serve_connection(stream, service_fn(upgrade_fn)) + .with_upgrades(); + + tokio::spawn(async move { + if let Err(e) = conn_fut.await { + error!("An error occurred: {:?}", e); + } + }); + } +} + +async fn propagate_read( + local_rx: impl AsyncRead, + mut ws_tx: WebSocketWrite>, + mut close_tx: oneshot::Sender<()>, + ping_frequency: Duration, +) -> Result<(), WebSocketError> { + let _guard = scopeguard::guard((), |_| { + info!("Closing local tx ==> websocket tx tunnel"); + }); + + let mut buffer = vec![0u8; 8 * 1024]; + pin_mut!(local_rx); + loop { + let read = select! { + biased; + + read_len = local_rx.read(buffer.as_mut_slice()) => read_len, + + _ = close_tx.closed() => break, + + _ = timeout(ping_frequency, futures_util::future::pending::<()>()) => { + debug!("sending ping to keep websocket connection alive"); + ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?; + continue; + } + }; + + let read_len = match read { + Ok(read_len) if read_len > 0 => read_len, + Ok(_) => break, + Err(err) => { + warn!( + "error while reading incoming bytes from local tx tunnel {}", + err + ); + break; + } + }; + + trace!("read {} bytes", read_len); + match ws_tx + .write_frame(Frame::binary(Payload::Borrowed(&buffer[0..read_len]))) + .await + { + Ok(_) => {} + Err(err) => { + warn!("error while writing to websocket tx tunnel {}", err); + break; + } + } + } + + Ok(()) +} + +async fn propagate_write( + local_tx: impl AsyncWrite, + mut ws_rx: WebSocketRead>, + mut close_rx: oneshot::Receiver<()>, + timeout_connect: Duration, +) -> Result<(), WebSocketError> { + let _guard = scopeguard::guard((), |_| { + info!("Closing local rx <== websocket rx tunnel"); + }); + let mut x = |x: Frame<'_>| { + debug!("frame {:?} {:?}", x.opcode, x.payload); + futures_util::future::ready(anyhow::Ok(())) + }; + + pin_mut!(local_tx); + loop { + let ret = select! { + biased; + ret = timeout(timeout_connect, ws_rx.read_frame(&mut x)) => ret, + + _ = &mut close_rx => break, + }; + + let msg = match ret { + Ok(Ok(msg)) => msg, + Ok(Err(err)) => { + error!("error while reading from websocket rx {}", err); + break; + } + Err(err) => { + trace!("frame {:?}", err); + // TODO: Check that the connection is not closed (no easy method to know if a tx is closed ...) + continue; + } + }; + + trace!("frame {:?} {:?}", msg.opcode, msg.payload); + let ret = match msg.opcode { + OpCode::Continuation | OpCode::Text | OpCode::Binary => { + local_tx.write_all(msg.payload.as_ref()).await + } + OpCode::Close => break, + OpCode::Ping => Ok(()), + OpCode::Pong => Ok(()), + }; + + match ret { + Ok(_) => {} + Err(err) => { + error!("error while writing bytes to local for rx tunnel {}", err); + break; + } + } + } + + Ok(()) +} diff --git a/src/udp.rs b/src/udp.rs new file mode 100644 index 00000000..73a5d35d --- /dev/null +++ b/src/udp.rs @@ -0,0 +1,241 @@ +#![allow(unused_imports)] + +use anyhow::Context; +use futures_util::future::join; +use futures_util::{stream, FutureExt, Stream}; +use hyper::server; +use libc::poll; +use pin_project::{pin_project, pinned_drop}; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::future::Future; +use std::io; +use std::io::{Error, ErrorKind, IoSlice}; +use std::net::SocketAddr; +use std::pin::{pin, Pin}; +use std::sync::{Arc, RwLock, Weak}; +use std::task::Poll; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf}; +use tokio::net::UdpSocket; +use tokio::time::Sleep; +use tracing::{debug, error, info}; + +const DEFAULT_UDP_BUFFER_SIZE: usize = 8 * 1024; + +struct UdpServer { + listener: UdpSocket, + std_socket: std::net::UdpSocket, + buffer: Vec, + peers: HashMap, + keys_to_delete: Arc>>, + pub cnx_timeout: Option, +} + +impl UdpServer { + pub fn new(listener: UdpSocket, timeout: Option) -> Self { + let socket = listener.into_std().unwrap(); + let listener = UdpSocket::from_std(socket.try_clone().unwrap()).unwrap(); + Self { + listener, + std_socket: socket, + peers: HashMap::with_hasher(ahash::RandomState::new()), + buffer: vec![0u8; DEFAULT_UDP_BUFFER_SIZE], + keys_to_delete: Default::default(), + cnx_timeout: timeout, + } + } + + fn clean_dead_keys(&mut self) { + let nb_key_to_delete = self.keys_to_delete.read().unwrap().len(); + if nb_key_to_delete == 0 { + return; + } + + debug!("Cleaning {} dead udp peers", nb_key_to_delete); + let mut keys_to_delete = self.keys_to_delete.write().unwrap(); + for key in keys_to_delete.iter() { + self.peers.remove(key); + } + keys_to_delete.clear(); + } + + fn clone_socket(&self) -> UdpSocket { + UdpSocket::from_std(self.std_socket.try_clone().unwrap()).unwrap() + } +} + +#[pin_project(PinnedDrop)] +pub struct UdpStream { + socket: UdpSocket, + peer: SocketAddr, + #[pin] + deadline: Option, + #[pin] + io: DuplexStream, + keys_to_delete: Weak>>, +} + +impl AsMut for UdpStream { + fn as_mut(&mut self) -> &mut DuplexStream { + &mut self.io + } +} + +#[pinned_drop] +impl PinnedDrop for UdpStream { + fn drop(self: Pin<&mut Self>) { + if let Some(keys_to_delete) = self.keys_to_delete.upgrade() { + keys_to_delete.write().unwrap().push(self.peer); + } + } +} + +impl AsyncRead for UdpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let project = self.project(); + if let Some(deadline) = project.deadline.as_pin_mut() { + if deadline.poll(cx).is_ready() { + return Poll::Ready(Err(Error::new( + ErrorKind::TimedOut, + format!("UDP stream timeout with {}", project.peer), + ))); + } + } + + project.io.poll_read(cx, buf) + } +} + +impl AsyncWrite for UdpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.socket.poll_send_to(cx, buf, self.peer) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.socket.poll_send_ready(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +pub async fn run_server( + bind: SocketAddr, + timeout: Option, +) -> Result>, anyhow::Error> { + info!( + "Starting UDP server listening cnx on {} with cnx timeout of {}s", + bind, + timeout.unwrap_or(Duration::from_secs(0)).as_secs() + ); + + let listener = UdpSocket::bind(bind) + .await + .with_context(|| format!("Cannot create UDP server {:?}", bind))?; + + let udp_server = UdpServer::new(listener, timeout); + let stream = stream::unfold(udp_server, |mut server| async { + loop { + server.clean_dead_keys(); + let (nb_bytes, peer_addr) = match server.listener.recv_from(&mut server.buffer).await { + Ok(ret) => ret, + Err(err) => { + error!("Cannot read from UDP server. Closing server: {}", err); + return None; + } + }; + + match server.peers.entry(peer_addr) { + Entry::Occupied(mut peer) => { + let ret = peer.get_mut().write_all(&server.buffer[0..nb_bytes]).await; + if let Err(err) = ret { + info!("Peer {:?} disconnected {:?}", peer_addr, err); + peer.remove(); + } + } + Entry::Vacant(peer) => { + let (mut rx, tx) = tokio::io::duplex(DEFAULT_UDP_BUFFER_SIZE); + rx.write_all(&server.buffer[0..nb_bytes]) + .await + .unwrap_or_default(); // should never fail + peer.insert(rx); + let udp_client = UdpStream { + socket: server.clone_socket(), + peer: peer_addr, + deadline: server + .cnx_timeout + .and_then(|timeout| tokio::time::Instant::now().checked_add(timeout)) + .map(tokio::time::sleep_until), + keys_to_delete: Arc::downgrade(&server.keys_to_delete), + io: tx, + }; + return Some((Ok(udp_client), (server))); + } + } + } + }); + + Ok(stream) +} + +pub struct MyUdpSocket { + socket: Arc, +} + +impl MyUdpSocket { + pub fn new(socket: Arc) -> Self { + Self { socket } + } +} + +impl AsyncRead for MyUdpSocket { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + unsafe { self.map_unchecked_mut(|x| &mut x.socket) } + .poll_recv_from(cx, buf) + .map(|x| x.map(|_| ())) + } +} + +impl AsyncWrite for MyUdpSocket { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +}