book-python/text_similarity_scorer.py

71 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class TextSimilarityScorer:
def __init__(self, model_name='shibing624/text2vec-base-chinese'):
# 使用专门针对中文优化的模型
self.model = SentenceTransformer(model_name)
def calculate_similarity(self, standard_answer, student_answer):
# 获取文本嵌入
standard_embedding = self.model.encode([standard_answer])
student_embedding = self.model.encode([student_answer])
# 计算余弦相似度
similarity = cosine_similarity(standard_embedding, student_embedding)[0][0]
return similarity
def score(self, standard_answer, student_answer, max_score=100):
# 基础相似度
similarity = self.calculate_similarity(standard_answer, student_answer)
# 长度比例
len_ratio = min(len(student_answer) / len(standard_answer), 1.0)
# 关键词覆盖度(简单实现)
standard_keywords = set(standard_answer.split())
student_keywords = set(student_answer.split())
keyword_coverage = len(student_keywords.intersection(standard_keywords)) / len(standard_keywords)
# 综合评分
final_score = (
similarity * 0.6 + # 语义相似度权重
len_ratio * 0.2 + # 长度比例权重
keyword_coverage * 0.2 # 关键词覆盖权重
) * max_score
return round(final_score)
# 使用示例
def main():
# 初始化评分器
scorer = TextSimilarityScorer()
# 示例标准答案和学生答案
standard_answer = "机器学习是人工智能的一个子领域,它使用统计学方法让计算机系统能够从数据中学习和改进。"
student_answers = [
"机器学习是AI的分支通过统计方法让计算机从数据中学习。", # 相似但较简短
"机器学习是计算机科学的一个领域,使用统计方法从数据中学习模式。", # 比较相似
"人工智能是计算机科学的重要领域。",
"哈哈哈哈哈哈哈哈。"# 不太相关
]
# 对每个学生答案进行评分
for i, student_answer in enumerate(student_answers, 1):
similarity = scorer.calculate_similarity(standard_answer, student_answer)
score = scorer.score(standard_answer, student_answer)
print(f"\n学生答案 {i}:")
print(f"答案: {student_answer}")
print(f"相似度: {similarity:.2f}")
print(f"得分: {score}")
if __name__ == "__main__":
main()