This commit is contained in:
ovo 2025-01-11 20:27:41 +08:00
parent 9e230bf899
commit 1ba85c4bd4
1 changed files with 35 additions and 16 deletions

View File

@ -1,44 +1,62 @@
import os
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class TextSimilarityScorer:
def __init__(self, model_name='bert-base-chinese'):
# 初始化 BERT 模型,这里使用中文预训练模型
def __init__(self, model_name='shibing624/text2vec-base-chinese'):
# 使用专门针对中文优化的模型
self.model = SentenceTransformer(model_name)
def calculate_similarity(self, standard_answer, student_answer):
# 获取文本嵌入
standard_embedding = self.model.encode([standard_answer])
student_embedding = self.model.encode([student_answer])
# 计算余弦相似度
similarity = cosine_similarity(standard_embedding, student_embedding)[0][0]
return similarity
def score(self, standard_answer, student_answer, max_score=100):
# 计算相似度
# 基础相似度
similarity = self.calculate_similarity(standard_answer, student_answer)
# 将相似度转换为分数
score = similarity * max_score
# 四舍五入到整数
return round(score)
# 长度比例
len_ratio = min(len(student_answer) / len(standard_answer), 1.0)
# 关键词覆盖度(简单实现)
standard_keywords = set(standard_answer.split())
student_keywords = set(student_answer.split())
keyword_coverage = len(student_keywords.intersection(standard_keywords)) / len(standard_keywords)
# 综合评分
final_score = (
similarity * 0.6 + # 语义相似度权重
len_ratio * 0.2 + # 长度比例权重
keyword_coverage * 0.2 # 关键词覆盖权重
) * max_score
return round(final_score)
# 使用示例
def main():
# 初始化评分器
scorer = TextSimilarityScorer()
# 示例标准答案和学生答案
standard_answer = "机器学习是人工智能的一个子领域,它使用统计学方法让计算机系统能够从数据中学习和改进。"
student_answers = [
"机器学习是AI的分支通过统计方法让计算机从数据中学习。", # 相似但较简短
"哈哈哈哈哈哈。", # 比较相似
"人工智能是计算机科学的重要领域。" # 不太相关
"机器学习是计算机科学的一个领域,使用统计方法从数据中学习模式。", # 比较相似
"人工智能是计算机科学的重要领域。",
"哈哈哈哈哈哈哈哈。"# 不太相关
]
# 对每个学生答案进行评分
for i, student_answer in enumerate(student_answers, 1):
similarity = scorer.calculate_similarity(standard_answer, student_answer)
@ -48,5 +66,6 @@ def main():
print(f"相似度: {similarity:.2f}")
print(f"得分: {score}")
if __name__ == "__main__":
main()