Skip to content

Commit

Permalink
Merge pull request #532 from pywebio/reliable-http-session
Browse files Browse the repository at this point in the history
Reliable http session
  • Loading branch information
wang0618 authored Jan 15, 2023
2 parents 75b4773 + 3168544 commit 732ebc5
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 89 deletions.
4 changes: 2 additions & 2 deletions pywebio/__version__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
__package__ = 'pywebio'
__description__ = 'Write interactive web app in script way.'
__url__ = 'https://pywebio.readthedocs.io'
__version__ = "1.7.0"
__version_info__ = (1, 7, 0, 0)
__version__ = "1.7.1"
__version_info__ = (1, 7, 1, 0)
__author__ = 'WangWeimin'
__author_email__ = '[email protected]'
__license__ = 'MIT'
129 changes: 88 additions & 41 deletions pywebio/platform/adaptor/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
import threading
import time
from contextlib import contextmanager
from typing import Dict, Optional
from typing import Dict, Optional, List
from collections import deque

from ..page import make_applications, render_page
from ..utils import deserialize_binary_event
from ...session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target
from ...session.base import get_session_info_from_headers
from ...session.base import get_session_info_from_headers, Session
from ...utils import random_str, LRUDict, isgeneratorfunction, iscoroutinefunction, check_webio_js


Expand All @@ -35,7 +36,7 @@ def request_obj(self):
Return the current request object"""
pass

def request_method(self):
def request_method(self) -> str:
"""返回当前请求的方法,大写
Return the HTTP method of the current request, uppercase"""
pass
Expand All @@ -45,29 +46,19 @@ def request_headers(self) -> Dict:
Return the header dictionary of the current request"""
pass

def request_url_parameter(self, name, default=None):
def request_url_parameter(self, name, default=None) -> str:
"""返回当前请求的URL参数
Returns the value of the given URL parameter of the current request"""
pass

def request_body(self):
def request_body(self) -> bytes:
"""返回当前请求的body数据
Returns the data of the current request body
:return: bytes/bytearray
"""
return b''

def request_json(self) -> Optional[Dict]:
"""返回当前请求的json反序列化后的内容,若请求数据不为json格式,返回None
Return the data (json deserialization) of the currently requested, if the data is not in json format, return None"""
try:
if self.request_headers().get('content-type') == 'application/octet-stream':
return deserialize_binary_event(self.request_body())
return json.loads(self.request_body())
except Exception:
return None

def set_header(self, name, value):
"""为当前响应设置header
Set a header for the current response"""
Expand All @@ -92,7 +83,7 @@ def get_response(self):
Get the current response object"""
pass

def get_client_ip(self):
def get_client_ip(self) -> str:
"""获取用户的ip
Get the user's ip"""
pass
Expand All @@ -102,6 +93,56 @@ def get_client_ip(self):
_event_loop = None


class ReliableTransport:
def __init__(self, session: Session, message_window: int = 4):
self.session = session
self.messages = deque()
self.window_size = message_window
self.min_msg_id = 0 # the id of the first message in the window
self.finished_event_id = -1 # the id of the last finished event

@staticmethod
def close_message(ack):
return dict(
commands=[[dict(command='close_session')]],
seq=ack + 1
)

def push_event(self, events: List[Dict], seq: int) -> int:
"""Send client events to the session and return the success message count"""
if not events:
return 0

submit_cnt = 0
for eid, event in enumerate(events, start=seq):
if eid > self.finished_event_id:
self.finished_event_id = eid # todo: use lock for check and set operation
self.session.send_client_event(event)
submit_cnt += 1

return submit_cnt

def get_response(self, ack=0):
"""
ack num is the number of messages that the client has received.
response is a list of messages that the client should receive, along with their min id `seq`.
"""
while ack >= self.min_msg_id and self.messages:
self.messages.popleft()
self.min_msg_id += 1

if len(self.messages) < self.window_size:
msgs = self.session.get_task_commands()
if msgs:
self.messages.append(msgs)

return dict(
commands=list(self.messages),
seq=self.min_msg_id,
ack=self.finished_event_id
)


# todo: use lock to avoid thread race condition
class HttpHandler:
"""基于HTTP的后端Handler实现
Expand All @@ -112,7 +153,7 @@ class HttpHandler:
"""
_webio_sessions = {} # WebIOSessionID -> WebIOSession()
_webio_last_commands = {} # WebIOSessionID -> (last commands, commands sequence id)
_webio_transports = {} # WebIOSessionID -> ReliableTransport(), type: Dict[str, ReliableTransport]
_webio_expire = LRUDict() # WebIOSessionID -> last active timestamp. In increasing order of last active time
_webio_expire_lock = threading.Lock()

