package com.guwan.backend.MergeStrategy; import jakarta.persistence.EntityManager; import jakarta.persistence.criteria.*; import org.springframework.data.jpa.domain.Specification; import org.springframework.data.jpa.repository.query.QueryUtils; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; import com.guwan.backend.MergeStrategy.LambdaUtils.PropertyFunction; /** * 自定义连表查询构建器 * 用于构建复杂的多表连接查询,支持指定返回字段和条件 */ public class JoinBuilder { private final Class rootClass; private final Map, JoinInfo> joins = new HashMap<>(); private final Map, List> conditions = new HashMap<>(); private Map, String[]> selectedFields; private Map namedSelections = new LinkedHashMap<>(); private JoinBuilder(Class rootClass) { this.rootClass = rootClass; } public static JoinBuilder from(Class rootClass) { return new JoinBuilder<>(rootClass); } /** * 添加连接表 * @param joinClass 要连接的表实体类 * @param sourceField 主表的外键字段 * @param targetField 连接表的目标字段 */ public JoinBuilder join(Class joinClass, String sourceField, String targetField) { joins.put(joinClass, new JoinInfo(sourceField, targetField)); return this; } /** * 添加连接表(使用增强的方法引用) * @param joinClass 要连接的表实体类 * @param sourceFieldGetter 主表外键字段的getter方法引用 * @param targetFieldGetter 连接表目标字段的getter方法引用 */ public JoinBuilder join(Class joinClass, PropertyFunction sourceFieldGetter, PropertyFunction targetFieldGetter) { String sourceField = LambdaUtils.getPropertyName(sourceFieldGetter); String targetField = LambdaUtils.getPropertyName(targetFieldGetter); return join(joinClass, sourceField, targetField); } /** * 添加等值条件 * @param entityClass 实体类 * @param fieldName 字段名 * @param value 字段值 */ public JoinBuilder equal(Class entityClass, String fieldName, V value) { if (!conditions.containsKey(entityClass)) { conditions.put(entityClass, new ArrayList<>()); } conditions.get(entityClass).add(new Condition(fieldName, value, Operator.EQUAL)); return this; } /** * 添加等值条件(使用增强的方法引用) * @param entityClass 实体类 * @param fieldGetter 字段的getter方法引用 * @param value 字段值 */ public JoinBuilder equal(Class entityClass, PropertyFunction fieldGetter, V value) { String fieldName = LambdaUtils.getPropertyName(fieldGetter); return equal(entityClass, fieldName, value); } /** * 选择返回的字段 * @param selectedFields 每个实体类需要返回的字段映射 */ public JoinBuilder selectFields(Map, String[]> selectedFields) { this.selectedFields = selectedFields; return this; } /** * 选择返回的字段(使用增强的方法引用) * @param fieldGetters 方法引用列表 */ @SafeVarargs public final JoinBuilder select(Class entityClass, PropertyFunction... fieldGetters) { if (selectedFields == null) { selectedFields = new HashMap<>(); } String[] fieldNames = Arrays.stream(fieldGetters) .map(LambdaUtils::getPropertyName) .toArray(String[]::new); selectedFields.put(entityClass, fieldNames); return this; } /** * 添加一个具有别名的字段选择(使用增强的方法引用) * @param alias 字段别名 * @param entityClass 实体类 * @param fieldGetter 字段的getter方法引用 */ public JoinBuilder selectAs(String alias, Class entityClass, PropertyFunction fieldGetter) { String fieldName = LambdaUtils.getPropertyName(fieldGetter); // 存储字段信息,稍后在configureSelections处理 if (selectedFields == null) { selectedFields = new HashMap<>(); } if (!selectedFields.containsKey(entityClass)) { selectedFields.put(entityClass, new String[0]); } namedSelections.put(alias, new SelectionInfo(entityClass, fieldName)); return this; } /** * 添加一个具有别名的字段选择(字符串版本) * @param alias 字段别名 * @param entityClass 实体类 * @param fieldName 字段名称 */ public JoinBuilder selectAs(String alias, Class entityClass, String fieldName) { // 存储字段信息,稍后在configureSelections处理 if (selectedFields == null) { selectedFields = new HashMap<>(); } if (!selectedFields.containsKey(entityClass)) { selectedFields.put(entityClass, new String[0]); } namedSelections.put(alias, new SelectionInfo(entityClass, fieldName)); return this; } /** * 设置实体间字段映射关系的别名,以便对应DTO字段 * @param dtoClass DTO类 * @param mappings 字段映射 */ @SafeVarargs public final JoinBuilder selectDto(Class dtoClass, FieldMapping... mappings) { for (FieldMapping mapping : mappings) { String dtoFieldName = mapping.getDtoFieldName(); String entityFieldName = mapping.getEntityFieldName(); Class entityClass = mapping.getEntityClass(); // 存储字段信息,稍后在configureSelections处理 if (selectedFields == null) { selectedFields = new HashMap<>(); } if (!selectedFields.containsKey(entityClass)) { selectedFields.put(entityClass, new String[0]); } namedSelections.put(dtoFieldName, new SelectionInfo(entityClass, entityFieldName)); } return this; } /** * 构建规范查询 * @return Specification 查询规范 */ public Specification build() { return (root, query, criteriaBuilder) -> { // 创建连接关系 Map, From> joinMap = createJoins(root, query); // 应用连表条件和其他条件 List predicates = applyAllConditions(root, joinMap, criteriaBuilder); // 设置查询结果投影 configureSelections(root, joinMap, query); return criteriaBuilder.and(predicates.toArray(new Predicate[0])); }; } /** * 执行查询并返回结果 * @param entityManager EntityManager实例 * @return 查询结果列表 */ public List execute(EntityManager entityManager) { // 创建CriteriaBuilder CriteriaBuilder criteriaBuilder = entityManager.getCriteriaBuilder(); // 创建CriteriaQuery CriteriaQuery criteriaQuery = criteriaBuilder.createQuery(Object[].class); // 设置根实体 Root root = criteriaQuery.from(rootClass); // 创建连接关系 Map, From> fromMap = createJoins(root, criteriaQuery); // 应用所有条件,包括表之间的连接条件和额外查询条件 List predicates = applyAllConditions(root, fromMap, criteriaBuilder); // 添加条件到查询 if (!predicates.isEmpty()) { criteriaQuery.where(criteriaBuilder.and(predicates.toArray(new Predicate[0]))); } // 设置查询结果投影 configureSelections(root, fromMap, criteriaQuery); // 执行查询并返回结果 return entityManager.createQuery(criteriaQuery).getResultList(); } /** * 创建连接 - 修改为使用from而不是join */ private Map, From> createJoins(Root root, CriteriaQuery query) { Map, From> fromMap = new HashMap<>(); fromMap.put(rootClass, root); // 对于每个需要连接的表,创建一个单独的From for (Map.Entry, JoinInfo> entry : joins.entrySet()) { Class joinClass = entry.getKey(); From joinRoot = query.from(joinClass); fromMap.put(joinClass, joinRoot); } return fromMap; } /** * 应用所有条件,包括表之间的连接条件和额外查询条件 */ private List applyAllConditions(Root root, Map, From> fromMap, CriteriaBuilder criteriaBuilder) { List predicates = new ArrayList<>(); // 首先添加表连接条件 for (Map.Entry, JoinInfo> entry : joins.entrySet()) { Class joinClass = entry.getKey(); JoinInfo joinInfo = entry.getValue(); if (fromMap.containsKey(joinClass)) { From joinFrom = fromMap.get(joinClass); Predicate joinCondition = criteriaBuilder.equal( root.get(joinInfo.sourceField), joinFrom.get(joinInfo.targetField) ); predicates.add(joinCondition); } } // 然后添加普通查询条件 for (Class entityClass : conditions.keySet()) { for (Condition condition : conditions.get(entityClass)) { if (fromMap.containsKey(entityClass)) { From from = fromMap.get(entityClass); predicates.add(applyCondition(condition, from, criteriaBuilder)); } } } return predicates; } /** * 执行查询并直接映射到DTO对象 * @param entityManager EntityManager实例 * @param dtoClass DTO类 * @param DTO类型 * @return DTO对象列表 */ public List executeAndMap(EntityManager entityManager, Class dtoClass) { List results = execute(entityManager); return mapToDto(results, dtoClass); } /** * 将查询结果映射到DTO对象 * @param results 查询结果 * @param dtoClass DTO类 * @param DTO类型 * @return DTO对象列表 */ public List mapToDto(List results, Class dtoClass) { try { // 获取所有字段名(顺序与查询结果一致) List fieldNames = new ArrayList<>(namedSelections.keySet()); // 获取构造函数 Class[] paramTypes = new Class[fieldNames.size()]; Arrays.fill(paramTypes, Object.class); Constructor constructor = dtoClass.getDeclaredConstructor(paramTypes); // 映射结果 return results.stream() .map(row -> { try { return constructor.newInstance(row); } catch (Exception e) { throw new RuntimeException("Failed to map result to DTO: " + e.getMessage(), e); } }) .collect(Collectors.toList()); } catch (Exception e) { throw new RuntimeException("Failed to create DTO mapper: " + e.getMessage(), e); } } /** * 配置查询结果投影 */ private void configureSelections(Root root, Map, From> fromMap, CriteriaQuery query) { if (!namedSelections.isEmpty()) { // 处理命名选择 List> selections = new ArrayList<>(); for (Map.Entry entry : namedSelections.entrySet()) { String alias = entry.getKey(); Object value = entry.getValue(); if (value instanceof SelectionInfo) { SelectionInfo info = (SelectionInfo) value; Path path; if (fromMap.containsKey(info.entityClass)) { path = fromMap.get(info.entityClass).get(info.fieldName); selections.add(path.alias(alias)); } } } if (!selections.isEmpty()) { query.multiselect(selections); return; } } // 如果没有命名选择或处理失败,回退到原来的处理方式 if (selectedFields != null && !selectedFields.isEmpty()) { List> selections = new ArrayList<>(); // 添加根实体的选择字段 if (selectedFields.containsKey(rootClass)) { for (String field : selectedFields.get(rootClass)) { selections.add(root.get(field)); } } // 添加关联实体的选择字段 for (Class entityClass : selectedFields.keySet()) { if (entityClass != rootClass && fromMap.containsKey(entityClass)) { From from = fromMap.get(entityClass); for (String field : selectedFields.get(entityClass)) { selections.add(from.get(field)); } } } if (!selections.isEmpty()) { query.multiselect(selections); } } } private Predicate applyCondition(Condition condition, From from, CriteriaBuilder criteriaBuilder) { switch (condition.operator) { case EQUAL: return criteriaBuilder.equal(from.get(condition.fieldName), condition.value); // 可以添加更多操作符支持,如LIKE, IN, GREATER_THAN等 default: throw new UnsupportedOperationException("Unsupported operator: " + condition.operator); } } /** * 字段映射定义,用于DTO映射 */ public static class FieldMapping { private final String dtoFieldName; private final String entityFieldName; private final Class entityClass; private FieldMapping(String dtoFieldName, String entityFieldName, Class entityClass) { this.dtoFieldName = dtoFieldName; this.entityFieldName = entityFieldName; this.entityClass = entityClass; } public static FieldMapping of(PropertyFunction dtoFieldGetter, PropertyFunction entityFieldGetter) { String dtoFieldName = LambdaUtils.getPropertyName(dtoFieldGetter); String entityFieldName = LambdaUtils.getPropertyName(entityFieldGetter); // 通过反射获取entityClass Class entityClass = null; try { // 尝试从方法引用中提取类信息 Method writeReplace = entityFieldGetter.getClass().getDeclaredMethod("writeReplace"); writeReplace.setAccessible(true); Object serializedLambda = writeReplace.invoke(entityFieldGetter); String implClass = (String) serializedLambda.getClass().getMethod("getImplClass").invoke(serializedLambda); implClass = implClass.replace('/', '.'); entityClass = (Class) Class.forName(implClass); } catch (Exception e) { // 忽略异常,entityClass将在运行时确定 } return new FieldMapping<>(dtoFieldName, entityFieldName, entityClass); } public String getDtoFieldName() { return dtoFieldName; } public String getEntityFieldName() { return entityFieldName; } public Class getEntityClass() { return entityClass; } } /** * 连接信息 */ private static class JoinInfo { private final String sourceField; private final String targetField; JoinInfo(String sourceField, String targetField) { this.sourceField = sourceField; this.targetField = targetField; } } /** * 选择字段信息 */ private static class SelectionInfo { private final Class entityClass; private final String fieldName; SelectionInfo(Class entityClass, String fieldName) { this.entityClass = entityClass; this.fieldName = fieldName; } } private enum Operator { EQUAL // 可以添加更多操作符支持,如LIKE, IN, GREATER_THAN等 } private static class Condition { private final String fieldName; private final Object value; private final Operator operator; Condition(String fieldName, Object value, Operator operator) { this.fieldName = fieldName; this.value = value; this.operator = operator; } } }