JPA-Test/src/main/java/com/guwan/backend/MergeStrategy/JoinBuilder.java

570 lines
21 KiB
Java
Raw Normal View History

2025-04-17 22:38:25 +08:00
package com.guwan.backend.MergeStrategy;
import jakarta.persistence.EntityManager;
import jakarta.persistence.criteria.*;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.repository.query.QueryUtils;
import java.lang.reflect.Constructor;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* 自定义连表查询构建器
* 用于构建复杂的多表连接查询支持指定返回字段和条件
*/
public class JoinBuilder<T> {
private final Class<T> rootClass;
private final Map<Class<?>, JoinInfo> joins = new HashMap<>();
private final Map<Class<?>, List<Condition>> conditions = new HashMap<>();
private Map<Class<?>, String[]> selectedFields;
private Map<String, Object> namedSelections = new LinkedHashMap<>();
private JoinBuilder(Class<T> rootClass) {
this.rootClass = rootClass;
}
public static <T> JoinBuilder<T> from(Class<T> rootClass) {
return new JoinBuilder<>(rootClass);
}
/**
* 添加连接表
* @param joinClass 要连接的表实体类
* @param sourceField 主表的外键字段
* @param targetField 连接表的目标字段
*/
public JoinBuilder<T> join(Class<?> joinClass, String sourceField, String targetField) {
joins.put(joinClass, new JoinInfo(sourceField, targetField));
return this;
}
/**
* 添加连接表使用方法引用
* @param joinClass 要连接的表实体类
* @param sourceFieldGetter 主表外键字段的getter方法引用
* @param targetFieldGetter 连接表目标字段的getter方法引用
*/
public <J, V> JoinBuilder<T> join(Class<J> joinClass,
Function<T, V> sourceFieldGetter,
Function<J, V> targetFieldGetter) {
String sourceField = getFieldNameFromGetter(sourceFieldGetter);
String targetField = getFieldNameFromGetter(targetFieldGetter);
return join(joinClass, sourceField, targetField);
}
/**
* 添加等值条件
* @param entityClass 实体类
* @param fieldName 字段名
* @param value 字段值
*/
public <V> JoinBuilder<T> equal(Class<?> entityClass, String fieldName, V value) {
if (!conditions.containsKey(entityClass)) {
conditions.put(entityClass, new ArrayList<>());
}
conditions.get(entityClass).add(new Condition(fieldName, value, Operator.EQUAL));
return this;
}
/**
* 添加等值条件使用方法引用
* @param entityClass 实体类
* @param fieldGetter 字段的getter方法引用
* @param value 字段值
*/
public <E, V> JoinBuilder<T> equal(Class<E> entityClass, Function<E, V> fieldGetter, V value) {
String fieldName = getFieldNameFromGetter(fieldGetter);
return equal(entityClass, fieldName, value);
}
/**
* 选择返回的字段
* @param selectedFields 每个实体类需要返回的字段映射
*/
public JoinBuilder<T> selectFields(Map<Class<?>, String[]> selectedFields) {
this.selectedFields = selectedFields;
return this;
}
/**
* 选择返回的字段使用方法引用
* @param fieldGetters 方法引用列表
*/
@SafeVarargs
public final <E> JoinBuilder<T> select(Class<E> entityClass, Function<E, ?>... fieldGetters) {
if (selectedFields == null) {
selectedFields = new HashMap<>();
}
String[] fieldNames = Arrays.stream(fieldGetters)
.map(this::getFieldNameFromGetter)
.toArray(String[]::new);
selectedFields.put(entityClass, fieldNames);
return this;
}
/**
* 添加一个具有别名的字段选择优先级高于select方法
* @param alias 字段别名
* @param entityClass 实体类
* @param fieldGetter 字段的getter方法引用
*/
public <E, V> JoinBuilder<T> selectAs(String alias, Class<E> entityClass, Function<E, V> fieldGetter) {
String fieldName = getFieldNameFromGetter(fieldGetter);
// 存储字段信息稍后在configureSelections处理
if (selectedFields == null) {
selectedFields = new HashMap<>();
}
if (!selectedFields.containsKey(entityClass)) {
selectedFields.put(entityClass, new String[0]);
}
namedSelections.put(alias, new SelectionInfo(entityClass, fieldName));
return this;
}
/**
* 添加一个具有别名的字段选择字符串版本
* @param alias 字段别名
* @param entityClass 实体类
* @param fieldName 字段名称
*/
public <E> JoinBuilder<T> selectAs(String alias, Class<E> entityClass, String fieldName) {
// 存储字段信息稍后在configureSelections处理
if (selectedFields == null) {
selectedFields = new HashMap<>();
}
if (!selectedFields.containsKey(entityClass)) {
selectedFields.put(entityClass, new String[0]);
}
namedSelections.put(alias, new SelectionInfo(entityClass, fieldName));
return this;
}
/**
* 设置实体间字段映射关系的别名以便对应DTO字段
* @param dtoClass DTO类
* @param mappings 字段映射
*/
@SafeVarargs
public final <R> JoinBuilder<T> selectDto(Class<R> dtoClass, FieldMapping<R, ?>... mappings) {
for (FieldMapping<R, ?> mapping : mappings) {
String dtoFieldName = getFieldNameFromGetter(mapping.getDtoFieldGetter());
String entityFieldName = getFieldNameFromGetter(mapping.getEntityFieldGetter());
Class<?> entityClass = mapping.getEntityClass();
// 存储字段信息稍后在configureSelections处理
if (selectedFields == null) {
selectedFields = new HashMap<>();
}
if (!selectedFields.containsKey(entityClass)) {
selectedFields.put(entityClass, new String[0]);
}
namedSelections.put(dtoFieldName, new SelectionInfo(entityClass, entityFieldName));
}
return this;
}
/**
* 构建规范查询
* @return Specification<T> 查询规范
*/
public Specification<T> build() {
return (root, query, criteriaBuilder) -> {
// 创建连接关系
Map<Class<?>, From<?, ?>> joinMap = createJoins(root, query);
// 应用连表条件和其他条件
List<Predicate> predicates = applyAllConditions(root, joinMap, criteriaBuilder);
// 设置查询结果投影
configureSelections(root, joinMap, query);
return criteriaBuilder.and(predicates.toArray(new Predicate[0]));
};
}
/**
* 执行查询并返回结果
* @param entityManager EntityManager实例
* @return 查询结果列表
*/
public List<Object[]> execute(EntityManager entityManager) {
// 创建CriteriaBuilder
CriteriaBuilder criteriaBuilder = entityManager.getCriteriaBuilder();
// 创建CriteriaQuery
CriteriaQuery<Object[]> criteriaQuery = criteriaBuilder.createQuery(Object[].class);
// 设置根实体
Root<T> root = criteriaQuery.from(rootClass);
// 创建连接关系
Map<Class<?>, From<?, ?>> fromMap = createJoins(root, criteriaQuery);
// 应用所有条件,包括表之间的连接条件和额外查询条件
List<Predicate> predicates = applyAllConditions(root, fromMap, criteriaBuilder);
// 添加条件到查询
if (!predicates.isEmpty()) {
criteriaQuery.where(criteriaBuilder.and(predicates.toArray(new Predicate[0])));
}
// 设置查询结果投影
configureSelections(root, fromMap, criteriaQuery);
// 执行查询并返回结果
return entityManager.createQuery(criteriaQuery).getResultList();
}
/**
* 创建连接 - 修改为使用from而不是join
*/
private Map<Class<?>, From<?, ?>> createJoins(Root<T> root, CriteriaQuery<?> query) {
Map<Class<?>, From<?, ?>> fromMap = new HashMap<>();
fromMap.put(rootClass, root);
// 对于每个需要连接的表创建一个单独的From
for (Map.Entry<Class<?>, JoinInfo> entry : joins.entrySet()) {
Class<?> joinClass = entry.getKey();
From<?, ?> joinRoot = query.from(joinClass);
fromMap.put(joinClass, joinRoot);
}
return fromMap;
}
/**
* 应用所有条件包括表之间的连接条件和额外查询条件
*/
private List<Predicate> applyAllConditions(Root<T> root, Map<Class<?>, From<?, ?>> fromMap, CriteriaBuilder criteriaBuilder) {
List<Predicate> predicates = new ArrayList<>();
// 首先添加表连接条件
for (Map.Entry<Class<?>, JoinInfo> entry : joins.entrySet()) {
Class<?> joinClass = entry.getKey();
JoinInfo joinInfo = entry.getValue();
if (fromMap.containsKey(joinClass)) {
From<?, ?> joinFrom = fromMap.get(joinClass);
Predicate joinCondition = criteriaBuilder.equal(
root.get(joinInfo.sourceField),
joinFrom.get(joinInfo.targetField)
);
predicates.add(joinCondition);
}
}
// 然后添加普通查询条件
for (Class<?> entityClass : conditions.keySet()) {
for (Condition condition : conditions.get(entityClass)) {
if (fromMap.containsKey(entityClass)) {
From<?, ?> from = fromMap.get(entityClass);
predicates.add(applyCondition(condition, from, criteriaBuilder));
}
}
}
return predicates;
}
/**
* 执行查询并直接映射到DTO对象
* @param entityManager EntityManager实例
* @param dtoClass DTO类
* @param <R> DTO类型
* @return DTO对象列表
*/
public <R> List<R> executeAndMap(EntityManager entityManager, Class<R> dtoClass) {
List<Object[]> results = execute(entityManager);
return mapToDto(results, dtoClass);
}
/**
* 将查询结果映射到DTO对象
* @param results 查询结果
* @param dtoClass DTO类
* @param <R> DTO类型
* @return DTO对象列表
*/
public <R> List<R> mapToDto(List<Object[]> results, Class<R> dtoClass) {
try {
// 获取所有字段名(顺序与查询结果一致)
List<String> fieldNames = new ArrayList<>(namedSelections.keySet());
// 获取构造函数
Class<?>[] paramTypes = new Class[fieldNames.size()];
Arrays.fill(paramTypes, Object.class);
Constructor<R> constructor = dtoClass.getDeclaredConstructor(paramTypes);
// 映射结果
return results.stream()
.map(row -> {
try {
return constructor.newInstance(row);
} catch (Exception e) {
throw new RuntimeException("Failed to map result to DTO: " + e.getMessage(), e);
}
})
.collect(Collectors.toList());
} catch (Exception e) {
throw new RuntimeException("Failed to create DTO mapper: " + e.getMessage(), e);
}
}
/**
* 配置查询结果投影
*/
private void configureSelections(Root<T> root, Map<Class<?>, From<?, ?>> fromMap, CriteriaQuery<?> query) {
if (!namedSelections.isEmpty()) {
// 处理命名选择
List<Selection<?>> selections = new ArrayList<>();
for (Map.Entry<String, Object> entry : namedSelections.entrySet()) {
String alias = entry.getKey();
Object value = entry.getValue();
if (value instanceof SelectionInfo) {
SelectionInfo info = (SelectionInfo) value;
Path<?> path;
if (fromMap.containsKey(info.entityClass)) {
path = fromMap.get(info.entityClass).get(info.fieldName);
selections.add(path.alias(alias));
}
}
}
if (!selections.isEmpty()) {
query.multiselect(selections);
return;
}
}
// 如果没有命名选择或处理失败,回退到原来的处理方式
if (selectedFields != null && !selectedFields.isEmpty()) {
List<Selection<?>> selections = new ArrayList<>();
// 添加根实体的选择字段
if (selectedFields.containsKey(rootClass)) {
for (String field : selectedFields.get(rootClass)) {
selections.add(root.get(field));
}
}
// 添加关联实体的选择字段
for (Class<?> entityClass : selectedFields.keySet()) {
if (entityClass != rootClass && fromMap.containsKey(entityClass)) {
From<?, ?> from = fromMap.get(entityClass);
for (String field : selectedFields.get(entityClass)) {
selections.add(from.get(field));
}
}
}
if (!selections.isEmpty()) {
query.multiselect(selections);
}
}
}
private <X> Predicate applyCondition(Condition condition, From<?, X> from, CriteriaBuilder criteriaBuilder) {
switch (condition.operator) {
case EQUAL:
return criteriaBuilder.equal(from.get(condition.fieldName), condition.value);
// 可以添加更多操作符支持如LIKE, IN, GREATER_THAN等
default:
throw new UnsupportedOperationException("Unsupported operator: " + condition.operator);
}
}
/**
* 从getter方法引用中提取属性名
*/
private <E, R> String getFieldNameFromGetter(Function<E, R> getter) {
// 直接尝试从方法引用的toString中获取
try {
String methodRef = getter.toString();
if (methodRef.contains("::")) {
// 处理方法引用格式: com.example.Entity::getName
String methodName = methodRef.substring(methodRef.lastIndexOf("::") + 2);
return extractFieldNameFromMethod(methodName);
} else if (methodRef.contains("->")) {
// 处理Lambda表达式格式: (Entity e) -> e.getName()
String methodCall = methodRef.substring(methodRef.indexOf("->") + 2).trim();
String methodName = methodCall.substring(methodCall.lastIndexOf(".") + 1);
if (methodName.endsWith(")")) {
methodName = methodName.substring(0, methodName.indexOf("("));
}
return extractFieldNameFromMethod(methodName);
}
} catch (Exception ignored) {
// 忽略异常,尝试下一种方法
}
try {
String getterName = getter.getClass().getName();
if (getterName.contains("$$Lambda$")) {
// 处理Lambda表达式
if (getterName.contains("get")) {
// 尝试从名称中提取字段名
int getIndex = getterName.lastIndexOf("get");
if (getIndex >= 0 && getIndex + 3 < getterName.length()) {
String fieldName = getterName.substring(getIndex + 3);
if (fieldName.indexOf('$') > 0) {
fieldName = fieldName.substring(0, fieldName.indexOf('$'));
}
return fieldName.substring(0, 1).toLowerCase() + fieldName.substring(1);
}
}
}
} catch (Exception ignored) {
// 忽略异常fallback到基于字符串的方法
}
// Fallback到基于字符串的方法
String methodReference = getter.toString();
return getFieldNameFromGetterString(methodReference);
}
/**
* 从getter方法字符串中提取属性名
*/
private String getFieldNameFromGetterString(String methodReference) {
// 方法引用的格式通常是:类名::get字段名 或 Lambda表达式
if (methodReference.contains("::")) {
// 处理方法引用格式: com.example.Entity::getName
String methodName = methodReference.substring(methodReference.lastIndexOf("::") + 2);
return extractFieldNameFromMethod(methodName);
} else if (methodReference.contains("->")) {
// 处理Lambda表达式格式: (Entity e) -> e.getName()
String methodCall = methodReference.substring(methodReference.indexOf("->") + 2).trim();
String methodName = methodCall.substring(methodCall.lastIndexOf(".") + 1);
if (methodName.endsWith(")")) {
methodName = methodName.substring(0, methodName.indexOf("("));
}
return extractFieldNameFromMethod(methodName);
}
// 无法识别的格式,返回原始字符串
return methodReference;
}
/**
* 从方法名中提取字段名
*/
private String extractFieldNameFromMethod(String methodName) {
if (methodName.startsWith("get") && methodName.length() > 3) {
// getName -> name (首字母小写)
String fieldName = methodName.substring(3);
return fieldName.substring(0, 1).toLowerCase() + fieldName.substring(1);
} else if (methodName.startsWith("is") && methodName.length() > 2) {
// isActive -> active (首字母小写)
String fieldName = methodName.substring(2);
return fieldName.substring(0, 1).toLowerCase() + fieldName.substring(1);
}
return methodName;
}
/**
* 首字母大写
*/
private String capitalize(String str) {
if (str == null || str.isEmpty()) {
return str;
}
return str.substring(0, 1).toUpperCase() + str.substring(1);
}
private enum Operator {
EQUAL
// 可以添加更多操作符支持如LIKE, IN, GREATER_THAN等
}
private static class Condition {
private final String fieldName;
private final Object value;
private final Operator operator;
Condition(String fieldName, Object value, Operator operator) {
this.fieldName = fieldName;
this.value = value;
this.operator = operator;
}
}
/**
* 连接信息
*/
private static class JoinInfo {
private final String sourceField;
private final String targetField;
JoinInfo(String sourceField, String targetField) {
this.sourceField = sourceField;
this.targetField = targetField;
}
}
/**
* 选择字段信息
*/
private static class SelectionInfo {
private final Class<?> entityClass;
private final String fieldName;
SelectionInfo(Class<?> entityClass, String fieldName) {
this.entityClass = entityClass;
this.fieldName = fieldName;
}
}
/**
* 字段映射定义用于DTO映射
*/
public static class FieldMapping<D, E> {
private final Function<D, ?> dtoFieldGetter;
private final Function<E, ?> entityFieldGetter;
private final Class<E> entityClass;
private FieldMapping(Function<D, ?> dtoFieldGetter, Function<E, ?> entityFieldGetter, Class<E> entityClass) {
this.dtoFieldGetter = dtoFieldGetter;
this.entityFieldGetter = entityFieldGetter;
this.entityClass = entityClass;
}
public static <D, E, V> FieldMapping<D, E> of(Function<D, V> dtoFieldGetter, Function<E, V> entityFieldGetter) {
// 通过反射或其他方式获取entityClass
Class<E> entityClass = null;
try {
// 尝试从方法引用中提取类信息
String methodRef = entityFieldGetter.toString();
if (methodRef.contains("::")) {
String className = methodRef.substring(0, methodRef.indexOf("::"));
entityClass = (Class<E>) Class.forName(className);
}
} catch (Exception e) {
// 忽略异常entityClass将在运行时确定
}
return new FieldMapping<>(dtoFieldGetter, entityFieldGetter, entityClass);
}
public Function<D, ?> getDtoFieldGetter() {
return dtoFieldGetter;
}
public Function<E, ?> getEntityFieldGetter() {
return entityFieldGetter;
}
public Class<?> getEntityClass() {
return entityClass;
}
}
}