diff --git a/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/Jpas.kt b/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/Jpas.kt index f9a2905..1346eb2 100644 --- a/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/Jpas.kt +++ b/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/Jpas.kt @@ -3,10 +3,7 @@ package com.synebula.gaea.jpa import com.synebula.gaea.data.date.DateTime import com.synebula.gaea.query.Operator import com.synebula.gaea.query.Where -import jakarta.persistence.criteria.CriteriaBuilder -import jakarta.persistence.criteria.CriteriaQuery -import jakarta.persistence.criteria.Predicate -import jakarta.persistence.criteria.Root +import jakarta.persistence.criteria.* import org.springframework.data.jpa.domain.Specification import java.lang.reflect.Field import java.util.* @@ -69,79 +66,209 @@ fun String.tryToDigital(field: Field): Double { * @param clazz 类 * @return Specification */ -fun Map?.toSpecification(clazz: Class<*>): Specification<*> { +fun Map.toSpecification(clazz: Class<*>): Specification<*> { val rangeStartSuffix = "[0]" //范围查询开始后缀 val rangeEndSuffix = "[1]" //范围查询结束后缀 return Specification { root: Root, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder -> - val predicates: MutableList = ArrayList() - for (argumentName in this!!.keys) { - if (this[argumentName] == null) continue - var fieldName = argumentName - var operator: Operator - - // 判断是否为range类型(范围内查询) - var start = true - if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) { - fieldName = fieldName.substring(fieldName.length - 3) - if (fieldName.endsWith(rangeEndSuffix)) start = false - } - val field = clazz.getDeclaredField(fieldName) - val where: Where = field.getDeclaredAnnotation(Where::class.java) - operator = where.operator - - // 如果是范围内容, 判断是数值类型还是时间类型 - if (operator === Operator.Range) { - if (clazz.getDeclaredField(fieldName).type != Date::class.java) { - operator = if (start) Operator.Gte else Operator.Lte - } - } - var predicate: Predicate - var digitalValue: Double + val predicates = mutableListOf() + for (argumentName in this.keys) { try { + var fieldName = argumentName + val fieldValue = this[argumentName]!! + var operator: Operator = Operator.Default + + // 判断是否为range类型(范围内查询) + var start = true + if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) { + fieldName = fieldName.substring(fieldName.length - 3) + if (fieldName.endsWith(rangeEndSuffix)) start = false + } + val fieldTree = fieldName.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() + //查找是否是嵌入字段, 找到最深的类型 + var field: Field + if (fieldTree.isNotEmpty()) { + var hostClass = clazz //需要查找字段所在的class + var i = 0 + do { + field = hostClass.getDeclaredField(fieldTree[i]) + hostClass = field.type + i++ + } while (i < fieldTree.size) + } else { + field = clazz.getDeclaredField(fieldName) + } + val where = field.getDeclaredAnnotation(Where::class.java) + if (where != null) operator = where.operator + + // 如果是范围内容, 判断是数值类型还是时间类型 + if (operator === Operator.Range) { + if (field.type != Date::class.java) { + operator = if (start) Operator.Gte else Operator.Lte + } + } + var predicate: Predicate + var digitalValue: Double when (operator) { - Operator.Ne -> predicate = - criteriaBuilder.notEqual(root.get(fieldName), this[fieldName]!!.toFieldType(field)) + Operator.Ne -> predicate = criteriaBuilder.notEqual( + getFieldPath(root, fieldName), + typeConvert(field, fieldValue) + ) - Operator.Lt -> { - digitalValue = this[fieldName]!!.tryToDigital(field) - predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue) + Operator.Lt -> try { + digitalValue = parseDigital(field, fieldValue) + predicate = criteriaBuilder.lessThan(getFieldPath(root, fieldName), digitalValue) + } catch (e: Exception) { + throw RuntimeException( + String.format( + "class [%s] field [%s] can not use annotation Where(Operator.lt)", + field.declaringClass.name, + field.name + ), e + ) } - Operator.Gt -> { - digitalValue = this[fieldName]!!.tryToDigital(field) - predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue) + Operator.Gt -> try { + digitalValue = parseDigital(field, fieldValue) + predicate = criteriaBuilder.greaterThan( + getFieldPath( + root, + fieldName + ), digitalValue + ) + } catch (e: Exception) { + throw RuntimeException( + String.format( + "class [%s] field [%s] can not use annotation Where(Operator.gt)", + field.declaringClass.name, + field.name + ), e + ) } - Operator.Lte -> { - digitalValue = this[fieldName]!!.tryToDigital(field) - predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue) + Operator.Lte -> try { + digitalValue = parseDigital(field, fieldValue) + predicate = criteriaBuilder.lessThanOrEqualTo( + getFieldPath( + root, + fieldName + ), digitalValue + ) + } catch (e: Exception) { + throw RuntimeException( + String.format( + "class [%s] field [%s] can not use annotation Where(Operator.lte)", + field.declaringClass.name, + field.name + ), e + ) } - Operator.Gte -> { - digitalValue = this[fieldName]!!.tryToDigital(field) - predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue) + Operator.Gte -> try { + digitalValue = parseDigital(field, fieldValue) + predicate = criteriaBuilder.greaterThanOrEqualTo( + getFieldPath( + root, + fieldName + ), digitalValue + ) + } catch (e: Exception) { + throw RuntimeException( + String.format( + "class [%s] field [%s] can not use annotation Where(Operator.gte)", + field.declaringClass.name, + field.name + ), e + ) } - Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%") - Operator.Range -> { - predicate = if (start) { - criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!) - } else { - criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!) - } - } + Operator.Like -> predicate = criteriaBuilder.like( + getFieldPath(root, fieldName), + String.format("%%%s%%", fieldValue) + ) - else -> predicate = - criteriaBuilder.equal(root.get(fieldName), this[fieldName]!!.toFieldType(field)) + Operator.Range -> predicate = if (start) criteriaBuilder.greaterThanOrEqualTo( + getFieldPath(root, fieldName), this[argumentName]!! + ) else criteriaBuilder.lessThanOrEqualTo( + getFieldPath(root, fieldName), this[argumentName]!! + ) + + else -> predicate = criteriaBuilder.equal( + getFieldPath(root, fieldName), + typeConvert(field, fieldValue) + ) } predicates.add(predicate) } catch (e: NoSuchFieldException) { - throw Error( - "class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringJavaClass.simpleName}.${operator.name})]", - e - ) + throw RuntimeException(e) } } criteriaBuilder.and(*predicates.toTypedArray()) } -} \ No newline at end of file +} + +/** + * 获取字段在 + */ +fun getFieldPath(root: Root, field: String): Path? { + val fieldTree = field.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() + var path: Path + if (fieldTree.isNotEmpty()) { + path = root.get(fieldTree[0]) + for (i in 1 until fieldTree.size) { + path = path.get(fieldTree[i]) + } + return path + } + return root.get(field) +} + + +/** + * 类型转换 + * + * @param field 字段对象 + * @param value 值 + * @return object + */ +fun typeConvert(field: Field, value: String): Any { + var result: Any = value + val fieldType = field.type + if (fieldType != value.javaClass) { + if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) { + result = value.toInt() + } + if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) { + result = value.toDouble() + } + if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) { + result = value.toFloat() + } + if (Date::class.java == fieldType) { + result = DateTime(value, "yyyy-MM-dd HH:mm:ss").date + } + } + return result +} + +/** + * 格式化数值类型 + * + * @param field 字段对象 + * @param value 值 + * @return double + */ +fun parseDigital(field: Field, value: String): Double { + val result: Double + val fieldType = field.type + result = + if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType || Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType || Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) { + value.toDouble() + } else throw java.lang.RuntimeException( + String.format( + "class [%s] field [%s] is not digital", + field.declaringClass.name, + field.name + ) + ) + return result +} diff --git a/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/FindMethodResolver.kt b/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/FindMethodResolver.kt index cb5f73f..817160b 100644 --- a/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/FindMethodResolver.kt +++ b/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/FindMethodResolver.kt @@ -10,8 +10,8 @@ class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe @Suppress("UNCHECKED_CAST") override fun mappingArguments(args: Array): Array { val params = args[0] as Map? - val specification = params.toSpecification(entityClazz) - return arrayOf(specification) + val specification = params?.toSpecification(entityClazz) + return if (specification != null) arrayOf(specification) else arrayOf() } override fun mappingResult(result: Any): Any { diff --git a/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/PageMethodResolver.kt b/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/PageMethodResolver.kt index 2675c32..c02aedb 100644 --- a/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/PageMethodResolver.kt +++ b/src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/proxy/method/resolver/PageMethodResolver.kt @@ -10,6 +10,7 @@ import org.springframework.data.domain.PageRequest import org.springframework.data.domain.Pageable import org.springframework.data.domain.Sort import java.util.* +import com.synebula.gaea.query.Page as QueryPage /** * 分页方法参数映射 @@ -29,8 +30,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe val fields = entityClazz.declaredFields for (field in fields) { val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation -> - (annotation.annotationClass.java == Id::class.java - || annotation.annotationClass.java == EmbeddedId::class.java) + (annotation.annotationClass.java == Id::class.java || annotation.annotationClass.java == EmbeddedId::class.java) } if (isId) { sort = Sort.by(Sort.Direction.ASC, field.name) @@ -50,7 +50,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe override fun mappingResult(result: Any): Any { val page = result as Page<*> - // Page 页面从0开始 - return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content) + // Page 页面从0开始 [com.synebula.gaea.query.Page as QueryPage] + return QueryPage(page.number + 1, page.size, page.totalElements.toInt(), page.content) } } \ No newline at end of file