kmeanscsharp.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import sys
  2. import jieba
  3. import re
  4. from sklearn.feature_extraction.text import TfidfVectorizer
  5. from sklearn.cluster import KMeans
  6. from collections import defaultdict
  7. def preprocess(text):
  8. # 去除标点符号
  9. text = re.sub(r'[^\w\s]', '', text)
  10. # 分词
  11. words = jieba.cut(text)
  12. # 去除停用词
  13. stopwords = ['的', '是', '我', '你', '他', '不满', '举报', '市民', '投诉', '家长', '教育局', '学校', '学生'] # 自定义停用词表
  14. words = [word for word in words if word not in stopwords]
  15. return ' '.join(words)
  16. def extract_features(texts):
  17. #vectorizer = CountVectorizer()
  18. #transform = TfidfTransformer()
  19. #features = transform.fit_transform(vectorizer.fit_transform(texts))
  20. tfidf_vectorizer = TfidfVectorizer()#token_pattern=r"(?u)\b\w+\b"
  21. features = tfidf_vectorizer.fit_transform(texts)
  22. return features.toarray()
  23. def cluster(features, n_clusters):
  24. kmeans = KMeans(n_clusters=n_clusters)
  25. kmeans.fit(features)
  26. labels = kmeans.labels_
  27. return labels
  28. # def visualize(labels):
  29. # unique_labels = set(labels)
  30. # #sizes = [labels.unique(label) for label in unique_labels]
  31. # counts = np.unique(labels, return_counts=True)
  32. # sizes = counts[1]
  33. # labels = [str(label) for label in unique_labels]
  34. # plt.pie(sizes, labels=labels, autopct='%1.1f%%')
  35. # plt.axis('equal')
  36. # plt.show()
  37. def getgroups(liststr , groupnum):
  38. listarray = eval(liststr)
  39. processed_texts1 = [preprocess(str(text)) for text in listarray]
  40. # 特征提取
  41. features = extract_features(processed_texts1)
  42. # 聚类算法
  43. n_clusters = groupnum
  44. labels = cluster(features, n_clusters)
  45. counts = defaultdict(str)
  46. i = 0
  47. for label in labels:
  48. counts[label] += "|" + str(listarray[i])
  49. i = i + 1
  50. listarray2 = []
  51. for item in counts.items():
  52. listarray2.append(item[1])
  53. return listarray2
  54. if __name__ == '__main__':
  55. try:
  56. result = getgroups(sys.argv[1],int(sys.argv[2]))
  57. print(result)
  58. except Exception as e:
  59. print(e)
  60. # result = getgroups("['学校食堂需要整改','学校食堂饭菜难吃','违规补课还是存在','学生在食堂吃不饱','学校要提高食堂饭菜质量','学校霸凌学生受欺','学校食堂饭菜质量很差','学生在食堂吃了拉肚子','举报违规补课','宁波优质学校还是太少,需要提高整体教学质量']", 4)