Skip to content

feat: improved LaTeX check workflows #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
219 changes: 208 additions & 11 deletions writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import os
import glob
import yaml
import re
from typing import List, Tuple, Dict

path_to = f'src/content/blog/{datetime.datetime.now().strftime("%Y-%m-%d")}'

Expand Down Expand Up @@ -95,19 +97,202 @@ def summary(article):
{"role": "user", "content": f"给这篇文章写一个15字的简短介绍:\n\n{article}"}
], deepseek, "deepseek-chat")

# LaTeX error handling
def extract_latex_segments(markdown_text: str) -> List[Tuple[str, int, int]]:
segments: List[Tuple[str,int,int]] = []
block_pattern = re.compile(r'(\$\$[\s\S]+?\$\$)', re.DOTALL)
for m in block_pattern.finditer(markdown_text):
segments.append((m.group(1), m.start(), m.end()))

inline_pattern = re.compile(r'(?<!\\)(\$(?:\\.|[^$])+?\$)', re.DOTALL)
for m in inline_pattern.finditer(markdown_text):
if any(start <= m.start() < end for _, start, end in segments):
continue
segments.append((m.group(1), m.start(), m.end()))

return segments

def latex_checks(latex_str: str) -> List[str]:
errors: List[str] = []

# 命令后多余空格 (忽略 \tt, \it, \bf)
for m in re.finditer(r"\\([a-zA-Z]+)(\s+)", latex_str):
cmd = m.group(1)
if cmd not in ('tt', 'it', 'bf'):
errors.append(f"命令 '\\{cmd}' 后跟有空格,建议去掉空格。")

# 引用前多余空格,建议用 '~'
if re.search(r"\s+\\ref\{", latex_str):
errors.append("'\\ref' 前有空格,应使用 '~\\ref{...}' 保持断开。")

# 省略号 '...' 而非 \dots 或 \ldots
if re.search(r'(?<!\\)(?:\.\.\.|…)', latex_str):
errors.append("检测到省略号,建议使用 '\\dots'、'\\cdots' 或 '\\ldots'。")

# 缩写后不加特殊空格
for m in re.finditer(r"\b(e\.g|i\.e|etc)\.(\s+)", latex_str):
errors.append(f"缩写 '{m.group(1)}.' 后应使用 '\\ ' 或 '~' 保持空格。")

# 句末大写字母后应有两个空格
for m in re.finditer(r"([A-Z])\.(\s)(?=[A-Z])", latex_str):
errors.append(f"句子结尾 '{m.group(1)}.' 后只有单个空格,建议使用两个空格。")

# 再次检查数学mode的$
# 块级
block_marks = re.findall(r'\$\$', latex_str)
if len(block_marks) % 2 != 0:
errors.append("块级数学模式 '$$' 不成对。")
# 去掉所有 $$…$$ 段
no_block = re.sub(r'\$\$[\s\S]+?\$\$', '', latex_str)
# 行内
inline_marks = len(re.findall(r'(?<!\\)\$', no_block))
if inline_marks % 2 != 0:
errors.append("行内数学模式 '$' 不成对。")

# 引号 `` ''
if '"' in latex_str and not re.search(r"``.*?''", latex_str, re.DOTALL):
errors.append("检测到直引号 '\"',建议使用 LaTeX 引号 ``...'' 。")

# \label 前空格
if re.search(r"\s+\\label\{", latex_str):
errors.append("'\\label' 前有空格,应紧贴前文。")

# \footnote 前空格
if re.search(r"\s+\\footnote\{", latex_str):
errors.append("'\\footnote' 前有空格,应紧贴前文。")

# 数学中用 x 而非 \times
for m in re.finditer(r"(?<!\\)\b(\d+)\s*x\s*(\d+)\b", latex_str):
errors.append(f"'{m.group(1)} x {m.group(2)}' 建议用 '$\\times$'。")

# 多余连续空格
if re.search(r" {2,}", latex_str):
errors.append("检测到连续多个空格,可能要删掉")

# 大括号匹配
stack: List[int] = []
for pos, ch in enumerate(latex_str):
if ch == '{': stack.append(pos)
elif ch == '}':
if not stack:
errors.append(f"位置 {pos}: 多余 '}}' 。")
else:
stack.pop()
for pos in stack:
errors.append(f"位置 {pos}: 多余 '{{' 。")

# \begin / \end 匹配(修正 \end raw-string 报错)
env_stack: List[Tuple[str, int]] = []
for m in re.finditer(r"\\(begin|end)\s*\{([^}]+)\}", latex_str):
cmd, env = m.group(1), m.group(2)
pos = m.start()
if cmd == 'begin':
env_stack.append((env, pos))
else: # cmd == 'end'
if not env_stack or env_stack[-1][0] != env:
# 注意这里用双反斜杠来正确表示 '\end'
errors.append(f"位置 {pos}: '\\end{{{env}}}' 无匹配或顺序错误。")
else:
env_stack.pop()
# 剩余未闭合的 begin
for env, pos in env_stack:
errors.append(f"位置 {pos}: '\\begin{{{env}}}' 未关闭。")

