fix spring jpa proxy bug

This commit is contained in:
2023-04-18 10:50:47 +08:00
parent 230ceea0fa
commit eff39eb7f8
3 changed files with 191 additions and 64 deletions

View File

@@ -3,10 +3,7 @@ package com.synebula.gaea.jpa
import com.synebula.gaea.data.date.DateTime
import com.synebula.gaea.query.Operator
import com.synebula.gaea.query.Where
import jakarta.persistence.criteria.CriteriaBuilder
import jakarta.persistence.criteria.CriteriaQuery
import jakarta.persistence.criteria.Predicate
import jakarta.persistence.criteria.Root
import jakarta.persistence.criteria.*
import org.springframework.data.jpa.domain.Specification
import java.lang.reflect.Field
import java.util.*
@@ -69,15 +66,16 @@ fun String.tryToDigital(field: Field): Double {
* @param clazz 类
* @return Specification
*/
fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> {
fun Map<String, String>.toSpecification(clazz: Class<*>): Specification<*> {
val rangeStartSuffix = "[0]" //范围查询开始后缀
val rangeEndSuffix = "[1]" //范围查询结束后缀
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
val predicates: MutableList<Predicate> = ArrayList()
for (argumentName in this!!.keys) {
if (this[argumentName] == null) continue
val predicates = mutableListOf<Predicate>()
for (argumentName in this.keys) {
try {
var fieldName = argumentName
var operator: Operator
val fieldValue = this[argumentName]!!
var operator: Operator = Operator.Default
// 判断是否为range类型(范围内查询)
var start = true
@@ -85,63 +83,192 @@ fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> {
fieldName = fieldName.substring(fieldName.length - 3)
if (fieldName.endsWith(rangeEndSuffix)) start = false
}
val field = clazz.getDeclaredField(fieldName)
val where: Where = field.getDeclaredAnnotation(Where::class.java)
operator = where.operator
val fieldTree = fieldName.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
//查找是否是嵌入字段, 找到最深的类型
var field: Field
if (fieldTree.isNotEmpty()) {
var hostClass = clazz //需要查找字段所在的class
var i = 0
do {
field = hostClass.getDeclaredField(fieldTree[i])
hostClass = field.type
i++
} while (i < fieldTree.size)
} else {
field = clazz.getDeclaredField(fieldName)
}
val where = field.getDeclaredAnnotation(Where::class.java)
if (where != null) operator = where.operator
// 如果是范围内容, 判断是数值类型还是时间类型
if (operator === Operator.Range) {
if (clazz.getDeclaredField(fieldName).type != Date::class.java) {
if (field.type != Date::class.java) {
operator = if (start) Operator.Gte else Operator.Lte
}
}
var predicate: Predicate
var digitalValue: Double
try {
when (operator) {
Operator.Ne -> predicate =
criteriaBuilder.notEqual(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
Operator.Ne -> predicate = criteriaBuilder.notEqual(
getFieldPath<Any>(root, fieldName),
typeConvert(field, fieldValue)
)
Operator.Lt -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue)
Operator.Lt -> try {
digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.lessThan(getFieldPath(root, fieldName), digitalValue)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.lt)",
field.declaringClass.name,
field.name
), e
)
}
Operator.Gt -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue)
Operator.Gt -> try {
digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.greaterThan(
getFieldPath(
root,
fieldName
), digitalValue
)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.gt)",
field.declaringClass.name,
field.name
), e
)
}
Operator.Lte -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue)
Operator.Lte -> try {
digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.lessThanOrEqualTo(
getFieldPath(
root,
fieldName
), digitalValue
)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.lte)",
field.declaringClass.name,
field.name
), e
)
}
Operator.Gte -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue)
Operator.Gte -> try {
digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.greaterThanOrEqualTo(
getFieldPath(
root,
fieldName
), digitalValue
)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.gte)",
field.declaringClass.name,
field.name
), e
)
}
Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%")
Operator.Range -> {
predicate = if (start) {
criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
} else {
criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
}
}
Operator.Like -> predicate = criteriaBuilder.like(
getFieldPath(root, fieldName),
String.format("%%%s%%", fieldValue)
)
else -> predicate =
criteriaBuilder.equal(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
Operator.Range -> predicate = if (start) criteriaBuilder.greaterThanOrEqualTo(
getFieldPath(root, fieldName), this[argumentName]!!
) else criteriaBuilder.lessThanOrEqualTo(
getFieldPath(root, fieldName), this[argumentName]!!
)
else -> predicate = criteriaBuilder.equal(
getFieldPath<Any>(root, fieldName),
typeConvert(field, fieldValue)
)
}
predicates.add(predicate)
} catch (e: NoSuchFieldException) {
throw Error(
"class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringJavaClass.simpleName}.${operator.name})]",
e
)
throw RuntimeException(e)
}
}
criteriaBuilder.and(*predicates.toTypedArray())
}
}
/**
* 获取字段在
*/
fun <Y> getFieldPath(root: Root<Any?>, field: String): Path<Y>? {
val fieldTree = field.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
var path: Path<Y>
if (fieldTree.isNotEmpty()) {
path = root.get(fieldTree[0])
for (i in 1 until fieldTree.size) {
path = path.get(fieldTree[i])
}
return path
}
return root.get(field)
}
/**
* 类型转换
*
* @param field 字段对象
* @param value 值
* @return object
*/
fun typeConvert(field: Field, value: String): Any {
var result: Any = value
val fieldType = field.type
if (fieldType != value.javaClass) {
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) {
result = value.toInt()
}
if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) {
result = value.toDouble()
}
if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
result = value.toFloat()
}
if (Date::class.java == fieldType) {
result = DateTime(value, "yyyy-MM-dd HH:mm:ss").date
}
}
return result
}
/**
* 格式化数值类型
*
* @param field 字段对象
* @param value 值
* @return double
*/
fun parseDigital(field: Field, value: String): Double {
val result: Double
val fieldType = field.type
result =
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType || Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType || Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
value.toDouble()
} else throw java.lang.RuntimeException(
String.format(
"class [%s] field [%s] is not digital",
field.declaringClass.name,
field.name
)
)
return result
}

View File

@@ -10,8 +10,8 @@ class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
@Suppress("UNCHECKED_CAST")
override fun mappingArguments(args: Array<Any>): Array<Any> {
val params = args[0] as Map<String, String>?
val specification = params.toSpecification(entityClazz)
return arrayOf(specification)
val specification = params?.toSpecification(entityClazz)
return if (specification != null) arrayOf(specification) else arrayOf()
}
override fun mappingResult(result: Any): Any {

View File

@@ -10,6 +10,7 @@ import org.springframework.data.domain.PageRequest
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import java.util.*
import com.synebula.gaea.query.Page as QueryPage
/**
* 分页方法参数映射
@@ -29,8 +30,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
val fields = entityClazz.declaredFields
for (field in fields) {
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
(annotation.annotationClass.java == Id::class.java
|| annotation.annotationClass.java == EmbeddedId::class.java)
(annotation.annotationClass.java == Id::class.java || annotation.annotationClass.java == EmbeddedId::class.java)
}
if (isId) {
sort = Sort.by(Sort.Direction.ASC, field.name)
@@ -50,7 +50,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
override fun mappingResult(result: Any): Any {
val page = result as Page<*>
// Page 页面从0开始
return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content)
// Page 页面从0开始 [com.synebula.gaea.query.Page as QueryPage]
return QueryPage(page.number + 1, page.size, page.totalElements.toInt(), page.content)
}
}