package com.guwan.backend.handler; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.guwan.backend.service.CourseService; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; import java.io.IOException; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @Slf4j @Component public class CoursesWebSocketHandler extends TextWebSocketHandler { @Autowired private CourseService courseService; private static final Set sessions = ConcurrentHashMap.newKeySet(); @Override public void afterConnectionEstablished(WebSocketSession session) { log.info("Home page WebSocket connection established: {}", session.getId()); sessions.add(session); // 连接建立时,发送所有课程的当前学习人数 sendAllCourseCounts(session); } @Override protected void handleTextMessage(WebSocketSession session, TextMessage message) { log.info("Home page received message from {}: {}", session.getId(), message.getPayload()); try { JsonNode jsonNode = new ObjectMapper().readTree(message.getPayload()); String type = jsonNode.get("type").asText(); if ("STUDENT_COUNT_UPDATE".equals(type)) { // 转发消息给所有首页连接的客户端 broadcastMessage(message.getPayload()); } } catch (Exception e) { log.error("Error handling home page message", e); } } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) { log.info("Home page WebSocket connection closed: {}, status: {}", session.getId(), status); sessions.remove(session); } public void broadcastMessage(String message) { sessions.forEach(session -> { try { if (session.isOpen()) { session.sendMessage(new TextMessage(message)); } } catch (IOException e) { log.error("Error broadcasting to home page", e); } }); } private void sendAllCourseCounts(WebSocketSession session) { try { // 获取所有课程的学习人数 Map allCounts = courseService.getAllCourseCounts(); // 发送每个课程的学习人数 allCounts.forEach((courseId, count) -> { String message = String.format( "{\"type\":\"STUDENT_COUNT_UPDATE\",\"courseId\":\"%s\",\"count\":%d}", courseId, count ); try { if (session.isOpen()) { session.sendMessage(new TextMessage(message)); } } catch (IOException e) { log.error("Error sending course count", e); } }); } catch (Exception e) { log.error("Error sending all course counts", e); } } }