# 括号前多余空格
if re.search(r"\s+\(", latex_str):
errors.append("左括号 '(' 前有空格,应去除。")

# 数学模式中不应有标点
for m in re.finditer(r"\$(?:[^$]*?)[.,;:!?]+(?:[^$]*?)\$", latex_str):
errors.append("数学模式中包含标点符号,建议放在模式外。")

return errors

def latex_errors(markdown_text: str) -> Dict[Tuple[str, int], List[str]]:
report = {}
for seg, start_idx, _ in extract_latex_segments(markdown_text):
errs = latex_checks(seg)
if errs:
report[(seg, start_idx)] = errs
return report

def modify_latex(markdown_text: str, error_report: Dict[Tuple[str,int], List[str]]) -> str:
"""
遍历 error_report,按 start_idx 从大到小替换,
保证后面的替换不影响前面的 start_idx。
"""
corrected = markdown_text
items = sorted(error_report.items(), key=lambda x: x[0][1], reverse=True)

for (seg, start_idx), errs in items:
end_idx = start_idx + len(seg)
context = corrected[max(0, start_idx-50): end_idx+50]
user_msg = (
f"修正此 LaTeX 片段(包含 $ 定界符):\n{seg}\n\n"
"检测到错误:\n- " + "\n- ".join(errs) +
"\n\n上下文:\n" + context +
"\n\n请只返回修正后的完整片段,不要添加其它标记。"
)
fixed = generate([
{"role":"system","content":"你是 LaTeX 专家,负责修正以下代码:"},
{"role":"user","content":user_msg}
], deepseek, "deepseek-reasoner").strip()

# 去掉```,如果不小心生成了
if fixed.startswith("```") and fixed.endswith("```"):
fixed = "\n".join(fixed.splitlines()[1:-1]).strip()

# 给重新生成的丢失的加上 $/$$,如果ds忘记了
if not fixed.startswith('$'):
if seg.startswith('$$') and seg.endswith('$$'):
fixed = '$$' + fixed + '$$'
elif seg.startswith('$') and seg.endswith('$'):
fixed = '$' + fixed + '$'

# 最终替换
corrected = corrected[:start_idx] + fixed + corrected[end_idx:]

return corrected

is_latin = lambda ch: '\u0000' <= ch <= '\u007F' or '\u00A0' <= ch <= '\u024F'
is_nonspace_latin = lambda ch: is_latin(ch) and not ch.isspace() and not ch in """*()[]{}"'/-@#"""
is_nonpunct_cjk = lambda ch: not is_latin(ch) and ch not in "·!¥…()—【】、;:‘’“”,。《》?「」"

def beautify_string(text):
res = ""
for idx in range(len(text)):
if idx and (
(is_nonspace_latin(text[idx]) and is_nonpunct_cjk(text[idx - 1])) or
(is_nonspace_latin(text[idx - 1]) and is_nonpunct_cjk(text[idx]))
): res += " "
res += text[idx]
return res
# beautify的时候跳过 LaTeX
def beautify_string(text: str) -> str:
segments = extract_latex_segments(text)
segments.sort(key=lambda x: x[1])

result_parts = []
last_end = 0

for seg_content, seg_start, seg_end in segments:
non_latex_part = text[last_end:seg_start]
processed_part = ""
for i, char in enumerate(non_latex_part):
if i > 0 and (
(is_nonspace_latin(char) and is_nonpunct_cjk(non_latex_part[i-1])) or
(is_nonspace_latin(non_latex_part[i-1]) and is_nonpunct_cjk(char))
):
processed_part += " "
processed_part += char
result_parts.append(processed_part)

result_parts.append(seg_content)
last_end = seg_end

final_part = text[last_end:]
processed_final_part = ""
for i, char in enumerate(final_part):
if i > 0 and (
(is_nonspace_latin(char) and is_nonpunct_cjk(final_part[i-1])) or
(is_nonspace_latin(final_part[i-1]) and is_nonpunct_cjk(char))
):
processed_final_part += " "
processed_final_part += char
result_parts.append(processed_final_part)

return "".join(result_parts)

start = time.time()
print(" Generating topic:")
Expand All @@ -121,9 +306,21 @@ def beautify_string(text):

start = time.time()
print(" Generating article:")
article = beautify_string(write_from_outline(outline_result))
article = write_from_outline(outline_result)
print(f" Article written: time spent {time.time() - start:.1f} s")

start = time.time()
while latex_errors(article):
print("latex_errors still exist")
article = modify_latex(article, latex_errors(article))

print(f" LaTeX errors fixed: time spent {time.time() - start:.1f} s")

start = time.time()
article = beautify_string(article)
print(f" Article beautified: time spent {time.time() - start:.1f} s")


start = time.time()
print(" Generating summary:")
summary_result = beautify_string(summary(article))
Expand Down Expand Up @@ -160,4 +357,4 @@ def beautify_string(text):
with open(f"{path_to}/index.md", "w", encoding="utf-8") as f:
f.write(markdown_file)

print(f" Composed article: {path_to}/index.md")
print(f" Composed article: {path_to}/index.md")