fix spring jpa proxy bug
This commit is contained in:
@@ -3,10 +3,7 @@ package com.synebula.gaea.jpa
|
||||
import com.synebula.gaea.data.date.DateTime
|
||||
import com.synebula.gaea.query.Operator
|
||||
import com.synebula.gaea.query.Where
|
||||
import jakarta.persistence.criteria.CriteriaBuilder
|
||||
import jakarta.persistence.criteria.CriteriaQuery
|
||||
import jakarta.persistence.criteria.Predicate
|
||||
import jakarta.persistence.criteria.Root
|
||||
import jakarta.persistence.criteria.*
|
||||
import org.springframework.data.jpa.domain.Specification
|
||||
import java.lang.reflect.Field
|
||||
import java.util.*
|
||||
@@ -69,79 +66,209 @@ fun String.tryToDigital(field: Field): Double {
|
||||
* @param clazz 类
|
||||
* @return Specification
|
||||
*/
|
||||
fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> {
|
||||
fun Map<String, String>.toSpecification(clazz: Class<*>): Specification<*> {
|
||||
val rangeStartSuffix = "[0]" //范围查询开始后缀
|
||||
val rangeEndSuffix = "[1]" //范围查询结束后缀
|
||||
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
|
||||
val predicates: MutableList<Predicate> = ArrayList()
|
||||
for (argumentName in this!!.keys) {
|
||||
if (this[argumentName] == null) continue
|
||||
var fieldName = argumentName
|
||||
var operator: Operator
|
||||
|
||||
// 判断是否为range类型(范围内查询)
|
||||
var start = true
|
||||
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
|
||||
fieldName = fieldName.substring(fieldName.length - 3)
|
||||
if (fieldName.endsWith(rangeEndSuffix)) start = false
|
||||
}
|
||||
val field = clazz.getDeclaredField(fieldName)
|
||||
val where: Where = field.getDeclaredAnnotation(Where::class.java)
|
||||
operator = where.operator
|
||||
|
||||
// 如果是范围内容, 判断是数值类型还是时间类型
|
||||
if (operator === Operator.Range) {
|
||||
if (clazz.getDeclaredField(fieldName).type != Date::class.java) {
|
||||
operator = if (start) Operator.Gte else Operator.Lte
|
||||
}
|
||||
}
|
||||
var predicate: Predicate
|
||||
var digitalValue: Double
|
||||
val predicates = mutableListOf<Predicate>()
|
||||
for (argumentName in this.keys) {
|
||||
try {
|
||||
var fieldName = argumentName
|
||||
val fieldValue = this[argumentName]!!
|
||||
var operator: Operator = Operator.Default
|
||||
|
||||
// 判断是否为range类型(范围内查询)
|
||||
var start = true
|
||||
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
|
||||
fieldName = fieldName.substring(fieldName.length - 3)
|
||||
if (fieldName.endsWith(rangeEndSuffix)) start = false
|
||||
}
|
||||
val fieldTree = fieldName.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
|
||||
//查找是否是嵌入字段, 找到最深的类型
|
||||
var field: Field
|
||||
if (fieldTree.isNotEmpty()) {
|
||||
var hostClass = clazz //需要查找字段所在的class
|
||||
var i = 0
|
||||
do {
|
||||
field = hostClass.getDeclaredField(fieldTree[i])
|
||||
hostClass = field.type
|
||||
i++
|
||||
} while (i < fieldTree.size)
|
||||
} else {
|
||||
field = clazz.getDeclaredField(fieldName)
|
||||
}
|
||||
val where = field.getDeclaredAnnotation(Where::class.java)
|
||||
if (where != null) operator = where.operator
|
||||
|
||||
// 如果是范围内容, 判断是数值类型还是时间类型
|
||||
if (operator === Operator.Range) {
|
||||
if (field.type != Date::class.java) {
|
||||
operator = if (start) Operator.Gte else Operator.Lte
|
||||
}
|
||||
}
|
||||
var predicate: Predicate
|
||||
var digitalValue: Double
|
||||
when (operator) {
|
||||
Operator.Ne -> predicate =
|
||||
criteriaBuilder.notEqual(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
|
||||
Operator.Ne -> predicate = criteriaBuilder.notEqual(
|
||||
getFieldPath<Any>(root, fieldName),
|
||||
typeConvert(field, fieldValue)
|
||||
)
|
||||
|
||||
Operator.Lt -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue)
|
||||
Operator.Lt -> try {
|
||||
digitalValue = parseDigital(field, fieldValue)
|
||||
predicate = criteriaBuilder.lessThan(getFieldPath(root, fieldName), digitalValue)
|
||||
} catch (e: Exception) {
|
||||
throw RuntimeException(
|
||||
String.format(
|
||||
"class [%s] field [%s] can not use annotation Where(Operator.lt)",
|
||||
field.declaringClass.name,
|
||||
field.name
|
||||
), e
|
||||
)
|
||||
}
|
||||
|
||||
Operator.Gt -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue)
|
||||
Operator.Gt -> try {
|
||||
digitalValue = parseDigital(field, fieldValue)
|
||||
predicate = criteriaBuilder.greaterThan(
|
||||
getFieldPath(
|
||||
root,
|
||||
fieldName
|
||||
), digitalValue
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
throw RuntimeException(
|
||||
String.format(
|
||||
"class [%s] field [%s] can not use annotation Where(Operator.gt)",
|
||||
field.declaringClass.name,
|
||||
field.name
|
||||
), e
|
||||
)
|
||||
}
|
||||
|
||||
Operator.Lte -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue)
|
||||
Operator.Lte -> try {
|
||||
digitalValue = parseDigital(field, fieldValue)
|
||||
predicate = criteriaBuilder.lessThanOrEqualTo(
|
||||
getFieldPath(
|
||||
root,
|
||||
fieldName
|
||||
), digitalValue
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
throw RuntimeException(
|
||||
String.format(
|
||||
"class [%s] field [%s] can not use annotation Where(Operator.lte)",
|
||||
field.declaringClass.name,
|
||||
field.name
|
||||
), e
|
||||
)
|
||||
}
|
||||
|
||||
Operator.Gte -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue)
|
||||
Operator.Gte -> try {
|
||||
digitalValue = parseDigital(field, fieldValue)
|
||||
predicate = criteriaBuilder.greaterThanOrEqualTo(
|
||||
getFieldPath(
|
||||
root,
|
||||
fieldName
|
||||
), digitalValue
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
throw RuntimeException(
|
||||
String.format(
|
||||
"class [%s] field [%s] can not use annotation Where(Operator.gte)",
|
||||
field.declaringClass.name,
|
||||
field.name
|
||||
), e
|
||||
)
|
||||
}
|
||||
|
||||
Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%")
|
||||
Operator.Range -> {
|
||||
predicate = if (start) {
|
||||
criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
|
||||
} else {
|
||||
criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
|
||||
}
|
||||
}
|
||||
Operator.Like -> predicate = criteriaBuilder.like(
|
||||
getFieldPath(root, fieldName),
|
||||
String.format("%%%s%%", fieldValue)
|
||||
)
|
||||
|
||||
else -> predicate =
|
||||
criteriaBuilder.equal(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
|
||||
Operator.Range -> predicate = if (start) criteriaBuilder.greaterThanOrEqualTo(
|
||||
getFieldPath(root, fieldName), this[argumentName]!!
|
||||
) else criteriaBuilder.lessThanOrEqualTo(
|
||||
getFieldPath(root, fieldName), this[argumentName]!!
|
||||
)
|
||||
|
||||
else -> predicate = criteriaBuilder.equal(
|
||||
getFieldPath<Any>(root, fieldName),
|
||||
typeConvert(field, fieldValue)
|
||||
)
|
||||
}
|
||||
predicates.add(predicate)
|
||||
} catch (e: NoSuchFieldException) {
|
||||
throw Error(
|
||||
"class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringJavaClass.simpleName}.${operator.name})]",
|
||||
e
|
||||
)
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
criteriaBuilder.and(*predicates.toTypedArray())
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取字段在
|
||||
*/
|
||||
fun <Y> getFieldPath(root: Root<Any?>, field: String): Path<Y>? {
|
||||
val fieldTree = field.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
|
||||
var path: Path<Y>
|
||||
if (fieldTree.isNotEmpty()) {
|
||||
path = root.get(fieldTree[0])
|
||||
for (i in 1 until fieldTree.size) {
|
||||
path = path.get(fieldTree[i])
|
||||
}
|
||||
return path
|
||||
}
|
||||
return root.get(field)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 类型转换
|
||||
*
|
||||
* @param field 字段对象
|
||||
* @param value 值
|
||||
* @return object
|
||||
*/
|
||||
fun typeConvert(field: Field, value: String): Any {
|
||||
var result: Any = value
|
||||
val fieldType = field.type
|
||||
if (fieldType != value.javaClass) {
|
||||
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) {
|
||||
result = value.toInt()
|
||||
}
|
||||
if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) {
|
||||
result = value.toDouble()
|
||||
}
|
||||
if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
|
||||
result = value.toFloat()
|
||||
}
|
||||
if (Date::class.java == fieldType) {
|
||||
result = DateTime(value, "yyyy-MM-dd HH:mm:ss").date
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化数值类型
|
||||
*
|
||||
* @param field 字段对象
|
||||
* @param value 值
|
||||
* @return double
|
||||
*/
|
||||
fun parseDigital(field: Field, value: String): Double {
|
||||
val result: Double
|
||||
val fieldType = field.type
|
||||
result =
|
||||
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType || Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType || Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
|
||||
value.toDouble()
|
||||
} else throw java.lang.RuntimeException(
|
||||
String.format(
|
||||
"class [%s] field [%s] is not digital",
|
||||
field.declaringClass.name,
|
||||
field.name
|
||||
)
|
||||
)
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun mappingArguments(args: Array<Any>): Array<Any> {
|
||||
val params = args[0] as Map<String, String>?
|
||||
val specification = params.toSpecification(entityClazz)
|
||||
return arrayOf(specification)
|
||||
val specification = params?.toSpecification(entityClazz)
|
||||
return if (specification != null) arrayOf(specification) else arrayOf()
|
||||
}
|
||||
|
||||
override fun mappingResult(result: Any): Any {
|
||||
|
||||
@@ -10,6 +10,7 @@ import org.springframework.data.domain.PageRequest
|
||||
import org.springframework.data.domain.Pageable
|
||||
import org.springframework.data.domain.Sort
|
||||
import java.util.*
|
||||
import com.synebula.gaea.query.Page as QueryPage
|
||||
|
||||
/**
|
||||
* 分页方法参数映射
|
||||
@@ -29,8 +30,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
|
||||
val fields = entityClazz.declaredFields
|
||||
for (field in fields) {
|
||||
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
|
||||
(annotation.annotationClass.java == Id::class.java
|
||||
|| annotation.annotationClass.java == EmbeddedId::class.java)
|
||||
(annotation.annotationClass.java == Id::class.java || annotation.annotationClass.java == EmbeddedId::class.java)
|
||||
}
|
||||
if (isId) {
|
||||
sort = Sort.by(Sort.Direction.ASC, field.name)
|
||||
@@ -50,7 +50,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
|
||||
override fun mappingResult(result: Any): Any {
|
||||
val page = result as Page<*>
|
||||
|
||||
// Page 页面从0开始
|
||||
return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content)
|
||||
// Page 页面从0开始 [com.synebula.gaea.query.Page as QueryPage]
|
||||
return QueryPage(page.number + 1, page.size, page.totalElements.toInt(), page.content)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user