From 9e230bf899058582ac9de8013aaeb00892469c63 Mon Sep 17 00:00:00 2001 From: ovo Date: Sat, 11 Jan 2025 10:34:32 +0800 Subject: [PATCH] demo --- text_similarity_scorer.py | 52 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 text_similarity_scorer.py diff --git a/text_similarity_scorer.py b/text_similarity_scorer.py new file mode 100644 index 0000000..60b1608 --- /dev/null +++ b/text_similarity_scorer.py @@ -0,0 +1,52 @@ +from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity +import numpy as np + +class TextSimilarityScorer: + def __init__(self, model_name='bert-base-chinese'): + # 初始化 BERT 模型,这里使用中文预训练模型 + self.model = SentenceTransformer(model_name) + + def calculate_similarity(self, standard_answer, student_answer): + # 获取文本嵌入 + standard_embedding = self.model.encode([standard_answer]) + student_embedding = self.model.encode([student_answer]) + + # 计算余弦相似度 + similarity = cosine_similarity(standard_embedding, student_embedding)[0][0] + return similarity + + def score(self, standard_answer, student_answer, max_score=100): + # 计算相似度 + similarity = self.calculate_similarity(standard_answer, student_answer) + + # 将相似度转换为分数 + score = similarity * max_score + + # 四舍五入到整数 + return round(score) + +# 使用示例 +def main(): + # 初始化评分器 + scorer = TextSimilarityScorer() + + # 示例标准答案和学生答案 + standard_answer = "机器学习是人工智能的一个子领域,它使用统计学方法让计算机系统能够从数据中学习和改进。" + student_answers = [ + "机器学习是AI的分支,通过统计方法让计算机从数据中学习。", # 相似但较简短 + "哈哈哈哈哈哈。", # 比较相似 + "人工智能是计算机科学的重要领域。" # 不太相关 + ] + + # 对每个学生答案进行评分 + for i, student_answer in enumerate(student_answers, 1): + similarity = scorer.calculate_similarity(standard_answer, student_answer) + score = scorer.score(standard_answer, student_answer) + print(f"\n学生答案 {i}:") + print(f"答案: {student_answer}") + print(f"相似度: {similarity:.2f}") + print(f"得分: {score}") + +if __name__ == "__main__": + main() \ No newline at end of file