Skip to content

Commit

Permalink
Merge branch 'master' into stable
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbers committed Jul 5, 2018
2 parents 93f71f5 + 6f8bfdb commit 47f7c08
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 41 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ The proxy can be launched:
- with a custom config: `python3 mtprotoproxy.py [configfile]`
- several times, clients will be automaticaly balanced between instances
- using *PyPy* interprteter
- with runtime statistics exported for [Prometheus](https://prometheus.io/): using [prometheus](https://github.com/alexbers/mtprotoproxy/tree/prometheus) branch
113 changes: 72 additions & 41 deletions mtprotoproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ def debug_signal(signum, frame):
# disables tg->client trafic reencryption, faster but less secure
FAST_MODE = config.get("FAST_MODE", True)
STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600)
PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 60*60*24)
READ_BUF_SIZE = config.get("READ_BUF_SIZE", 16384)
WRITE_BUF_SIZE = config.get("WRITE_BUF_SIZE", 65536)
CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 60*30)
PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60)
TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 8192)
TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536)
CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 10*60)
CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10)
CLIENT_ACK_TIMEOUT = config.get("CLIENT_ACK_TIMEOUT", 5*60)

TG_DATACENTER_PORT = 443

Expand Down Expand Up @@ -203,6 +204,7 @@ def debug_signal(signum, frame):

PROTO_TAG_ABRIDGED = b"\xef\xef\xef\xef"
PROTO_TAG_INTERMEDIATE = b"\xee\xee\xee\xee"
PROTO_TAG_SECURE = b"\xdd\xdd\xdd\xdd"

CBC_PADDING = 16
PADDING_FILLER = b"\x04\x00\x00\x00"
Expand All @@ -222,14 +224,14 @@ def init_stats():
stats = {user: collections.Counter() for user in USERS}


def update_stats(user, connects=0, curr_connects=0, octets=0):
def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0):
global stats

if user not in stats:
stats[user] = collections.Counter()

stats[user].update(connects=connects, curr_connects=curr_connects,
octets=octets)
octets=octets, msgs=msgs)


class LayeredStreamReaderBase:
Expand Down Expand Up @@ -429,6 +431,10 @@ async def read(self, buf_size):

data = await self.upstream.readexactly(msg_len)

if msg_len % 4 != 0:
cut_border = msg_len - (msg_len % 4)
data = data[:cut_border]

return data, extra


Expand Down Expand Up @@ -548,7 +554,7 @@ async def handle_handshake(reader, writer):
decrypted = decryptor.decrypt(handshake)

proto_tag = decrypted[PROTO_TAG_POS:PROTO_TAG_POS+4]
if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE):
if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE):
continue

dc_idx = int.from_bytes(decrypted[DC_IDX_POS:DC_IDX_POS+2], "little", signed=True)
Expand All @@ -557,13 +563,34 @@ async def handle_handshake(reader, writer):
writer = CryptoWrappedStreamWriter(writer, encryptor)
return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv

while await reader.read(READ_BUF_SIZE):
EMPTY_READ_BUF_SIZE = 4096
while await reader.read(EMPTY_READ_BUF_SIZE):
# just consume all the data
pass

return False


def set_keepalive(sock, interval=40, attempts=5):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval)
if hasattr(socket, "TCP_KEEPINTVL"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval)
if hasattr(socket, "TCP_KEEPCNT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts)


def set_ack_timeout(sock, timeout):
if hasattr(socket, "TCP_USER_TIMEOUT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout*1000)


def set_bufsizes(sock, recv_buf, send_buf):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf)


async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
RESERVED_NONCE_FIRST_CHARS = [b"\xef"]
RESERVED_NONCE_BEGININGS = [b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54",
Expand All @@ -583,7 +610,10 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):

try:
reader_tgt, writer_tgt = await asyncio.open_connection(dc, TG_DATACENTER_PORT,
limit=READ_BUF_SIZE)
limit=TO_CLT_BUFSIZE)
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)

except ConnectionRefusedError as E:
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
return False
Expand Down Expand Up @@ -653,21 +683,6 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por
return key, iv


def set_keepalive(sock, interval=40, attempts=5):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval)
if hasattr(socket, "TCP_KEEPINTVL"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval)
if hasattr(socket, "TCP_KEEPCNT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts)


def set_bufsizes(sock, recv_buf=READ_BUF_SIZE, send_buf=WRITE_BUF_SIZE):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf)


async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
START_SEQ_NO = -2
NONCE_LEN = 16
Expand Down Expand Up @@ -695,9 +710,9 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx])

