BPE.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334
  1. import re, collections
  2. def get_stats(vocab):
  3. """统计词元对频率"""
  4. pairs = collections.defaultdict(int)
  5. for word, freq in vocab.items():
  6. symbols = word.split()
  7. for i in range(len(symbols)-1):
  8. pairs[symbols[i],symbols[i+1]] += freq
  9. return pairs
  10. def merge_vocab(pair, v_in):
  11. """合并词元对"""
  12. v_out = {}
  13. bigram = re.escape(' '.join(pair))
  14. p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
  15. for word in v_in:
  16. w_out = p.sub(''.join(pair), word)
  17. v_out[w_out] = v_in[word]
  18. return v_out
  19. # 准备语料库,每个词末尾加上</w>表示结束,并切分好字符
  20. vocab = {'h u g </w>': 1, 'p u g </w>': 1, 'p u n </w>': 1, 'b u n </w>': 1}
  21. num_merges = 4 # 设置合并次数
  22. for i in range(num_merges):
  23. pairs = get_stats(vocab)
  24. if not pairs:
  25. break
  26. best = max(pairs, key=pairs.get)
  27. vocab = merge_vocab(best, vocab)
  28. print(f"第{i+1}次合并: {best} -> {''.join(best)}")
  29. print(f"新词表(部分): {list(vocab.keys())}")
  30. print("-" * 20)