diff --git a/api/common_api.py b/api/common_api.py index 8a31a60..ed328d8 100644 --- a/api/common_api.py +++ b/api/common_api.py @@ -4,7 +4,7 @@ from api.common_response import success_response, exception_response, error_response from config import current_version import logging -from module.common import get_firmware_infos +from module.firmware import get_firmware_infos logger = logging.getLogger(__name__) diff --git a/api/common_response.py b/api/common_response.py index 5968c9d..9887513 100644 --- a/api/common_response.py +++ b/api/common_response.py @@ -1,6 +1,6 @@ import logging from module.msg_notifier import send_notify -from exception.common_exception import VersionNotFoundException +from exception.common_exception import * logger = logging.getLogger(__name__) @@ -16,6 +16,10 @@ def exception_response(ex): logger.error(f'{str(ex)}') send_notify(f'无法获取 {ex.branch} 分支的 [{ex.target_version}] 版本信息') return error_response(404, str(ex)) + elif isinstance(ex, Md5NotMatchException): + logger.error(f'{str(ex)}') + send_notify(f'固件文件 md5 不匹配, 请重新下载') + return error_response(501, str(ex)) logger.error(ex, exc_info=True) traceback_str = "\n".join(traceback.format_exception(ex)) send_notify(f'出现异常, {traceback_str}') diff --git a/changelog.md b/changelog.md index 9b51980..c1f0aea 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,13 @@ # Change Log +## 0.3.0 + - 在 api 请求发生超时错误时进行重试 + - 当 IPv6 启用时 DoH 尝试查询 AAAA 记录 + - 安装固件时对下载文件的 md5 进行校验 + - 修复某些情况下 aria2 进程没有正常关闭的问题 + - 更正检测固件版本时固件文件解密失败的错误文本 + - 修复某些代理软件错误配置 localhost 代理导致无法调用 aria2 api 的问题 + ## 0.2.9 - 添加新 GitHub 下载源 nuaa.cf, 并更新在其它 GitHub 下载源中使用的 UA - 更正尝试下载一个不存在的 Ryujinx 版本时所展示的文本 diff --git a/config.py b/config.py index 4adae6c..1f14bbe 100644 --- a/config.py +++ b/config.py @@ -9,7 +9,7 @@ import sys -current_version = '0.2.9' +current_version = '0.3.0' user_agent = f'ns-emu-tools/{current_version}' @@ -72,6 +72,7 @@ class DownloadSetting: autoDeleteAfterInstall: Optional[bool] = True disableAria2Ipv6: Optional[bool] = True removeOldAria2LogFile: Optional[bool] = True + verifyFirmwareMd5: Optional[bool] = True @dataclass_json diff --git a/exception/common_exception.py b/exception/common_exception.py index 4389c39..df3316e 100644 --- a/exception/common_exception.py +++ b/exception/common_exception.py @@ -12,3 +12,8 @@ def __init__(self, target_version, branch, emu_type): self.emu_type = emu_type self.msg = f'Fail to get release info of version [{target_version}] on branch [{branch}]' super().__init__(self.msg) + + +class Md5NotMatchException(Exception): + def __init__(self): + super().__init__('MD5 not match') diff --git a/module/common.py b/module/common.py index 08445bc..f652721 100644 --- a/module/common.py +++ b/module/common.py @@ -1,43 +1,11 @@ +import logging import os -import shutil import subprocess -from functools import lru_cache from pathlib import Path -from module.msg_notifier import send_notify -from config import config -import bs4 -from utils.network import get_finial_url, session -import logging -from module.downloader import download - -logger = logging.getLogger(__name__) +from utils.network import get_finial_url -@lru_cache(1) -def get_firmware_infos(): - base_url = 'https://archive.org/download/nintendo-switch-global-firmwares/' - resp = session.get(get_finial_url(base_url)) - soup = bs4.BeautifulSoup(resp.text, features="html.parser") - a_tags = soup.select('#maincontent > div > div > pre > table > tbody > tr > td > a') - archive_versions = [] - for a in a_tags: - name = a.text - if name.startswith('Firmware ') and name.endswith('.zip'): - size = a.parent.next_sibling.next_sibling.next_sibling.next_sibling.text - version = name[9:-4] - version_num = 0 - for num in version.split('.'): - version_num *= 100 - version_num += int(''.join(ch for ch in num if ch.isdigit())) - archive_versions.append({ - 'name': name, - 'version': version, - 'size': size, - 'url': base_url + a.attrs['href'], - 'version_num': version_num, - }) - archive_versions = sorted(archive_versions, key=lambda x: x['version_num'], reverse=True) - return archive_versions +logger = logging.getLogger(__name__) def check_and_install_msvc(): @@ -57,36 +25,6 @@ def check_and_install_msvc(): # process.wait() -def install_firmware(firmware_version, target_firmware_path): - send_notify('正在获取固件信息...') - firmware_infos = get_firmware_infos() - target_info = None - if firmware_version: - firmware_map = {fi['version']: fi for fi in firmware_infos} - target_info = firmware_map.get(firmware_version) - if not target_info: - logger.info(f'Target firmware version [{firmware_version}] not found, skip install.') - send_notify(f'Target firmware version [{firmware_version}] not found, skip install.') - return - url = get_finial_url(target_info['url']) - send_notify(f'开始下载固件...') - logger.info(f"downloading firmware of [{firmware_version}] from {url}") - info = download(url) - file = info.files[0] - import zipfile - with zipfile.ZipFile(file.path, 'r') as zf: - firmware_path = target_firmware_path - shutil.rmtree(firmware_path, ignore_errors=True) - firmware_path.mkdir(parents=True, exist_ok=True) - send_notify(f'开始解压安装固件...') - logger.info(f'Unzipping firmware files to {firmware_path}') - zf.extractall(firmware_path) - logger.info(f'Firmware of [{firmware_version}] install successfully.') - if config.setting.download.autoDeleteAfterInstall: - os.remove(file.path) - return firmware_version - - if __name__ == '__main__': # infos = get_firmware_infos() # for info in infos: diff --git a/module/downloader.py b/module/downloader.py index a94a5ca..385668c 100644 --- a/module/downloader.py +++ b/module/downloader.py @@ -27,12 +27,15 @@ def init_aria2(): send_notify(f'starting aria2 daemon at port {port}') logger.info(f'starting aria2 daemon at port {port}') if config.setting.download.removeOldAria2LogFile and os.path.exists('aria2.log'): - logger.info('removing old aria2 logs.') - os.remove('aria2.log') + try: + logger.info('removing old aria2 logs.') + os.remove('aria2.log') + except: + pass st_inf = subprocess.STARTUPINFO() st_inf.dwFlags = st_inf.dwFlags | subprocess.STARTF_USESHOWWINDOW cli = [aria2_path, '--enable-rpc', '--rpc-listen-port', str(port), '--async-dns=true', - '--rpc-secret', '123456', '--log', 'aria2.log', '--log-level=info'] + '--rpc-secret', '123456', '--log', 'aria2.log', '--log-level=info', f'--stop-with-process={os.getpid()}'] if config.setting.download.disableAria2Ipv6: cli.append('--disable-ipv6=true') cli.append('--async-dns-server=223.5.5.5,119.29.29.29') @@ -51,11 +54,21 @@ def init_aria2(): global_options = get_global_options() logger.info(f'aria2 global options: {global_options}') aria2.set_global_options(global_options) - import atexit - atexit.register(shutdown_aria2) def download(url, save_dir=None, options=None, download_in_background=False): + origin_no_proxy = os.environ.get('no_proxy') + os.environ['no_proxy'] = '127.0.0.1,localhost' + try: + return _download(url, save_dir, options, download_in_background) + finally: + if origin_no_proxy is None: + del os.environ['no_proxy'] + else: + os.environ['no_proxy'] = origin_no_proxy + + +def _download(url, save_dir=None, options=None, download_in_background=False): init_aria2() tmp = init_download_options_with_proxy(url) tmp['auto-file-renaming'] = 'false' @@ -101,12 +114,6 @@ def download(url, save_dir=None, options=None, download_in_background=False): return info -def shutdown_aria2(): - if aria2_process: - # logger.info('Shutdown aria2...') - aria2_process.kill() - - if __name__ == '__main__': info = download('http://ipv4.download.thinkbroadband.com/200MB.zip') os.remove(info.files[0].path) diff --git a/module/firmware.py b/module/firmware.py index 98d1767..3dad3c6 100644 --- a/module/firmware.py +++ b/module/firmware.py @@ -5,6 +5,11 @@ from config import config, dump_config import shutil from module.msg_notifier import send_notify +import xmltodict +from functools import lru_cache +from config import config +from module.downloader import download +from utils.network import get_finial_url, session logger = logging.getLogger(__name__) hactool_path = Path(os.path.realpath(os.path.dirname(__file__))).joinpath('hactool.exe') @@ -84,6 +89,10 @@ def extract_version(target_file, key_path): f'--romfsdir="{str(tmp_path)}"', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) process.wait() + if not tmp_path.exists(): + logger.info(f'Fail to decrypt file.') + send_notify(f'无法解析固件文件, 可能是当前使用的密钥与固件版本不匹配') + return if tmp_path.joinpath('file').exists(): with open(tmp_path.joinpath('file'), 'rb') as f: f.seek(0x68) @@ -94,5 +103,84 @@ def extract_version(target_file, key_path): return version +@lru_cache(1) +def get_firmware_infos(): + import urllib.parse + base_url = 'https://archive.org/download/nintendo-switch-global-firmwares/' + url = base_url + 'nintendo-switch-global-firmwares_files.xml' + resp = session.get(get_finial_url(url), timeout=5) + data = xmltodict.parse(resp.text) + files = data['files']['file'] + res = [] + for info in files: + if 'ZIP' != info['format']: + continue + info['name'] = info['@name'] + del info['@name'] + info['url'] = base_url + urllib.parse.quote(info['name']) + version = info['name'][9:-4] + info['version'] = version + version_num = 0 + for num in version.split('.'): + version_num *= 100 + version_num += int(''.join(ch for ch in num if ch.isdigit())) + info['version_num'] = version_num + res.append(info) + res = sorted(res, key=lambda x: x['version_num'], reverse=True) + return res + + +def check_file_md5(file: Path, target_md5: str): + if not file.exists() or not file.is_file(): + return None + import hashlib + logger.debug(f'calculating md5 of file: {file}') + send_notify('开始校验文件 md5...') + hash_md5 = hashlib.md5() + with file.open('rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + file_md5 = hash_md5.hexdigest() + send_notify(f'本地文件 md5: {file_md5}') + send_notify(f'远端文件 md5: {target_md5}') + logger.debug(f'file md5: {file_md5}, target md5: {target_md5}') + return file_md5.lower() == target_md5.lower() + + +def install_firmware(firmware_version, target_firmware_path): + send_notify('正在获取固件信息...') + firmware_infos = get_firmware_infos() + target_info = None + if firmware_version: + firmware_map = {fi['version']: fi for fi in firmware_infos} + target_info = firmware_map.get(firmware_version) + if not target_info: + logger.info(f'Target firmware version [{firmware_version}] not found, skip install.') + send_notify(f'Target firmware version [{firmware_version}] not found, skip install.') + return + url = get_finial_url(target_info['url']) + send_notify(f'开始下载固件...') + logger.info(f"downloading firmware of [{firmware_version}] from {url}") + info = download(url) + file = info.files[0] + if config.setting.download.verifyFirmwareMd5 and not check_file_md5(file.path, target_info['md5']): + logger.info(f'firmware md5 not match, removing file [{file}]...') + os.remove(file) + from exception.common_exception import Md5NotMatchException + raise Md5NotMatchException() + import zipfile + with zipfile.ZipFile(file.path, 'r') as zf: + firmware_path = target_firmware_path + shutil.rmtree(firmware_path, ignore_errors=True) + firmware_path.mkdir(parents=True, exist_ok=True) + send_notify(f'开始解压安装固件...') + logger.info(f'Unzipping firmware files to {firmware_path}') + zf.extractall(firmware_path) + logger.info(f'Firmware of [{firmware_version}] install successfully.') + if config.setting.download.autoDeleteAfterInstall: + os.remove(file.path) + return firmware_version + + if __name__ == '__main__': detect_firmware_version('yuzu') diff --git a/module/ryujinx.py b/module/ryujinx.py index f6de414..12cb334 100644 --- a/module/ryujinx.py +++ b/module/ryujinx.py @@ -94,7 +94,7 @@ def install_firmware_to_ryujinx(firmware_version=None): shutil.rmtree(firmware_path, ignore_errors=True) firmware_path.mkdir(parents=True, exist_ok=True) tmp_dir = firmware_path.joinpath('tmp/') - from module.common import install_firmware + from module.firmware import install_firmware new_version = install_firmware(firmware_version, tmp_dir) if new_version: for path in tmp_dir.glob('*.nca'): diff --git a/module/sentry.py b/module/sentry.py index e5991b0..26ff5c0 100644 --- a/module/sentry.py +++ b/module/sentry.py @@ -4,7 +4,7 @@ def sampler(sample_data): if 'wsgi_environ' in sample_data and sample_data['wsgi_environ']['PATH_INFO'] == '/index.html': - return 1 + return 0.1 return 0 diff --git a/module/yuzu.py b/module/yuzu.py index ff15f1b..f9b8abb 100644 --- a/module/yuzu.py +++ b/module/yuzu.py @@ -117,7 +117,7 @@ def install_firmware_to_yuzu(firmware_version=None): logger.info(f'Current firmware are same as target version [{firmware_version}], skip install.') send_notify(f'当前的 固件 就是 [{firmware_version}], 跳过安装.') return - from module.common import install_firmware + from module.firmware import install_firmware new_version = install_firmware(firmware_version, get_yuzu_nand_path().joinpath(r'system\Contents\registered')) if new_version: config.yuzu.yuzu_firmware = new_version diff --git a/requirements.txt b/requirements.txt index 011b450..5ca6146 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ requests-cache dnspython[doh] pywebview sentry-sdk +xmltodict diff --git a/utils/doh.py b/utils/doh.py index 4f5c96d..c880aba 100644 --- a/utils/doh.py +++ b/utils/doh.py @@ -1,8 +1,10 @@ import ipaddress import logging +import socket import sys import time from typing import Dict, List +from config import config import dns.message import dns.query @@ -24,6 +26,9 @@ resolver.nameservers = ["223.5.5.5", '119.29.29.29'] +try_ipv6 = connection.HAS_IPV6 and not config.setting.download.disableAria2Ipv6 + + class DnsCacheItem: expire_at: float = 0 answer = None @@ -56,16 +61,17 @@ def update_dns_cache(name: str, answer): dns_cache[name] = available_items -def _get_available_items(name: str): +def _get_available_items(name: str, record_type: str = 'A'): + rdtype = dns.rdatatype.from_text(record_type) now = time.time() cached_items = dns_cache.get(name, []) - available_items = [item for item in cached_items if item.expire_at > now] + available_items = [item for item in cached_items if item.expire_at > now and item.answer.rdtype == rdtype] return available_items -def take_from_dns_cache(name: str): +def take_from_dns_cache(name: str, record_type: str = 'A'): res = [] - available_items = _get_available_items(name) + available_items = _get_available_items(name, record_type) available_answers = [item.answer for item in available_items] for answer in available_answers: for ip in answer: @@ -73,7 +79,7 @@ def take_from_dns_cache(name: str): return res -def query_address(name, record_type='A', server=DOH_SERVER, path="/dns-query", fallback=True, verbose=False): +def query_address(name, record_type='A', server=DOH_SERVER, path="/dns-query", fallback=True, verbose=True): """ Returns domain name query results retrieved by using DNS over HTTPS protocol @@ -87,7 +93,7 @@ def query_address(name, record_type='A', server=DOH_SERVER, path="/dns-query", f if is_ip_address(name): return [name] - retval = take_from_dns_cache(name) + retval = take_from_dns_cache(name, record_type) if retval: logger.debug(f'use dns answer from cache: {retval}') return retval @@ -97,11 +103,13 @@ def query_address(name, record_type='A', server=DOH_SERVER, path="/dns-query", f q = dns.message.make_query(name, dns.rdatatype.from_text(record_type)) resp = dns.query.https(q, server, session=session) # print(f'[{name}] doh answer: {resp.answer}') - logger.debug(f'doh answer: {resp.answer}') + logger.debug(f'doh answer of [{name} in {record_type}]: {resp.answer}') if not resp.answer: return [] retval = [] for answer in resp.answer: + if answer.rdtype not in {dns.rdatatype.AAAA, dns.rdatatype.A}: + continue update_dns_cache(name, answer) for item in answer: retval.append(item.address) @@ -122,18 +130,38 @@ def query_address(name, record_type='A', server=DOH_SERVER, path="/dns-query", f return retval +def _try_connect(addresses, port, *args, **kwargs): + global try_ipv6 + for ip in addresses: + try: + sock: socket.socket = _orig_create_connection((ip, port), *args, **kwargs) + # logger.debug(f'connected to {sock.getpeername()}') + return sock + except: + pass + + def patched_create_connection(address, *args, **kwargs): """Wrap urllib3's create_connection to resolve the name elsewhere""" # resolve hostname to an ip address; use your own # resolver here, as otherwise the system resolver will be used. + global try_ipv6 host, port = address if host.strip() == DOH_SERVER: return _orig_create_connection((DOH_SERVER, port), *args, **kwargs) + if try_ipv6: + addresses = query_address(host, 'AAAA') + sock = _try_connect(addresses, port, *args, **kwargs) + if sock: + return sock + elif addresses: + logger.debug(f'IPv6 disabled in DoH.') + try_ipv6 = False addresses = query_address(host) - if not addresses: - return _orig_create_connection(address, *args, **kwargs) - hostname = addresses[0] - return _orig_create_connection((hostname, port), *args, **kwargs) + sock = _try_connect(addresses, port, *args, **kwargs) + if sock: + return sock + return _orig_create_connection(address, *args, **kwargs) def install_doh(): @@ -147,4 +175,5 @@ def install_doh(): # time.sleep(60) # print(query_address('google.com')) install_doh() - print(requests.get('http://t.tt')) + print(requests.get('https://nsarchive.e6ex.com', timeout=5).text) + print(requests.get('https://cfrp.e6ex.com', timeout=5).text) diff --git a/utils/network.py b/utils/network.py index ecf9af5..bf83ffa 100644 --- a/utils/network.py +++ b/utils/network.py @@ -3,6 +3,7 @@ import logging import os import requests_cache +from requests.adapters import HTTPAdapter logger = logging.getLogger(__name__) @@ -26,6 +27,9 @@ session = requests_cache.CachedSession(cache_control=True) session.headers.update({'User-Agent': user_agent}) +session.mount('https://cfrp.e6ex.com', HTTPAdapter(max_retries=5)) +session.mount('https://nsarchive.e6ex.com', HTTPAdapter(max_retries=5)) +session.mount('https://api.github.com', HTTPAdapter(max_retries=5)) options_on_proxy = { 'split': '16', diff --git a/vue/src/pages/Ryujinx.vue b/vue/src/pages/Ryujinx.vue index 1404486..8d5bc81 100644 --- a/vue/src/pages/Ryujinx.vue +++ b/vue/src/pages/Ryujinx.vue @@ -14,7 +14,7 @@ - @@ -105,7 +105,7 @@ - + - diff --git a/vue/src/pages/Settings.vue b/vue/src/pages/Settings.vue index 0190bb7..c1f5f11 100644 --- a/vue/src/pages/Settings.vue +++ b/vue/src/pages/Settings.vue @@ -58,6 +58,7 @@ + diff --git a/vue/src/pages/Yuzu.vue b/vue/src/pages/Yuzu.vue index 1cc86e9..748879e 100644 --- a/vue/src/pages/Yuzu.vue +++ b/vue/src/pages/Yuzu.vue @@ -99,7 +99,7 @@ - + - diff --git a/vue/src/store/index.js b/vue/src/store/index.js index 437743e..aff23e0 100644 --- a/vue/src/store/index.js +++ b/vue/src/store/index.js @@ -109,13 +109,14 @@ const state = { network: { firmwareSource: 'auto-detect', githubApiMode: 'direct', - githubDownloadSource: "self" + githubDownloadSource: "self", + useDoh: true, }, download: { autoDeleteAfterInstall: true, disableAria2Ipv6: true, removeOldAria2LogFile: true, - useDoh: true, + verifyFirmwareMd5: true, } }, },