try:
reader_tgt, writer_tgt = await asyncio.open_connection(addr, port, limit=READ_BUF_SIZE)
reader_tgt, writer_tgt = await asyncio.open_connection(addr, port, limit=TO_CLT_BUFSIZE)
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)
except ConnectionRefusedError as E:
print_err("Got connection refused while trying to connect to", addr, port)
return False
Expand All @@ -719,7 +734,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):

old_reader = reader_tgt
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
ans = await reader_tgt.read(READ_BUF_SIZE)
ans = await reader_tgt.read(TO_CLT_BUFSIZE)

if len(ans) != RPC_NONCE_ANS_LEN:
return False
Expand Down Expand Up @@ -800,8 +815,9 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):


async def handle_client(reader_clt, writer_clt):
set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE)
set_bufsizes(writer_clt.get_extra_info("socket"))
set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3)
set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT)
set_bufsizes(writer_clt.get_extra_info("socket"), TO_TG_BUFSIZE, TO_CLT_BUFSIZE)

try:
clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt),
Expand Down Expand Up @@ -846,16 +862,16 @@ def decrypt(self, data):
if proto_tag == PROTO_TAG_ABRIDGED:
reader_clt = MTProtoCompactFrameStreamReader(reader_clt)
writer_clt = MTProtoCompactFrameStreamWriter(writer_clt)
elif proto_tag == PROTO_TAG_INTERMEDIATE:
elif proto_tag in (PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE):
reader_clt = MTProtoIntermediateFrameStreamReader(reader_clt)
writer_clt = MTProtoIntermediateFrameStreamWriter(writer_clt)
else:
return

async def connect_reader_to_writer(rd, wr, user):
async def connect_reader_to_writer(rd, wr, user, rd_buf_size):
try:
while True:
data = await rd.read(READ_BUF_SIZE)
data = await rd.read(rd_buf_size)
if isinstance(data, tuple):
data, extra = data
else:
Expand All @@ -866,15 +882,17 @@ async def connect_reader_to_writer(rd, wr, user):
await wr.drain()
return
else:
update_stats(user, octets=len(data))
update_stats(user, octets=len(data), msgs=1)
wr.write(data, extra)
await wr.drain()
except (OSError, asyncio.streams.IncompleteReadError) as e:
# print_err(e)
pass

task_tg_to_clt = asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user))
task_clt_to_tg = asyncio.ensure_future(connect_reader_to_writer(reader_clt, writer_tg, user))
tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, TO_CLT_BUFSIZE)
clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, TO_TG_BUFSIZE)
task_tg_to_clt = asyncio.ensure_future(tg_to_clt)
task_clt_to_tg = asyncio.ensure_future(clt_to_tg)

update_stats(user, curr_connects=1)
await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED)
Expand Down Expand Up @@ -902,9 +920,9 @@ async def stats_printer():

print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S"))
for user, stat in stats.items():
print("%s: %d connects (%d current), %.2f MB" % (
print("%s: %d connects (%d current), %.2f MB, %d msgs" % (
user, stat["connects"], stat["curr_connects"],
stat["octets"] / 1000000))
stat["octets"] / 1000000, stat["msgs"]))
print(flush=True)


Expand Down Expand Up @@ -1032,22 +1050,35 @@ def print_tg_info():
params_encodeded = urllib.parse.urlencode(params, safe=':')
print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True)

params = {"server": ip, "port": PORT, "secret": "dd" + secret}
params_encodeded = urllib.parse.urlencode(params, safe=':')
print("{}: tg://proxy?{} (beta)".format(user, params_encodeded), flush=True)


def loop_exception_handler(loop, context):
exception = context.get("exception")
transport = context.get("transport")
if exception:
if isinstance(exception, TimeoutError):
if transport:
print_err("Timeout, killing transport")
transport.abort()
return
if isinstance(exception, OSError):
IGNORE_ERRNO = {
10038 # operation on non-socket on Windows, likely because fd == -1
10038, # operation on non-socket on Windows, likely because fd == -1
121, # the semaphore timeout period has expired on Windows
}

FORCE_CLOSE_ERRNO = {
113, # no route to host

}
if exception.errno in IGNORE_ERRNO:
return
elif exception.errno in FORCE_CLOSE_ERRNO:
if transport:
transport.abort()
return

loop.default_exception_handler(context)

Expand All @@ -1072,12 +1103,12 @@ def main():
reuse_port = hasattr(socket, "SO_REUSEPORT")

task_v4 = asyncio.start_server(handle_client_wrapper, '0.0.0.0', PORT,
limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop)
limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
server_v4 = loop.run_until_complete(task_v4)

if socket.has_ipv6:
task_v6 = asyncio.start_server(handle_client_wrapper, '::', PORT,
limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop)
limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
server_v6 = loop.run_until_complete(task_v6)

try:
Expand Down

0 comments on commit 47f7c08

Please sign in to comment.