apply_patch_executor.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. from __future__ import annotations
  2. import os
  3. import tempfile
  4. from dataclasses import dataclass
  5. from datetime import datetime
  6. from pathlib import Path
  7. from typing import List, Optional, Tuple
  8. class PatchApplyError(RuntimeError):
  9. """
  10. 补丁应用过程中发生的异常类。
  11. 用于封装补丁应用失败的原因,并提供额外的检查目标信息。
  12. 参数:
  13. message: 错误消息,描述补丁应用失败的原因
  14. recheck_targets: 可选的重新检查目标列表,用于辅助调试和修复补丁问题
  15. """
  16. def __init__(self, message: str, recheck_targets: Optional[List[str]] = None):
  17. super().__init__(message)
  18. self.recheck_targets = recheck_targets or []
  19. @dataclass
  20. class ApplyResult:
  21. """
  22. 补丁应用结果的数据类。
  23. 用于返回补丁应用过程中产生的变更信息和备份信息。
  24. 字段:
  25. files_changed: 被修改的文件路径列表(相对路径)
  26. backups: 创建的备份文件路径列表(绝对路径)
  27. """
  28. files_changed: List[str]
  29. backups: List[str]
  30. class ApplyPatchExecutor:
  31. """
  32. 应用 Codex 风格的 *** Begin Patch 格式补丁。
  33. 安全特性 (MVP):
  34. - repo_root 路径限制 (防止路径逃逸)
  35. - 通过临时文件 + os.replace 实现原子写入
  36. - 备份到 <repo_root>/.helloagents/backups/<timestamp>/
  37. - 大小限制 (最大文件数, 最大总变更行数)
  38. - Update File 块的冲突检测 (精确匹配)
  39. """
  40. def __init__(
  41. self,
  42. repo_root: Path,
  43. max_files: int = 10,
  44. max_total_changed_lines: int = 800,
  45. allowed_write_suffixes: Optional[List[str]] = None,
  46. ):
  47. """
  48. 初始化补丁应用执行器。
  49. 参数:
  50. repo_root: 代码仓库根目录路径,所有补丁操作都限制在此目录内
  51. max_files: 单个补丁允许修改的最大文件数量,默认10个
  52. max_total_changed_lines: 单个补丁允许修改的最大总行数,默认800行
  53. allowed_write_suffixes: 允许修改的文件后缀列表,默认只允许常见文本文件
  54. """
  55. self.repo_root = repo_root
  56. self.max_files = max_files
  57. self.max_total_changed_lines = max_total_changed_lines
  58. # 默认允许写入的文件后缀,防止意外修改二进制文件或敏感文件
  59. self.allowed_write_suffixes = allowed_write_suffixes or [
  60. ".py",
  61. ".md",
  62. ".toml",
  63. ".json",
  64. ".yml",
  65. ".yaml",
  66. ".txt",
  67. ".html",
  68. ".htm",
  69. ".css",
  70. ".js",
  71. ]
  72. # 初始化工作目录和备份目录
  73. self.root_dir = repo_root / ".helloagents"
  74. self.backups_dir = self.root_dir / "backups"
  75. self.backups_dir.mkdir(parents=True, exist_ok=True)
  76. def apply(self, patch_text: str) -> ApplyResult:
  77. """
  78. 解析并应用补丁文本。
  79. 执行流程:
  80. 1. 解析补丁操作 (Add/Update/Delete)
  81. 2. 检查安全限制 (文件数量, 变更行数)
  82. 3. 创建备份目录
  83. 4. 逐个执行操作 (先备份再修改)
  84. 参数:
  85. patch_text: 符合 Codex 风格的补丁文本,以 *** Begin Patch 开始,*** End Patch 结束
  86. 返回:
  87. ApplyResult: 包含被修改文件和备份文件信息的结果对象
  88. 异常:
  89. PatchApplyError: 当补丁不符合格式、超出限制或应用失败时抛出
  90. """
  91. # 解析补丁文本,提取操作列表
  92. ops = self._parse_patch(patch_text)
  93. # 统计受影响的文件数量,检查是否超过限制
  94. touched_files = [op[1] for op in ops if op[0] in {"add", "update", "delete"}]
  95. if len(set(touched_files)) > self.max_files:
  96. raise PatchApplyError(f"Too many files in patch: {len(set(touched_files))} > {self.max_files}")
  97. # 估算补丁修改的总行数,检查是否超过限制
  98. total_changed = self._estimate_changed_lines(ops)
  99. if total_changed > self.max_total_changed_lines:
  100. raise PatchApplyError(f"Patch too large: {total_changed} changed lines > {self.max_total_changed_lines}")
  101. # 创建本次补丁应用的专属备份目录(时间戳命名)
  102. backup_run_dir = self.backups_dir / datetime.now().strftime("%Y%m%d_%H%M%S")
  103. backup_run_dir.mkdir(parents=True, exist_ok=True)
  104. # 初始化结果收集变量
  105. files_changed: List[str] = [] # 记录被修改的文件路径
  106. backups: List[str] = [] # 记录创建的备份文件路径
  107. # 遍历所有解析出的操作,逐个执行
  108. for kind, rel_path, payload in ops:
  109. # 安全检查:确保路径在仓库内,防止路径遍历攻击
  110. target = self._safe_path(rel_path)
  111. # 安全检查:确保文件后缀在允许的列表中
  112. self._enforce_suffix(target)
  113. if kind == "add":
  114. # 添加新文件操作
  115. if target.exists():
  116. raise PatchApplyError(f"Add File target already exists: {rel_path}")
  117. # 创建父目录(如果不存在)
  118. target.parent.mkdir(parents=True, exist_ok=True)
  119. # 原子写入新文件内容
  120. self._atomic_write(target, payload)
  121. # 记录变更
  122. files_changed.append(rel_path)
  123. elif kind == "delete":
  124. # 删除文件操作
  125. if not target.exists():
  126. raise PatchApplyError(f"Delete File target missing: {rel_path}")
  127. # 删除前先备份文件
  128. b = self._backup_file(target, backup_run_dir)
  129. backups.append(str(b))
  130. # 删除文件
  131. target.unlink()
  132. # 记录变更
  133. files_changed.append(rel_path)
  134. elif kind == "update":
  135. # 更新文件操作
  136. if not target.exists():
  137. raise PatchApplyError(f"Update File target missing: {rel_path}")
  138. # 读取原始文件内容(保留换行符)
  139. original = target.read_text(encoding="utf-8").splitlines(keepends=True)
  140. # 修改前先备份文件
  141. b = self._backup_file(target, backup_run_dir)
  142. backups.append(str(b))
  143. # 应用更新补丁内容
  144. updated = self._apply_update_payload(original, payload, rel_path)
  145. # 原子写入更新后的内容
  146. self._atomic_write(target, "".join(updated))
  147. # 记录变更
  148. files_changed.append(rel_path)
  149. else:
  150. # 未知操作类型
  151. raise PatchApplyError(f"Unknown op kind: {kind}")
  152. # 返回最终的应用结果
  153. return ApplyResult(files_changed=files_changed, backups=backups)
  154. def _safe_path(self, rel_path: str) -> Path:
  155. """
  156. 验证路径安全性,防止路径遍历攻击 (Path Traversal)。
  157. 确保目标路径在 repo_root 目录下,防止访问仓库外的文件。
  158. 参数:
  159. rel_path: 相对路径字符串
  160. 返回:
  161. Path: 安全的绝对路径对象
  162. 异常:
  163. PatchApplyError: 当路径是绝对路径、包含特殊字符或试图访问仓库外时抛出
  164. """
  165. if rel_path.startswith("/") or rel_path.startswith("~"):
  166. raise PatchApplyError(f"Absolute paths are not allowed: {rel_path}")
  167. target = (self.repo_root / rel_path).resolve()
  168. # 检查解析后的路径是否以 repo_root 开头
  169. if not str(target).startswith(str(self.repo_root.resolve()) + os.sep) and target != self.repo_root.resolve():
  170. raise PatchApplyError(f"Path escapes repo_root: {rel_path}")
  171. if target.exists() and target.is_symlink():
  172. raise PatchApplyError(f"Refusing to modify symlink: {rel_path}")
  173. return target
  174. def _enforce_suffix(self, target: Path) -> None:
  175. """
  176. 检查目标文件的后缀是否在允许的列表中。
  177. 防止意外修改二进制文件、配置文件或其他敏感文件。
  178. 参数:
  179. target: 目标文件路径对象
  180. 异常:
  181. PatchApplyError: 当文件后缀不在允许列表中时抛出
  182. """
  183. if target.suffix and target.suffix not in self.allowed_write_suffixes:
  184. raise PatchApplyError(f"Disallowed file suffix for write: {target.suffix}")
  185. def _backup_file(self, target: Path, backup_run_dir: Path) -> Path:
  186. """
  187. 备份目标文件到指定的备份目录。
  188. 备份文件保持与原文件相同的相对路径结构,后缀添加 .bak。
  189. 参数:
  190. target: 要备份的目标文件路径
  191. backup_run_dir: 本次运行的备份目录
  192. 返回:
  193. Path: 创建的备份文件路径
  194. """
  195. # 获取文件相对于仓库根目录的路径
  196. rel = target.relative_to(self.repo_root)
  197. # 构建备份文件路径
  198. backup_path = backup_run_dir / (str(rel) + ".bak")
  199. # 创建备份文件的父目录(如果不存在)
  200. backup_path.parent.mkdir(parents=True, exist_ok=True)
  201. # 复制文件内容到备份文件
  202. backup_path.write_bytes(target.read_bytes())
  203. return backup_path
  204. def _atomic_write(self, target: Path, content: str) -> None:
  205. """
  206. 原子写入文件内容。
  207. 先写入临时文件,然后使用 os.replace 原子性替换目标文件,确保写入过程不会因为中断而导致文件损坏。
  208. 参数:
  209. target: 目标文件路径
  210. content: 要写入的文件内容
  211. """
  212. target.parent.mkdir(parents=True, exist_ok=True)
  213. with tempfile.NamedTemporaryFile("w", delete=False, dir=str(target.parent), encoding="utf-8") as tf:
  214. tf.write(content)
  215. tf.flush()
  216. os.fsync(tf.fileno())
  217. tmp_name = tf.name
  218. os.replace(tmp_name, target)
  219. def _parse_patch(self, text: str) -> List[Tuple[str, str, str]]:
  220. """
  221. 解析补丁文本,提取操作列表。
  222. 支持的操作:
  223. - *** Add File: <path> - 添加新文件
  224. - *** Delete File: <path> - 删除文件
  225. - *** Update File: <path> - 更新文件内容
  226. 参数:
  227. text: 补丁文本字符串
  228. 返回:
  229. List[Tuple[str, str, str]]: 操作列表,每个操作包含(操作类型, 路径, 内容)
  230. 异常:
  231. PatchApplyError: 当补丁格式不符合要求时抛出
  232. """
  233. lines = text.splitlines()
  234. # 宽容处理:跳过前置空行/代码块围栏,找到真正的开头
  235. while lines and lines[0].strip() in {"", "```", "```patch", "```diff", "```text"}:
  236. lines = lines[1:]
  237. # 如果仍未以标头开头,尝试向下寻找标头并截取
  238. if lines and lines[0].strip() != "*** Begin Patch":
  239. for idx, l in enumerate(lines):
  240. if l.strip() == "*** Begin Patch":
  241. lines = lines[idx:]
  242. break
  243. if not lines or lines[0].strip() != "*** Begin Patch":
  244. raise PatchApplyError("Patch must start with '*** Begin Patch'")
  245. # 同样跳过结尾的围栏/空行
  246. while lines and lines[-1].strip() in {"", "```"}:
  247. lines = lines[:-1]
  248. if not lines or lines[-1].strip() != "*** End Patch":
  249. # 如果末尾未对齐,尝试在中间找到最后一个 End 标记
  250. for idx in range(len(lines) - 1, -1, -1):
  251. if lines[idx].strip() == "*** End Patch":
  252. lines = lines[: idx + 1]
  253. break
  254. if not lines or lines[-1].strip() != "*** End Patch":
  255. raise PatchApplyError("Patch must end with '*** End Patch'")
  256. ops: List[Tuple[str, str, str]] = []
  257. i = 1
  258. while i < len(lines) - 1:
  259. line = lines[i]
  260. if line.startswith("*** Add File: "):
  261. path = line[len("*** Add File: ") :].strip()
  262. i += 1
  263. buf: List[str] = []
  264. while i < len(lines) - 1 and not lines[i].startswith("*** "):
  265. # 兼容两种格式:
  266. # 1) 规范形式:以 '+' 开头
  267. # 2) 宽松形式:直接给出正文(模型有时会省略 '+')
  268. if lines[i].startswith("+"):
  269. buf.append(lines[i][1:] + "\n")
  270. else:
  271. buf.append(lines[i] + "\n")
  272. i += 1
  273. ops.append(("add", path, "".join(buf)))
  274. continue
  275. if line.startswith("*** Delete File: "):
  276. path = line[len("*** Delete File: ") :].strip()
  277. ops.append(("delete", path, ""))
  278. i += 1
  279. continue
  280. if line.startswith("*** Update File: "):
  281. path = line[len("*** Update File: ") :].strip()
  282. i += 1
  283. buf: List[str] = []
  284. while i < len(lines) - 1 and not lines[i].startswith("*** "):
  285. buf.append(lines[i])
  286. i += 1
  287. ops.append(("update", path, "\n".join(buf)))
  288. continue
  289. if line.strip() == "":
  290. i += 1
  291. continue
  292. raise PatchApplyError(f"Unexpected patch line: {line}")
  293. return ops
  294. def _estimate_changed_lines(self, ops: List[Tuple[str, str, str]]) -> int:
  295. """
  296. 估算补丁操作的总变更行数。
  297. 用于检查补丁大小是否超过限制。
  298. 参数:
  299. ops: 补丁操作列表
  300. 返回:
  301. int: 估算的总变更行数
  302. """
  303. changed = 0
  304. for kind, _, payload in ops:
  305. if kind == "add":
  306. # 添加文件:按行数计算
  307. changed += payload.count("\n")
  308. elif kind == "delete":
  309. # 删除文件:按1行计算
  310. changed += 1
  311. elif kind == "update":
  312. # 更新文件:只计算+/-开头的变更行
  313. for l in payload.splitlines():
  314. if l.startswith("+") or l.startswith("-"):
  315. changed += 1
  316. return changed
  317. def _apply_update_payload(self, original: List[str], payload: str, rel_path: str) -> List[str]:
  318. """
  319. 应用 Update File 的内容。
  320. 将 payload 分割成多个 hunk (代码块),然后逐个应用。
  321. """
  322. # 兼容宽松格式:如果 payload 没有任何 + / - / 前导空格行,视为“整文件替换”
  323. raw_lines = payload.splitlines(keepends=True)
  324. if raw_lines and all(not l.startswith(("+", "-", " ")) for l in raw_lines):
  325. return raw_lines
  326. hunks = self._split_hunks(payload)
  327. current = original
  328. try:
  329. for hunk in hunks:
  330. current = self._apply_hunk(current, hunk, rel_path)
  331. return current
  332. except PatchApplyError as e:
  333. # 宽松兜底:当上下文匹配失败时,尝试将 payload 视作“新的完整文件”生成 after 版本
  334. if "context not found" not in str(e).lower():
  335. raise
  336. fallback = self._hunks_to_after(hunks)
  337. if fallback:
  338. return fallback
  339. raise
  340. def _split_hunks(self, payload: str) -> List[List[str]]:
  341. """
  342. 将 Update File 的 payload 分割成多个 hunk(代码块)。
  343. Hunk 通常由 @@ ... @@ 分隔符分隔,或者由空行分隔。
  344. 每个 hunk 代表文件的一个修改区域。
  345. 参数:
  346. payload: Update File 操作的内容
  347. 返回:
  348. List[List[str]]: hunk 列表,每个 hunk 是多行字符串的列表
  349. """
  350. lines = payload.splitlines()
  351. hunks: List[List[str]] = []
  352. buf: List[str] = []
  353. for l in lines:
  354. if l.startswith("@@"):
  355. if buf:
  356. hunks.append(buf)
  357. buf = []
  358. continue
  359. if l.strip() == "" and buf:
  360. hunks.append(buf)
  361. buf = []
  362. continue
  363. buf.append(l)
  364. if buf:
  365. hunks.append(buf)
  366. return [h for h in hunks if any(x.startswith((" ", "+", "-")) for x in h)]
  367. def _apply_hunk(self, current: List[str], hunk_lines: List[str], rel_path: str) -> List[str]:
  368. """
  369. 应用单个 hunk(代码块)到当前文件内容。
  370. 原理:
  371. 1. 解析 hunk,分离出 'before' (上下文 + 删除行) 和 'after' (上下文 + 新增行)
  372. 2. 在当前文件中查找 'before' 块的精确位置
  373. 3. 如果找到匹配的上下文,用 'after' 块替换 'before' 块
  374. 4. 如果找不到匹配的上下文,抛出异常
  375. 参数:
  376. current: 当前文件的内容行列表
  377. hunk_lines: hunk 的内容行列表
  378. rel_path: 文件的相对路径(用于错误提示)
  379. 返回:
  380. List[str]: 应用 hunk 后的文件内容行列表
  381. 异常:
  382. PatchApplyError: 当 hunk 格式错误或找不到匹配的上下文时抛出
  383. """
  384. before: List[str] = []
  385. after: List[str] = []
  386. for l in hunk_lines:
  387. if not l:
  388. continue
  389. tag = l[0]
  390. text = l[1:] + "\n"
  391. if tag == " ":
  392. before.append(text)
  393. after.append(text)
  394. elif tag == "-":
  395. before.append(text)
  396. elif tag == "+":
  397. after.append(text)
  398. if not before:
  399. raise PatchApplyError("Update hunk has no context/removals; refusing to apply")
  400. idx = self._find_subsequence(current, before)
  401. if idx is None:
  402. context_line = next((b.strip() for b in before if b.strip()), "")
  403. hint = f"{rel_path}:search:'{context_line[:80]}'"
  404. raise PatchApplyError("Patch hunk context not found; file changed?", recheck_targets=[hint])
  405. return current[:idx] + after + current[idx + len(before) :]
  406. def _find_subsequence(self, haystack: List[str], needle: List[str]) -> Optional[int]:
  407. """
  408. 在文件内容中查找代码块的起始位置。
  409. 使用简单的 O(N*M) 字符串匹配算法,在 haystack 中查找 needle 的精确匹配。
  410. 参数:
  411. haystack: 文件内容行列表
  412. needle: 要查找的代码块行列表
  413. 返回:
  414. Optional[int]: 匹配的起始行索引,如果未找到则返回 None
  415. """
  416. if len(needle) > len(haystack):
  417. return None
  418. for i in range(0, len(haystack) - len(needle) + 1):
  419. if haystack[i : i + len(needle)] == needle:
  420. return i
  421. # 宽松匹配:忽略行尾空白再尝试一次,缓解缩进/换行轻微偏差
  422. norm_hay = [h.rstrip() + "\n" for h in haystack]
  423. norm_need = [n.rstrip() + "\n" for n in needle]
  424. for i in range(0, len(norm_hay) - len(norm_need) + 1):
  425. if norm_hay[i : i + len(norm_need)] == norm_need:
  426. return i
  427. return None
  428. def _hunks_to_after(self, hunks: List[List[str]]) -> List[str]:
  429. """
  430. 将多个 hunk 的“after”部分合成为一份完整文件内容。
  431. 用于上下文匹配失败时的宽松回退:保留 + 和空格行,忽略 - 行。
  432. """
  433. out: List[str] = []
  434. for hunk in hunks:
  435. for l in hunk:
  436. if not l:
  437. continue
  438. tag = l[0]
  439. text = l[1:] + "\n" if len(l) > 1 else "\n"
  440. if tag == "-" or tag == "@":
  441. continue
  442. if tag in (" ", "+"):
  443. out.append(text)
  444. return out