1
0

aime_generator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. """
  2. AIME数学题目生成器
  3. 使用HelloAgents框架生成AIME风格的数学题目
  4. """
  5. import json
  6. import os
  7. import time
  8. import random
  9. from typing import List, Dict, Any, Optional
  10. from datetime import datetime
  11. from tqdm import tqdm
  12. from hello_agents import SimpleAgent
  13. from hello_agents import HelloAgentsLLM
  14. from datasets import load_dataset
  15. class AIMEGenerator:
  16. """AIME题目生成器"""
  17. # AIME题目生成提示词(英文)
  18. GENERATION_PROMPT = """You are a professional mathematics competition problem designer, skilled in creating AIME (American Invitational Mathematics Examination) style problems.
  19. AIME Problem Characteristics:
  20. 1. Answer: An integer between 0 and 999
  21. 2. Topics: Algebra, Geometry, Number Theory, Combinatorics, Probability, etc.
  22. 3. Style: Requires multi-step reasoning, but no advanced theory
  23. 4. Difficulty: Medium to hard (similar to AIME problems 6-9)
  24. Please generate an AIME-style mathematics problem, including:
  25. 1. Problem statement (clear and complete)
  26. 2. Answer (an integer between 0 and 999)
  27. 3. Detailed solution (including all reasoning steps)
  28. 4. Topic classification (Algebra/Geometry/Number Theory/Combinatorics/Probability)
  29. Please output in the following JSON format, avoid using special escape characters in JSON:
  30. ```json
  31. {
  32. "problem": "Problem statement in English",
  33. "answer": 123,
  34. "solution": "Detailed solution steps in English",
  35. "topic": "Algebra"
  36. }
  37. ```
  38. """
  39. def __init__(
  40. self,
  41. llm: HelloAgentsLLM = None,
  42. delay_seconds: float = 1.0,
  43. use_reference_examples: bool = True,
  44. reference_dataset: str = "TianHongZXY/aime-1983-2025"
  45. ):
  46. """
  47. 初始化生成器
  48. Args:
  49. llm: LLM实例(可选)
  50. delay_seconds: 每次生成之间的延迟(秒),避免API速率限制
  51. use_reference_examples: 是否使用真题作为参考样例
  52. reference_dataset: 参考数据集名称,默认使用TianHongZXY/aime-1983-2025(900+道题)
  53. """
  54. # 如果没有提供llm,创建默认的HelloAgentsLLM
  55. if llm is None:
  56. self.llm = HelloAgentsLLM()
  57. else:
  58. self.llm = llm
  59. self.agent = SimpleAgent(
  60. name="AIME Generator",
  61. llm=self.llm,
  62. system_prompt="你是一位专业的数学竞赛题目设计专家。"
  63. )
  64. self.delay_seconds = delay_seconds
  65. self.use_reference_examples = use_reference_examples
  66. self.reference_examples = []
  67. # 加载参考样例
  68. if use_reference_examples:
  69. try:
  70. print(f"📚 加载AIME真题数据集: {reference_dataset}")
  71. # 尝试不同的split
  72. try:
  73. dataset = load_dataset(reference_dataset, split="train")
  74. except:
  75. dataset = load_dataset(reference_dataset, split="test")
  76. # 加载所有题目作为参考
  77. self.reference_examples = list(dataset)
  78. print(f" ✓ 已加载 {len(self.reference_examples)} 道参考题目")
  79. # 统计年份分布(如果有year字段)
  80. year_counts = {}
  81. for item in self.reference_examples:
  82. year = item.get('year')
  83. if year:
  84. year_counts[year] = year_counts.get(year, 0) + 1
  85. if year_counts:
  86. year_range = f"{min(year_counts.keys())}-{max(year_counts.keys())}"
  87. print(f" ℹ️ 年份范围: {year_range}")
  88. except Exception as e:
  89. print(f" ⚠️ 加载参考样例失败: {e}")
  90. print(f" ℹ️ 将使用默认提示词生成")
  91. self.use_reference_examples = False
  92. def generate_single(self, max_retries: int = 3) -> Dict[str, Any]:
  93. """
  94. 生成单个题目
  95. Args:
  96. max_retries: 最大重试次数
  97. Returns:
  98. 题目数据
  99. """
  100. # 构建提示词
  101. prompt = self._build_prompt()
  102. for attempt in range(max_retries):
  103. try:
  104. response = self.agent.run(prompt)
  105. return self._parse_response(response)
  106. except Exception as e:
  107. if attempt < max_retries - 1:
  108. tqdm.write(f"⚠️ 生成失败(尝试 {attempt + 1}/{max_retries}),{self.delay_seconds}秒后重试...")
  109. time.sleep(self.delay_seconds)
  110. else:
  111. tqdm.write(f"❌ 生成失败,已达最大重试次数: {e}")
  112. return self._get_default_problem()
  113. def _build_prompt(self) -> str:
  114. """构建生成提示词"""
  115. if not self.use_reference_examples or not self.reference_examples:
  116. return self.GENERATION_PROMPT
  117. # 随机选择一个参考样例
  118. example = random.choice(self.reference_examples)
  119. example_problem = example.get('problem', 'Example problem')
  120. example_answer = example.get('answer', 0)
  121. # 构建带参考样例的提示词(英文)
  122. prompt = f"""You are a professional mathematics competition problem designer, skilled in creating AIME (American Invitational Mathematics Examination) style problems.
  123. 【Reference Example】(For style reference only, please generate a completely different problem)
  124. Problem: {example_problem}
  125. Answer: {example_answer}
  126. AIME Problem Characteristics:
  127. 1. Answer: An integer between 0 and 999
  128. 2. Topics: Algebra, Geometry, Number Theory, Combinatorics, Probability, etc.
  129. 3. Style: Requires multi-step reasoning, but no advanced theory
  130. 4. Difficulty: Medium to hard (similar to AIME problems 6-9)
  131. Please generate a **completely different** AIME-style mathematics problem, including:
  132. 1. Problem statement (clear and complete, different from the reference)
  133. 2. Answer (an integer between 0 and 999, different from the reference)
  134. 3. Detailed solution (including all reasoning steps)
  135. 4. Topic classification (Algebra/Geometry/Number Theory/Combinatorics/Probability)
  136. Please output in the following JSON format, avoid using special escape characters in JSON:
  137. ```json
  138. {{
  139. "problem": "Problem statement in English",
  140. "answer": 123,
  141. "solution": "Detailed solution steps in English",
  142. "topic": "Algebra"
  143. }}
  144. ```
  145. Important Notes:
  146. - **Must generate a completely different problem from the reference**
  147. - You can reference the style, but do not copy the content
  148. - Ensure the problem is creative and original
  149. """
  150. return prompt
  151. def _parse_response(self, response: str) -> Dict[str, Any]:
  152. """解析LLM响应(支持LaTeX数学公式)"""
  153. import re
  154. # 提取JSON部分
  155. if "```json" in response:
  156. json_str = response.split("```json")[1].split("```")[0].strip()
  157. elif "```" in response:
  158. json_str = response.split("```")[1].split("```")[0].strip()
  159. else:
  160. json_str = response.strip()
  161. # 使用json.loads的strict=False来处理转义字符
  162. # 但这还不够,我们需要更智能的处理
  163. try:
  164. problem_data = json.loads(json_str)
  165. except json.JSONDecodeError as e:
  166. # 如果解析失败,尝试修复常见的LaTeX转义问题
  167. # 方法:先将字符串中的单个反斜杠替换为双反斜杠(但保留已经转义的)
  168. # 这样LaTeX的 \frac 会变成 \\frac,在JSON中是合法的
  169. # 使用正则表达式:找到所有未转义的反斜杠(不是\\的\)
  170. # 并将其替换为\\
  171. fixed_json_str = re.sub(r'(?<!\\)\\(?!["\\/bfnrtu])', r'\\\\', json_str)
  172. try:
  173. problem_data = json.loads(fixed_json_str)
  174. except json.JSONDecodeError:
  175. # 如果还是失败,打印错误信息并抛出
  176. print(f"❌ JSON解析失败:")
  177. print(f"原始响应: {response[:500]}...")
  178. print(f"提取的JSON: {json_str[:500]}...")
  179. raise
  180. # 验证必需字段
  181. if "problem" not in problem_data or "answer" not in problem_data:
  182. raise ValueError("缺少必需字段: problem 或 answer")
  183. # 验证答案范围
  184. answer = int(problem_data.get("answer", 0))
  185. if not (0 <= answer <= 999):
  186. print(f"⚠️ 答案超出范围: {answer},调整为0-999范围内")
  187. answer = max(0, min(999, answer))
  188. problem_data["answer"] = answer
  189. # 确保有默认值
  190. problem_data.setdefault("solution", "No solution provided")
  191. problem_data.setdefault("topic", "Uncategorized")
  192. return problem_data
  193. def _get_default_problem(self) -> Dict[str, Any]:
  194. """获取默认题目(生成失败时使用)"""
  195. return {
  196. "problem": "生成失败,请重新生成",
  197. "answer": 0,
  198. "solution": "N/A",
  199. "topic": "未知"
  200. }
  201. def generate_batch(
  202. self,
  203. num_problems: int = 30,
  204. checkpoint_path: str = None
  205. ) -> List[Dict[str, Any]]:
  206. """
  207. 批量生成题目
  208. Args:
  209. num_problems: 生成题目数量
  210. checkpoint_path: 检查点文件路径(用于保存进度)
  211. Returns:
  212. 题目列表
  213. """
  214. print(f"\n🎯 开始生成AIME题目")
  215. print(f" 目标数量: {num_problems}")
  216. print(f" 生成模型: {self.llm.model}")
  217. print(f" 延迟设置: {self.delay_seconds}秒/题")
  218. # 尝试从检查点恢复
  219. problems = []
  220. start_index = 0
  221. if checkpoint_path and os.path.exists(checkpoint_path):
  222. print(f"\n📂 发现检查点文件,尝试恢复...")
  223. try:
  224. with open(checkpoint_path, 'r', encoding='utf-8') as f:
  225. problems = json.load(f)
  226. start_index = len(problems)
  227. print(f" ✓ 已恢复 {start_index} 个题目,从第 {start_index + 1} 个继续")
  228. except Exception as e:
  229. print(f" ⚠️ 恢复失败: {e},从头开始")
  230. problems = []
  231. start_index = 0
  232. # 生成题目(使用tqdm显示进度)
  233. with tqdm(total=num_problems, initial=start_index, desc="生成AIME题目", unit="题") as pbar:
  234. last_call_time = 0 # 上次API调用的时间
  235. for i in range(start_index, num_problems):
  236. # 计算距离上次调用的时间
  237. if last_call_time > 0:
  238. elapsed = time.time() - last_call_time
  239. # 如果距离上次调用不足delay_seconds,则等待
  240. if elapsed < self.delay_seconds:
  241. wait_time = self.delay_seconds - elapsed
  242. tqdm.write(f"⏳ 等待 {wait_time:.1f} 秒以避免速率限制...")
  243. time.sleep(wait_time)
  244. # 记录开始时间
  245. start_time = time.time()
  246. # 生成题目
  247. problem = self.generate_single()
  248. problem["id"] = f"gen_aime_{i + 1}"
  249. problem["generated_at"] = datetime.now().isoformat()
  250. # 记录结束时间
  251. last_call_time = time.time()
  252. generation_time = last_call_time - start_time
  253. problems.append(problem)
  254. # 更新进度条描述
  255. pbar.set_postfix({
  256. "主题": problem.get('topic', 'N/A'),
  257. "答案": problem.get('answer', 'N/A'),
  258. "耗时": f"{generation_time:.1f}s"
  259. })
  260. pbar.update(1)
  261. # 保存检查点
  262. if checkpoint_path:
  263. try:
  264. with open(checkpoint_path, 'w', encoding='utf-8') as f:
  265. json.dump(problems, f, ensure_ascii=False, indent=2)
  266. except Exception as e:
  267. tqdm.write(f"⚠️ 保存检查点失败: {e}")
  268. print(f"\n✅ 生成完成!共 {len(problems)} 个题目")
  269. return problems
  270. def save_problems(
  271. self,
  272. problems: List[Dict[str, Any]],
  273. output_path: str
  274. ):
  275. """保存题目到文件"""
  276. # 确保目录存在
  277. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  278. with open(output_path, 'w', encoding='utf-8') as f:
  279. json.dump(problems, f, ensure_ascii=False, indent=2)
  280. print(f"\n💾 题目已保存: {output_path}")
  281. def generate_and_save(
  282. self,
  283. num_problems: int = 30,
  284. output_dir: str = "data_generation/generated_data"
  285. ) -> str:
  286. """生成并保存题目"""
  287. # 创建输出目录
  288. os.makedirs(output_dir, exist_ok=True)
  289. # 清理旧的检查点文件
  290. for file in os.listdir(output_dir):
  291. if file.startswith("checkpoint_") and file.endswith(".json"):
  292. old_checkpoint = os.path.join(output_dir, file)
  293. try:
  294. os.remove(old_checkpoint)
  295. print(f"🗑️ 已删除旧检查点文件: {file}")
  296. except Exception as e:
  297. print(f"⚠️ 删除旧检查点失败: {e}")
  298. # 设置检查点路径
  299. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  300. checkpoint_path = os.path.join(output_dir, f"checkpoint_{timestamp}.json")
  301. # 生成题目(带检查点)
  302. problems = self.generate_batch(num_problems, checkpoint_path=checkpoint_path)
  303. # 保存题目
  304. output_path = os.path.join(output_dir, f"aime_generated_{timestamp}.json")
  305. self.save_problems(problems, output_path)
  306. # 生成统计报告
  307. self._generate_statistics_report(problems, output_dir, timestamp)
  308. # 删除检查点文件
  309. if os.path.exists(checkpoint_path):
  310. try:
  311. os.remove(checkpoint_path)
  312. print(f"\n🗑️ 已删除检查点文件")
  313. except Exception as e:
  314. print(f"\n⚠️ 删除检查点文件失败: {e}")
  315. return output_path
  316. def _generate_statistics_report(
  317. self,
  318. problems: List[Dict[str, Any]],
  319. output_dir: str,
  320. timestamp: str
  321. ):
  322. """生成统计报告"""
  323. # 统计主题分布
  324. topics = {}
  325. answers = []
  326. for problem in problems:
  327. topic = problem.get("topic", "未知")
  328. topics[topic] = topics.get(topic, 0) + 1
  329. if "answer" in problem:
  330. answers.append(problem["answer"])
  331. # 生成报告
  332. report = f"""# AIME题目生成统计报告
  333. ## 基本信息
  334. - **生成时间**: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  335. - **题目数量**: {len(problems)}
  336. ## 主题分布
  337. | 主题 | 数量 | 占比 |
  338. |------|------|------|
  339. """
  340. for topic, count in sorted(topics.items(), key=lambda x: x[1], reverse=True):
  341. percentage = count / len(problems) * 100
  342. report += f"| {topic} | {count} | {percentage:.1f}% |\n"
  343. if answers:
  344. report += f"""
  345. ## 答案分析
  346. - **平均答案**: {sum(answers) / len(answers):.2f}
  347. - **最小答案**: {min(answers)}
  348. - **最大答案**: {max(answers)}
  349. - **答案范围**: {min(answers)}-{max(answers)}
  350. """
  351. report += f"""
  352. ## 题目列表
  353. | ID | 主题 | 答案 |
  354. |-----|------|------|
  355. """
  356. for problem in problems[:10]: # 只显示前10个
  357. report += f"| {problem.get('id', 'N/A')} | {problem.get('topic', 'N/A')} | {problem.get('answer', 'N/A')} |\n"
  358. if len(problems) > 10:
  359. report += f"\n*(仅显示前10个题目,完整列表请查看JSON文件)*\n"
  360. report += f"""
  361. ---
  362. *报告生成时间: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}*
  363. """
  364. # 保存报告
  365. report_path = os.path.join(output_dir, f"generation_report_{timestamp}.md")
  366. with open(report_path, 'w', encoding='utf-8') as f:
  367. f.write(report)
  368. print(f"📊 统计报告已保存: {report_path}")
  369. if __name__ == "__main__":
  370. # 创建生成器
  371. generator = AIMEGenerator()
  372. # 生成30个题目
  373. output_path = generator.generate_and_save(num_problems=30)
  374. print(f"\n✅ 完成!生成的题目保存在: {output_path}")