fix spring jpa proxy bug
This commit is contained in:
@@ -3,10 +3,7 @@ package com.synebula.gaea.jpa
|
|||||||
import com.synebula.gaea.data.date.DateTime
|
import com.synebula.gaea.data.date.DateTime
|
||||||
import com.synebula.gaea.query.Operator
|
import com.synebula.gaea.query.Operator
|
||||||
import com.synebula.gaea.query.Where
|
import com.synebula.gaea.query.Where
|
||||||
import jakarta.persistence.criteria.CriteriaBuilder
|
import jakarta.persistence.criteria.*
|
||||||
import jakarta.persistence.criteria.CriteriaQuery
|
|
||||||
import jakarta.persistence.criteria.Predicate
|
|
||||||
import jakarta.persistence.criteria.Root
|
|
||||||
import org.springframework.data.jpa.domain.Specification
|
import org.springframework.data.jpa.domain.Specification
|
||||||
import java.lang.reflect.Field
|
import java.lang.reflect.Field
|
||||||
import java.util.*
|
import java.util.*
|
||||||
@@ -69,79 +66,209 @@ fun String.tryToDigital(field: Field): Double {
|
|||||||
* @param clazz 类
|
* @param clazz 类
|
||||||
* @return Specification
|
* @return Specification
|
||||||
*/
|
*/
|
||||||
fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> {
|
fun Map<String, String>.toSpecification(clazz: Class<*>): Specification<*> {
|
||||||
val rangeStartSuffix = "[0]" //范围查询开始后缀
|
val rangeStartSuffix = "[0]" //范围查询开始后缀
|
||||||
val rangeEndSuffix = "[1]" //范围查询结束后缀
|
val rangeEndSuffix = "[1]" //范围查询结束后缀
|
||||||
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
|
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
|
||||||
val predicates: MutableList<Predicate> = ArrayList()
|
val predicates = mutableListOf<Predicate>()
|
||||||
for (argumentName in this!!.keys) {
|
for (argumentName in this.keys) {
|
||||||
if (this[argumentName] == null) continue
|
|
||||||
var fieldName = argumentName
|
|
||||||
var operator: Operator
|
|
||||||
|
|
||||||
// 判断是否为range类型(范围内查询)
|
|
||||||
var start = true
|
|
||||||
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
|
|
||||||
fieldName = fieldName.substring(fieldName.length - 3)
|
|
||||||
if (fieldName.endsWith(rangeEndSuffix)) start = false
|
|
||||||
}
|
|
||||||
val field = clazz.getDeclaredField(fieldName)
|
|
||||||
val where: Where = field.getDeclaredAnnotation(Where::class.java)
|
|
||||||
operator = where.operator
|
|
||||||
|
|
||||||
// 如果是范围内容, 判断是数值类型还是时间类型
|
|
||||||
if (operator === Operator.Range) {
|
|
||||||
if (clazz.getDeclaredField(fieldName).type != Date::class.java) {
|
|
||||||
operator = if (start) Operator.Gte else Operator.Lte
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var predicate: Predicate
|
|
||||||
var digitalValue: Double
|
|
||||||
try {
|
try {
|
||||||
|
var fieldName = argumentName
|
||||||
|
val fieldValue = this[argumentName]!!
|
||||||
|
var operator: Operator = Operator.Default
|
||||||
|
|
||||||
|
// 判断是否为range类型(范围内查询)
|
||||||
|
var start = true
|
||||||
|
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
|
||||||
|
fieldName = fieldName.substring(fieldName.length - 3)
|
||||||
|
if (fieldName.endsWith(rangeEndSuffix)) start = false
|
||||||
|
}
|
||||||
|
val fieldTree = fieldName.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
|
||||||
|
//查找是否是嵌入字段, 找到最深的类型
|
||||||
|
var field: Field
|
||||||
|
if (fieldTree.isNotEmpty()) {
|
||||||
|
var hostClass = clazz //需要查找字段所在的class
|
||||||
|
var i = 0
|
||||||
|
do {
|
||||||
|
field = hostClass.getDeclaredField(fieldTree[i])
|
||||||
|
hostClass = field.type
|
||||||
|
i++
|
||||||
|
} while (i < fieldTree.size)
|
||||||
|
} else {
|
||||||
|
field = clazz.getDeclaredField(fieldName)
|
||||||
|
}
|
||||||
|
val where = field.getDeclaredAnnotation(Where::class.java)
|
||||||
|
if (where != null) operator = where.operator
|
||||||
|
|
||||||
|
// 如果是范围内容, 判断是数值类型还是时间类型
|
||||||
|
if (operator === Operator.Range) {
|
||||||
|
if (field.type != Date::class.java) {
|
||||||
|
operator = if (start) Operator.Gte else Operator.Lte
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var predicate: Predicate
|
||||||
|
var digitalValue: Double
|
||||||
when (operator) {
|
when (operator) {
|
||||||
Operator.Ne -> predicate =
|
Operator.Ne -> predicate = criteriaBuilder.notEqual(
|
||||||
criteriaBuilder.notEqual(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
|
getFieldPath<Any>(root, fieldName),
|
||||||
|
typeConvert(field, fieldValue)
|
||||||
|
)
|
||||||
|
|
||||||
Operator.Lt -> {
|
Operator.Lt -> try {
|
||||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
digitalValue = parseDigital(field, fieldValue)
|
||||||
predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue)
|
predicate = criteriaBuilder.lessThan(getFieldPath(root, fieldName), digitalValue)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
throw RuntimeException(
|
||||||
|
String.format(
|
||||||
|
"class [%s] field [%s] can not use annotation Where(Operator.lt)",
|
||||||
|
field.declaringClass.name,
|
||||||
|
field.name
|
||||||
|
), e
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator.Gt -> {
|
Operator.Gt -> try {
|
||||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
digitalValue = parseDigital(field, fieldValue)
|
||||||
predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue)
|
predicate = criteriaBuilder.greaterThan(
|
||||||
|
getFieldPath(
|
||||||
|
root,
|
||||||
|
fieldName
|
||||||
|
), digitalValue
|
||||||
|
)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
throw RuntimeException(
|
||||||
|
String.format(
|
||||||
|
"class [%s] field [%s] can not use annotation Where(Operator.gt)",
|
||||||
|
field.declaringClass.name,
|
||||||
|
field.name
|
||||||
|
), e
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator.Lte -> {
|
Operator.Lte -> try {
|
||||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
digitalValue = parseDigital(field, fieldValue)
|
||||||
predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue)
|
predicate = criteriaBuilder.lessThanOrEqualTo(
|
||||||
|
getFieldPath(
|
||||||
|
root,
|
||||||
|
fieldName
|
||||||
|
), digitalValue
|
||||||
|
)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
throw RuntimeException(
|
||||||
|
String.format(
|
||||||
|
"class [%s] field [%s] can not use annotation Where(Operator.lte)",
|
||||||
|
field.declaringClass.name,
|
||||||
|
field.name
|
||||||
|
), e
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator.Gte -> {
|
Operator.Gte -> try {
|
||||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
digitalValue = parseDigital(field, fieldValue)
|
||||||
predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue)
|
predicate = criteriaBuilder.greaterThanOrEqualTo(
|
||||||
|
getFieldPath(
|
||||||
|
root,
|
||||||
|
fieldName
|
||||||
|
), digitalValue
|
||||||
|
)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
throw RuntimeException(
|
||||||
|
String.format(
|
||||||
|
"class [%s] field [%s] can not use annotation Where(Operator.gte)",
|
||||||
|
field.declaringClass.name,
|
||||||
|
field.name
|
||||||
|
), e
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%")
|
Operator.Like -> predicate = criteriaBuilder.like(
|
||||||
Operator.Range -> {
|
getFieldPath(root, fieldName),
|
||||||
predicate = if (start) {
|
String.format("%%%s%%", fieldValue)
|
||||||
criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
|
)
|
||||||
} else {
|
|
||||||
criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> predicate =
|
Operator.Range -> predicate = if (start) criteriaBuilder.greaterThanOrEqualTo(
|
||||||
criteriaBuilder.equal(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
|
getFieldPath(root, fieldName), this[argumentName]!!
|
||||||
|
) else criteriaBuilder.lessThanOrEqualTo(
|
||||||
|
getFieldPath(root, fieldName), this[argumentName]!!
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> predicate = criteriaBuilder.equal(
|
||||||
|
getFieldPath<Any>(root, fieldName),
|
||||||
|
typeConvert(field, fieldValue)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
predicates.add(predicate)
|
predicates.add(predicate)
|
||||||
} catch (e: NoSuchFieldException) {
|
} catch (e: NoSuchFieldException) {
|
||||||
throw Error(
|
throw RuntimeException(e)
|
||||||
"class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringJavaClass.simpleName}.${operator.name})]",
|
|
||||||
e
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
criteriaBuilder.and(*predicates.toTypedArray())
|
criteriaBuilder.and(*predicates.toTypedArray())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取字段在
|
||||||
|
*/
|
||||||
|
fun <Y> getFieldPath(root: Root<Any?>, field: String): Path<Y>? {
|
||||||
|
val fieldTree = field.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
|
||||||
|
var path: Path<Y>
|
||||||
|
if (fieldTree.isNotEmpty()) {
|
||||||
|
path = root.get(fieldTree[0])
|
||||||
|
for (i in 1 until fieldTree.size) {
|
||||||
|
path = path.get(fieldTree[i])
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return root.get(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类型转换
|
||||||
|
*
|
||||||
|
* @param field 字段对象
|
||||||
|
* @param value 值
|
||||||
|
* @return object
|
||||||
|
*/
|
||||||
|
fun typeConvert(field: Field, value: String): Any {
|
||||||
|
var result: Any = value
|
||||||
|
val fieldType = field.type
|
||||||
|
if (fieldType != value.javaClass) {
|
||||||
|
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) {
|
||||||
|
result = value.toInt()
|
||||||
|
}
|
||||||
|
if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) {
|
||||||
|
result = value.toDouble()
|
||||||
|
}
|
||||||
|
if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
|
||||||
|
result = value.toFloat()
|
||||||
|
}
|
||||||
|
if (Date::class.java == fieldType) {
|
||||||
|
result = DateTime(value, "yyyy-MM-dd HH:mm:ss").date
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化数值类型
|
||||||
|
*
|
||||||
|
* @param field 字段对象
|
||||||
|
* @param value 值
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
fun parseDigital(field: Field, value: String): Double {
|
||||||
|
val result: Double
|
||||||
|
val fieldType = field.type
|
||||||
|
result =
|
||||||
|
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType || Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType || Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
|
||||||
|
value.toDouble()
|
||||||
|
} else throw java.lang.RuntimeException(
|
||||||
|
String.format(
|
||||||
|
"class [%s] field [%s] is not digital",
|
||||||
|
field.declaringClass.name,
|
||||||
|
field.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
|
|||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun mappingArguments(args: Array<Any>): Array<Any> {
|
override fun mappingArguments(args: Array<Any>): Array<Any> {
|
||||||
val params = args[0] as Map<String, String>?
|
val params = args[0] as Map<String, String>?
|
||||||
val specification = params.toSpecification(entityClazz)
|
val specification = params?.toSpecification(entityClazz)
|
||||||
return arrayOf(specification)
|
return if (specification != null) arrayOf(specification) else arrayOf()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun mappingResult(result: Any): Any {
|
override fun mappingResult(result: Any): Any {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import org.springframework.data.domain.PageRequest
|
|||||||
import org.springframework.data.domain.Pageable
|
import org.springframework.data.domain.Pageable
|
||||||
import org.springframework.data.domain.Sort
|
import org.springframework.data.domain.Sort
|
||||||
import java.util.*
|
import java.util.*
|
||||||
|
import com.synebula.gaea.query.Page as QueryPage
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 分页方法参数映射
|
* 分页方法参数映射
|
||||||
@@ -29,8 +30,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
|
|||||||
val fields = entityClazz.declaredFields
|
val fields = entityClazz.declaredFields
|
||||||
for (field in fields) {
|
for (field in fields) {
|
||||||
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
|
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
|
||||||
(annotation.annotationClass.java == Id::class.java
|
(annotation.annotationClass.java == Id::class.java || annotation.annotationClass.java == EmbeddedId::class.java)
|
||||||
|| annotation.annotationClass.java == EmbeddedId::class.java)
|
|
||||||
}
|
}
|
||||||
if (isId) {
|
if (isId) {
|
||||||
sort = Sort.by(Sort.Direction.ASC, field.name)
|
sort = Sort.by(Sort.Direction.ASC, field.name)
|
||||||
@@ -50,7 +50,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
|
|||||||
override fun mappingResult(result: Any): Any {
|
override fun mappingResult(result: Any): Any {
|
||||||
val page = result as Page<*>
|
val page = result as Page<*>
|
||||||
|
|
||||||
// Page 页面从0开始
|
// Page 页面从0开始 [com.synebula.gaea.query.Page as QueryPage]
|
||||||
return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content)
|
return QueryPage(page.number + 1, page.size, page.totalElements.toInt(), page.content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user