book-python/text_similarity_scorer.py

52 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class TextSimilarityScorer:
def __init__(self, model_name='bert-base-chinese'):
# 初始化 BERT 模型,这里使用中文预训练模型
self.model = SentenceTransformer(model_name)
def calculate_similarity(self, standard_answer, student_answer):
# 获取文本嵌入
standard_embedding = self.model.encode([standard_answer])
student_embedding = self.model.encode([student_answer])
# 计算余弦相似度
similarity = cosine_similarity(standard_embedding, student_embedding)[0][0]
return similarity
def score(self, standard_answer, student_answer, max_score=100):
# 计算相似度
similarity = self.calculate_similarity(standard_answer, student_answer)
# 将相似度转换为分数
score = similarity * max_score
# 四舍五入到整数
return round(score)
# 使用示例
def main():
# 初始化评分器
scorer = TextSimilarityScorer()
# 示例标准答案和学生答案
standard_answer = "机器学习是人工智能的一个子领域,它使用统计学方法让计算机系统能够从数据中学习和改进。"
student_answers = [
"机器学习是AI的分支通过统计方法让计算机从数据中学习。", # 相似但较简短
"哈哈哈哈哈哈。", # 比较相似
"人工智能是计算机科学的重要领域。" # 不太相关
]
# 对每个学生答案进行评分
for i, student_answer in enumerate(student_answers, 1):
similarity = scorer.calculate_similarity(standard_answer, student_answer)
score = scorer.score(standard_answer, student_answer)
print(f"\n学生答案 {i}:")
print(f"答案: {student_answer}")
print(f"相似度: {similarity:.2f}")
print(f"得分: {score}")
if __name__ == "__main__":
main()