Expand Down Expand Up @@ -143,23 +184,13 @@ def _remove_expired_sessions(cls, session_expire_seconds):
if session:
session.close(nonblock=True)
del cls._webio_sessions[sid]
del cls._webio_transports[sid]

@classmethod
def _remove_webio_session(cls, sid):
cls._webio_sessions.pop(sid, None)
cls._webio_expire.pop(sid, None)

@classmethod
def get_response(cls, sid, ack=0):
commands, seq = cls._webio_last_commands.get(sid, ([], 0))
if ack == seq:
webio_session = cls._webio_sessions[sid]
commands = webio_session.get_task_commands()
seq += 1
cls._webio_last_commands[sid] = (commands, seq)

return {'commands': commands, 'seq': seq}

def _process_cors(self, context: HttpContext):
"""Handling cross-domain requests: check the source of the request and set headers"""
origin = context.request_headers().get('Origin', '')
Expand Down Expand Up @@ -209,6 +240,14 @@ def get_cdn(self, context):
return False
return self.cdn

def read_event_data(self, context: HttpContext) -> List[Dict]:
try:
if context.request_headers().get('content-type') == 'application/octet-stream':
return [deserialize_binary_event(context.request_body())]
return json.loads(context.request_body())
except Exception:
return []

@contextmanager
def handle_request_context(self, context: HttpContext):
"""called when every http request"""
Expand Down Expand Up @@ -240,16 +279,18 @@ def handle_request_context(self, context: HttpContext):
context.set_content(html)
return context.get_response()

webio_session_id = None
ack = int(context.request_url_parameter('ack', 0))
webio_session_id = request_headers['webio-session-id']
new_request = False
if webio_session_id.startswith('NEW-'):
new_request = True
webio_session_id = webio_session_id[4:]

# 初始请求,创建新 Session
if not request_headers['webio-session-id'] or request_headers['webio-session-id'] == 'NEW':
if new_request and webio_session_id not in cls._webio_sessions: # 初始请求,创建新 Session
if context.request_method() == 'POST': # 不能在POST请求中创建Session,防止CSRF攻击
context.set_status(403)
return context.get_response()

webio_session_id = random_str(24)
context.set_header('webio-session-id', webio_session_id)
session_info = get_session_info_from_headers(context.request_headers())
session_info['user_ip'] = context.get_client_ip()
session_info['request'] = context.request_obj()
Expand All @@ -264,17 +305,23 @@ def handle_request_context(self, context: HttpContext):
session_cls = ThreadBasedSession
webio_session = session_cls(application, session_info=session_info)
cls._webio_sessions[webio_session_id] = webio_session
yield type(self).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
elif request_headers['webio-session-id'] not in cls._webio_sessions: # WebIOSession deleted
context.set_content([dict(command='close_session')], json_type=True)
cls._webio_transports[webio_session_id] = ReliableTransport(webio_session)
yield cls.WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
elif webio_session_id not in cls._webio_sessions: # WebIOSession deleted
close_msg = ReliableTransport.close_message(ack)
context.set_content(close_msg, json_type=True)
return context.get_response()
else:
webio_session_id = request_headers['webio-session-id']
# in this case, the request_headers['webio-session-id'] may also startswith NEW,
# this is because the response for the previous new session request has not been received by the client,
# and the client has sent a new request with the same session id.
webio_session = cls._webio_sessions[webio_session_id]

if context.request_method() == 'POST': # client push event
if context.request_json() is not None:
webio_session.send_client_event(context.request_json())
seq = int(context.request_url_parameter('seq', 0))
event_data = self.read_event_data(context)
submit_cnt = cls._webio_transports[webio_session_id].push_event(event_data, seq)
if submit_cnt > 0:
yield type(self).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
elif context.request_method() == 'GET': # client pull messages
pass
Expand All @@ -283,8 +330,8 @@ def handle_request_context(self, context: HttpContext):

self.interval_cleaning()

ack = int(context.request_url_parameter('ack', 0))
context.set_content(type(self).get_response(webio_session_id, ack=ack), json_type=True)
resp = cls._webio_transports[webio_session_id].get_response(ack)
context.set_content(resp, json_type=True)

if webio_session.closed():
self._remove_webio_session(webio_session_id)
Expand Down
Loading

0 comments on commit 732ebc5

Please sign in to comment.