data_exploration.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # data_exploration.py
  2. import os
  3. import numpy as np
  4. import pandas as pd
  5. from hello_agents import ToolRegistry
  6. # 读取数据集
  7. work_path = os.path.dirname(os.path.abspath(__file__))
  8. df = pd.read_csv(f"{work_path}/../data/shopping_behavior_updated.csv")
  9. def get_basic_metadata(input: str) -> dict:
  10. """获取基本元数据"""
  11. metadata = {
  12. "shape": df.shape,
  13. "columns": list(df.columns),
  14. "dtypes": df.dtypes.astype(str).to_dict(),
  15. "memory_usage": df.memory_usage(deep=True).sum()
  16. }
  17. return metadata
  18. def assess_data_quality(input: str) -> dict:
  19. """综合数据质量评估"""
  20. quality_report = {
  21. "completeness": {},
  22. "consistency": {},
  23. "validity": {},
  24. "anomalies": {}
  25. }
  26. for col in df.columns:
  27. # 完整性
  28. missing_rate = df[col].isna().mean()
  29. quality_report["completeness"][col] = {
  30. "missing_rate": missing_rate,
  31. "level": "high" if missing_rate < 0.05 else "medium" if missing_rate < 0.2 else "low"
  32. }
  33. # 有效性(基于数据类型)
  34. if pd.api.types.is_numeric_dtype(df[col]):
  35. # 数值型检查
  36. quality_report["anomalies"][col] = {
  37. "min": float(df[col].min()),
  38. "max": float(df[col].max())
  39. }
  40. elif pd.api.types.is_datetime64_any_dtype(df[col]):
  41. # 时间型检查
  42. future_dates = df[col] > pd.Timestamp.now()
  43. quality_report["validity"][col] = {
  44. "future_dates_count": future_dates.sum(),
  45. "date_range": [df[col].min().strftime('%Y-%m-%d'),
  46. df[col].max().strftime('%Y-%m-%d')]
  47. }
  48. return quality_report
  49. def get_statistical_summary(input: str) -> dict:
  50. """核心数据统计摘要"""
  51. summary = {}
  52. for col in df.select_dtypes(include=[np.number]).columns:
  53. series = df[col].dropna()
  54. summary[col] = {
  55. "basic": {
  56. "count": int(series.count()),
  57. "mean": float(series.mean()),
  58. "std": float(series.std()),
  59. "min": float(series.min()),
  60. "25%": float(series.quantile(0.25)),
  61. "50%": float(series.quantile(0.50)),
  62. "75%": float(series.quantile(0.75)),
  63. "max": float(series.max())
  64. },
  65. "advanced": {
  66. "skewness": float(series.skew()),
  67. "kurtosis": float(series.kurtosis()),
  68. "cv": float(series.std() / series.mean()) if series.mean() != 0 else None,
  69. "zeros_count": int((series == 0).sum()),
  70. "negative_count": int((series < 0).sum())
  71. }
  72. }
  73. return summary
  74. def create_data_exploration_registry():
  75. """创建包含数据探查工具的注册表"""
  76. registry = ToolRegistry()
  77. # 注册获取基本元数据函数
  78. registry.register_function(
  79. name="get_basic_metadata",
  80. description="获取基本元数据,包括形状、列名、数据类型和内存使用情况",
  81. func=get_basic_metadata
  82. )
  83. # 注册数据质量评估函数
  84. registry.register_function(
  85. name="assess_data_quality",
  86. description="综合评估数据质量,包括完整性、一致性、有效性和异常检测",
  87. func=assess_data_quality
  88. )
  89. # 注册统计摘要函数
  90. registry.register_function(
  91. name="get_statistical_summary",
  92. description="获取数值型列的核心统计摘要,包括基本统计量和高级统计量",
  93. func=get_statistical_summary
  94. )
  95. return registry
  96. if __name__ == "__main__":
  97. registry = create_data_exploration_registry()
  98. result = registry.execute_tool("get_basic_metadata", input_text=None)
  99. print(result)