package com.guwan.backend.MergeStrategy; import jakarta.persistence.EntityManager; import jakarta.persistence.criteria.*; import org.springframework.data.jpa.domain.Specification; import org.springframework.data.jpa.repository.query.QueryUtils; import java.lang.reflect.Constructor; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; /** * 自定义连表查询构建器 * 用于构建复杂的多表连接查询,支持指定返回字段和条件 */ public class JoinBuilder { private final Class rootClass; private final Map, JoinInfo> joins = new HashMap<>(); private final Map, List> conditions = new HashMap<>(); private Map, String[]> selectedFields; private Map namedSelections = new LinkedHashMap<>(); private JoinBuilder(Class rootClass) { this.rootClass = rootClass; } public static JoinBuilder from(Class rootClass) { return new JoinBuilder<>(rootClass); } /** * 添加连接表 * @param joinClass 要连接的表实体类 * @param sourceField 主表的外键字段 * @param targetField 连接表的目标字段 */ public JoinBuilder join(Class joinClass, String sourceField, String targetField) { joins.put(joinClass, new JoinInfo(sourceField, targetField)); return this; } /** * 添加连接表(使用方法引用) * @param joinClass 要连接的表实体类 * @param sourceFieldGetter 主表外键字段的getter方法引用 * @param targetFieldGetter 连接表目标字段的getter方法引用 */ public JoinBuilder join(Class joinClass, Function sourceFieldGetter, Function targetFieldGetter) { String sourceField = getFieldNameFromGetter(sourceFieldGetter); String targetField = getFieldNameFromGetter(targetFieldGetter); return join(joinClass, sourceField, targetField); } /** * 添加等值条件 * @param entityClass 实体类 * @param fieldName 字段名 * @param value 字段值 */ public JoinBuilder equal(Class entityClass, String fieldName, V value) { if (!conditions.containsKey(entityClass)) { conditions.put(entityClass, new ArrayList<>()); } conditions.get(entityClass).add(new Condition(fieldName, value, Operator.EQUAL)); return this; } /** * 添加等值条件(使用方法引用) * @param entityClass 实体类 * @param fieldGetter 字段的getter方法引用 * @param value 字段值 */ public JoinBuilder equal(Class entityClass, Function fieldGetter, V value) { String fieldName = getFieldNameFromGetter(fieldGetter); return equal(entityClass, fieldName, value); } /** * 选择返回的字段 * @param selectedFields 每个实体类需要返回的字段映射 */ public JoinBuilder selectFields(Map, String[]> selectedFields) { this.selectedFields = selectedFields; return this; } /** * 选择返回的字段(使用方法引用) * @param fieldGetters 方法引用列表 */ @SafeVarargs public final JoinBuilder select(Class entityClass, Function... fieldGetters) { if (selectedFields == null) { selectedFields = new HashMap<>(); } String[] fieldNames = Arrays.stream(fieldGetters) .map(this::getFieldNameFromGetter) .toArray(String[]::new); selectedFields.put(entityClass, fieldNames); return this; } /** * 添加一个具有别名的字段选择(优先级高于select方法) * @param alias 字段别名 * @param entityClass 实体类 * @param fieldGetter 字段的getter方法引用 */ public JoinBuilder selectAs(String alias, Class entityClass, Function fieldGetter) { String fieldName = getFieldNameFromGetter(fieldGetter); // 存储字段信息,稍后在configureSelections处理 if (selectedFields == null) { selectedFields = new HashMap<>(); } if (!selectedFields.containsKey(entityClass)) { selectedFields.put(entityClass, new String[0]); } namedSelections.put(alias, new SelectionInfo(entityClass, fieldName)); return this; } /** * 添加一个具有别名的字段选择(字符串版本) * @param alias 字段别名 * @param entityClass 实体类 * @param fieldName 字段名称 */ public JoinBuilder selectAs(String alias, Class entityClass, String fieldName) { // 存储字段信息,稍后在configureSelections处理 if (selectedFields == null) { selectedFields = new HashMap<>(); } if (!selectedFields.containsKey(entityClass)) { selectedFields.put(entityClass, new String[0]); } namedSelections.put(alias, new SelectionInfo(entityClass, fieldName)); return this; } /** * 设置实体间字段映射关系的别名,以便对应DTO字段 * @param dtoClass DTO类 * @param mappings 字段映射 */ @SafeVarargs public final JoinBuilder selectDto(Class dtoClass, FieldMapping... mappings) { for (FieldMapping mapping : mappings) { String dtoFieldName = getFieldNameFromGetter(mapping.getDtoFieldGetter()); String entityFieldName = getFieldNameFromGetter(mapping.getEntityFieldGetter()); Class entityClass = mapping.getEntityClass(); // 存储字段信息,稍后在configureSelections处理 if (selectedFields == null) { selectedFields = new HashMap<>(); } if (!selectedFields.containsKey(entityClass)) { selectedFields.put(entityClass, new String[0]); } namedSelections.put(dtoFieldName, new SelectionInfo(entityClass, entityFieldName)); } return this; } /** * 构建规范查询 * @return Specification 查询规范 */ public Specification build() { return (root, query, criteriaBuilder) -> { // 创建连接关系 Map, From> joinMap = createJoins(root, query); // 应用连表条件和其他条件 List predicates = applyAllConditions(root, joinMap, criteriaBuilder); // 设置查询结果投影 configureSelections(root, joinMap, query); return criteriaBuilder.and(predicates.toArray(new Predicate[0])); }; } /** * 执行查询并返回结果 * @param entityManager EntityManager实例 * @return 查询结果列表 */ public List execute(EntityManager entityManager) { // 创建CriteriaBuilder CriteriaBuilder criteriaBuilder = entityManager.getCriteriaBuilder(); // 创建CriteriaQuery CriteriaQuery criteriaQuery = criteriaBuilder.createQuery(Object[].class); // 设置根实体 Root root = criteriaQuery.from(rootClass); // 创建连接关系 Map, From> fromMap = createJoins(root, criteriaQuery); // 应用所有条件,包括表之间的连接条件和额外查询条件 List predicates = applyAllConditions(root, fromMap, criteriaBuilder); // 添加条件到查询 if (!predicates.isEmpty()) { criteriaQuery.where(criteriaBuilder.and(predicates.toArray(new Predicate[0]))); } // 设置查询结果投影 configureSelections(root, fromMap, criteriaQuery); // 执行查询并返回结果 return entityManager.createQuery(criteriaQuery).getResultList(); } /** * 创建连接 - 修改为使用from而不是join */ private Map, From> createJoins(Root root, CriteriaQuery query) { Map, From> fromMap = new HashMap<>(); fromMap.put(rootClass, root); // 对于每个需要连接的表,创建一个单独的From for (Map.Entry, JoinInfo> entry : joins.entrySet()) { Class joinClass = entry.getKey(); From joinRoot = query.from(joinClass); fromMap.put(joinClass, joinRoot); } return fromMap; } /** * 应用所有条件,包括表之间的连接条件和额外查询条件 */ private List applyAllConditions(Root root, Map, From> fromMap, CriteriaBuilder criteriaBuilder) { List predicates = new ArrayList<>(); // 首先添加表连接条件 for (Map.Entry, JoinInfo> entry : joins.entrySet()) { Class joinClass = entry.getKey(); JoinInfo joinInfo = entry.getValue(); if (fromMap.containsKey(joinClass)) { From joinFrom = fromMap.get(joinClass); Predicate joinCondition = criteriaBuilder.equal( root.get(joinInfo.sourceField), joinFrom.get(joinInfo.targetField) ); predicates.add(joinCondition); } } // 然后添加普通查询条件 for (Class entityClass : conditions.keySet()) { for (Condition condition : conditions.get(entityClass)) { if (fromMap.containsKey(entityClass)) { From from = fromMap.get(entityClass); predicates.add(applyCondition(condition, from, criteriaBuilder)); } } } return predicates; } /** * 执行查询并直接映射到DTO对象 * @param entityManager EntityManager实例 * @param dtoClass DTO类 * @param DTO类型 * @return DTO对象列表 */ public List executeAndMap(EntityManager entityManager, Class dtoClass) { List results = execute(entityManager); return mapToDto(results, dtoClass); } /** * 将查询结果映射到DTO对象 * @param results 查询结果 * @param dtoClass DTO类 * @param DTO类型 * @return DTO对象列表 */ public List mapToDto(List results, Class dtoClass) { try { // 获取所有字段名(顺序与查询结果一致) List fieldNames = new ArrayList<>(namedSelections.keySet()); // 获取构造函数 Class[] paramTypes = new Class[fieldNames.size()]; Arrays.fill(paramTypes, Object.class); Constructor constructor = dtoClass.getDeclaredConstructor(paramTypes); // 映射结果 return results.stream() .map(row -> { try { return constructor.newInstance(row); } catch (Exception e) { throw new RuntimeException("Failed to map result to DTO: " + e.getMessage(), e); } }) .collect(Collectors.toList()); } catch (Exception e) { throw new RuntimeException("Failed to create DTO mapper: " + e.getMessage(), e); } } /** * 配置查询结果投影 */ private void configureSelections(Root root, Map, From> fromMap, CriteriaQuery query) { if (!namedSelections.isEmpty()) { // 处理命名选择 List> selections = new ArrayList<>(); for (Map.Entry entry : namedSelections.entrySet()) { String alias = entry.getKey(); Object value = entry.getValue(); if (value instanceof SelectionInfo) { SelectionInfo info = (SelectionInfo) value; Path path; if (fromMap.containsKey(info.entityClass)) { path = fromMap.get(info.entityClass).get(info.fieldName); selections.add(path.alias(alias)); } } } if (!selections.isEmpty()) { query.multiselect(selections); return; } } // 如果没有命名选择或处理失败,回退到原来的处理方式 if (selectedFields != null && !selectedFields.isEmpty()) { List> selections = new ArrayList<>(); // 添加根实体的选择字段 if (selectedFields.containsKey(rootClass)) { for (String field : selectedFields.get(rootClass)) { selections.add(root.get(field)); } } // 添加关联实体的选择字段 for (Class entityClass : selectedFields.keySet()) { if (entityClass != rootClass && fromMap.containsKey(entityClass)) { From from = fromMap.get(entityClass); for (String field : selectedFields.get(entityClass)) { selections.add(from.get(field)); } } } if (!selections.isEmpty()) { query.multiselect(selections); } } } private Predicate applyCondition(Condition condition, From from, CriteriaBuilder criteriaBuilder) { switch (condition.operator) { case EQUAL: return criteriaBuilder.equal(from.get(condition.fieldName), condition.value); // 可以添加更多操作符支持,如LIKE, IN, GREATER_THAN等 default: throw new UnsupportedOperationException("Unsupported operator: " + condition.operator); } } /** * 从getter方法引用中提取属性名 */ private String getFieldNameFromGetter(Function getter) { // 直接尝试从方法引用的toString中获取 try { String methodRef = getter.toString(); if (methodRef.contains("::")) { // 处理方法引用格式: com.example.Entity::getName String methodName = methodRef.substring(methodRef.lastIndexOf("::") + 2); return extractFieldNameFromMethod(methodName); } else if (methodRef.contains("->")) { // 处理Lambda表达式格式: (Entity e) -> e.getName() String methodCall = methodRef.substring(methodRef.indexOf("->") + 2).trim(); String methodName = methodCall.substring(methodCall.lastIndexOf(".") + 1); if (methodName.endsWith(")")) { methodName = methodName.substring(0, methodName.indexOf("(")); } return extractFieldNameFromMethod(methodName); } } catch (Exception ignored) { // 忽略异常,尝试下一种方法 } try { String getterName = getter.getClass().getName(); if (getterName.contains("$$Lambda$")) { // 处理Lambda表达式 if (getterName.contains("get")) { // 尝试从名称中提取字段名 int getIndex = getterName.lastIndexOf("get"); if (getIndex >= 0 && getIndex + 3 < getterName.length()) { String fieldName = getterName.substring(getIndex + 3); if (fieldName.indexOf('$') > 0) { fieldName = fieldName.substring(0, fieldName.indexOf('$')); } return fieldName.substring(0, 1).toLowerCase() + fieldName.substring(1); } } } } catch (Exception ignored) { // 忽略异常,fallback到基于字符串的方法 } // Fallback到基于字符串的方法 String methodReference = getter.toString(); return getFieldNameFromGetterString(methodReference); } /** * 从getter方法字符串中提取属性名 */ private String getFieldNameFromGetterString(String methodReference) { // 方法引用的格式通常是:类名::get字段名 或 Lambda表达式 if (methodReference.contains("::")) { // 处理方法引用格式: com.example.Entity::getName String methodName = methodReference.substring(methodReference.lastIndexOf("::") + 2); return extractFieldNameFromMethod(methodName); } else if (methodReference.contains("->")) { // 处理Lambda表达式格式: (Entity e) -> e.getName() String methodCall = methodReference.substring(methodReference.indexOf("->") + 2).trim(); String methodName = methodCall.substring(methodCall.lastIndexOf(".") + 1); if (methodName.endsWith(")")) { methodName = methodName.substring(0, methodName.indexOf("(")); } return extractFieldNameFromMethod(methodName); } // 无法识别的格式,返回原始字符串 return methodReference; } /** * 从方法名中提取字段名 */ private String extractFieldNameFromMethod(String methodName) { if (methodName.startsWith("get") && methodName.length() > 3) { // getName -> name (首字母小写) String fieldName = methodName.substring(3); return fieldName.substring(0, 1).toLowerCase() + fieldName.substring(1); } else if (methodName.startsWith("is") && methodName.length() > 2) { // isActive -> active (首字母小写) String fieldName = methodName.substring(2); return fieldName.substring(0, 1).toLowerCase() + fieldName.substring(1); } return methodName; } /** * 首字母大写 */ private String capitalize(String str) { if (str == null || str.isEmpty()) { return str; } return str.substring(0, 1).toUpperCase() + str.substring(1); } private enum Operator { EQUAL // 可以添加更多操作符支持,如LIKE, IN, GREATER_THAN等 } private static class Condition { private final String fieldName; private final Object value; private final Operator operator; Condition(String fieldName, Object value, Operator operator) { this.fieldName = fieldName; this.value = value; this.operator = operator; } } /** * 连接信息 */ private static class JoinInfo { private final String sourceField; private final String targetField; JoinInfo(String sourceField, String targetField) { this.sourceField = sourceField; this.targetField = targetField; } } /** * 选择字段信息 */ private static class SelectionInfo { private final Class entityClass; private final String fieldName; SelectionInfo(Class entityClass, String fieldName) { this.entityClass = entityClass; this.fieldName = fieldName; } } /** * 字段映射定义,用于DTO映射 */ public static class FieldMapping { private final Function dtoFieldGetter; private final Function entityFieldGetter; private final Class entityClass; private FieldMapping(Function dtoFieldGetter, Function entityFieldGetter, Class entityClass) { this.dtoFieldGetter = dtoFieldGetter; this.entityFieldGetter = entityFieldGetter; this.entityClass = entityClass; } public static FieldMapping of(Function dtoFieldGetter, Function entityFieldGetter) { // 通过反射或其他方式获取entityClass Class entityClass = null; try { // 尝试从方法引用中提取类信息 String methodRef = entityFieldGetter.toString(); if (methodRef.contains("::")) { String className = methodRef.substring(0, methodRef.indexOf("::")); entityClass = (Class) Class.forName(className); } } catch (Exception e) { // 忽略异常,entityClass将在运行时确定 } return new FieldMapping<>(dtoFieldGetter, entityFieldGetter, entityClass); } public Function getDtoFieldGetter() { return dtoFieldGetter; } public Function getEntityFieldGetter() { return entityFieldGetter; } public Class getEntityClass() { return entityClass; } } }