fix spring jpa proxy bug

This commit is contained in:
2023-04-18 10:50:47 +08:00
parent 230ceea0fa
commit eff39eb7f8
3 changed files with 191 additions and 64 deletions

View File

@@ -3,10 +3,7 @@ package com.synebula.gaea.jpa
import com.synebula.gaea.data.date.DateTime import com.synebula.gaea.data.date.DateTime
import com.synebula.gaea.query.Operator import com.synebula.gaea.query.Operator
import com.synebula.gaea.query.Where import com.synebula.gaea.query.Where
import jakarta.persistence.criteria.CriteriaBuilder import jakarta.persistence.criteria.*
import jakarta.persistence.criteria.CriteriaQuery
import jakarta.persistence.criteria.Predicate
import jakarta.persistence.criteria.Root
import org.springframework.data.jpa.domain.Specification import org.springframework.data.jpa.domain.Specification
import java.lang.reflect.Field import java.lang.reflect.Field
import java.util.* import java.util.*
@@ -69,79 +66,209 @@ fun String.tryToDigital(field: Field): Double {
* @param clazz 类 * @param clazz 类
* @return Specification * @return Specification
*/ */
fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> { fun Map<String, String>.toSpecification(clazz: Class<*>): Specification<*> {
val rangeStartSuffix = "[0]" //范围查询开始后缀 val rangeStartSuffix = "[0]" //范围查询开始后缀
val rangeEndSuffix = "[1]" //范围查询结束后缀 val rangeEndSuffix = "[1]" //范围查询结束后缀
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder -> return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
val predicates: MutableList<Predicate> = ArrayList() val predicates = mutableListOf<Predicate>()
for (argumentName in this!!.keys) { for (argumentName in this.keys) {
if (this[argumentName] == null) continue
var fieldName = argumentName
var operator: Operator
// 判断是否为range类型(范围内查询)
var start = true
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
fieldName = fieldName.substring(fieldName.length - 3)
if (fieldName.endsWith(rangeEndSuffix)) start = false
}
val field = clazz.getDeclaredField(fieldName)
val where: Where = field.getDeclaredAnnotation(Where::class.java)
operator = where.operator
// 如果是范围内容, 判断是数值类型还是时间类型
if (operator === Operator.Range) {
if (clazz.getDeclaredField(fieldName).type != Date::class.java) {
operator = if (start) Operator.Gte else Operator.Lte
}
}
var predicate: Predicate
var digitalValue: Double
try { try {
var fieldName = argumentName
val fieldValue = this[argumentName]!!
var operator: Operator = Operator.Default
// 判断是否为range类型(范围内查询)
var start = true
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
fieldName = fieldName.substring(fieldName.length - 3)
if (fieldName.endsWith(rangeEndSuffix)) start = false
}
val fieldTree = fieldName.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
//查找是否是嵌入字段, 找到最深的类型
var field: Field
if (fieldTree.isNotEmpty()) {
var hostClass = clazz //需要查找字段所在的class
var i = 0
do {
field = hostClass.getDeclaredField(fieldTree[i])
hostClass = field.type
i++
} while (i < fieldTree.size)
} else {
field = clazz.getDeclaredField(fieldName)
}
val where = field.getDeclaredAnnotation(Where::class.java)
if (where != null) operator = where.operator
// 如果是范围内容, 判断是数值类型还是时间类型
if (operator === Operator.Range) {
if (field.type != Date::class.java) {
operator = if (start) Operator.Gte else Operator.Lte
}
}
var predicate: Predicate
var digitalValue: Double
when (operator) { when (operator) {
Operator.Ne -> predicate = Operator.Ne -> predicate = criteriaBuilder.notEqual(
criteriaBuilder.notEqual(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field)) getFieldPath<Any>(root, fieldName),
typeConvert(field, fieldValue)
)
Operator.Lt -> { Operator.Lt -> try {
digitalValue = this[fieldName]!!.tryToDigital(field) digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue) predicate = criteriaBuilder.lessThan(getFieldPath(root, fieldName), digitalValue)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.lt)",
field.declaringClass.name,
field.name
), e
)
} }
Operator.Gt -> { Operator.Gt -> try {
digitalValue = this[fieldName]!!.tryToDigital(field) digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue) predicate = criteriaBuilder.greaterThan(
getFieldPath(
root,
fieldName
), digitalValue
)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.gt)",
field.declaringClass.name,
field.name
), e
)
} }
Operator.Lte -> { Operator.Lte -> try {
digitalValue = this[fieldName]!!.tryToDigital(field) digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue) predicate = criteriaBuilder.lessThanOrEqualTo(
getFieldPath(
root,
fieldName
), digitalValue
)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.lte)",
field.declaringClass.name,
field.name
), e
)
} }
Operator.Gte -> { Operator.Gte -> try {
digitalValue = this[fieldName]!!.tryToDigital(field) digitalValue = parseDigital(field, fieldValue)
predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue) predicate = criteriaBuilder.greaterThanOrEqualTo(
getFieldPath(
root,
fieldName
), digitalValue
)
} catch (e: Exception) {
throw RuntimeException(
String.format(
"class [%s] field [%s] can not use annotation Where(Operator.gte)",
field.declaringClass.name,
field.name
), e
)
} }
Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%") Operator.Like -> predicate = criteriaBuilder.like(
Operator.Range -> { getFieldPath(root, fieldName),
predicate = if (start) { String.format("%%%s%%", fieldValue)
criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!) )
} else {
criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
}
}
else -> predicate = Operator.Range -> predicate = if (start) criteriaBuilder.greaterThanOrEqualTo(
criteriaBuilder.equal(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field)) getFieldPath(root, fieldName), this[argumentName]!!
) else criteriaBuilder.lessThanOrEqualTo(
getFieldPath(root, fieldName), this[argumentName]!!
)
else -> predicate = criteriaBuilder.equal(
getFieldPath<Any>(root, fieldName),
typeConvert(field, fieldValue)
)
} }
predicates.add(predicate) predicates.add(predicate)
} catch (e: NoSuchFieldException) { } catch (e: NoSuchFieldException) {
throw Error( throw RuntimeException(e)
"class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringJavaClass.simpleName}.${operator.name})]",
e
)
} }
} }
criteriaBuilder.and(*predicates.toTypedArray()) criteriaBuilder.and(*predicates.toTypedArray())
} }
} }
/**
* 获取字段在
*/
fun <Y> getFieldPath(root: Root<Any?>, field: String): Path<Y>? {
val fieldTree = field.split("\\.".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
var path: Path<Y>
if (fieldTree.isNotEmpty()) {
path = root.get(fieldTree[0])
for (i in 1 until fieldTree.size) {
path = path.get(fieldTree[i])
}
return path
}
return root.get(field)
}
/**
* 类型转换
*
* @param field 字段对象
* @param value 值
* @return object
*/
fun typeConvert(field: Field, value: String): Any {
var result: Any = value
val fieldType = field.type
if (fieldType != value.javaClass) {
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) {
result = value.toInt()
}
if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) {
result = value.toDouble()
}
if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
result = value.toFloat()
}
if (Date::class.java == fieldType) {
result = DateTime(value, "yyyy-MM-dd HH:mm:ss").date
}
}
return result
}
/**
* 格式化数值类型
*
* @param field 字段对象
* @param value 值
* @return double
*/
fun parseDigital(field: Field, value: String): Double {
val result: Double
val fieldType = field.type
result =
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType || Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType || Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
value.toDouble()
} else throw java.lang.RuntimeException(
String.format(
"class [%s] field [%s] is not digital",
field.declaringClass.name,
field.name
)
)
return result
}

View File

@@ -10,8 +10,8 @@ class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun mappingArguments(args: Array<Any>): Array<Any> { override fun mappingArguments(args: Array<Any>): Array<Any> {
val params = args[0] as Map<String, String>? val params = args[0] as Map<String, String>?
val specification = params.toSpecification(entityClazz) val specification = params?.toSpecification(entityClazz)
return arrayOf(specification) return if (specification != null) arrayOf(specification) else arrayOf()
} }
override fun mappingResult(result: Any): Any { override fun mappingResult(result: Any): Any {

View File

@@ -10,6 +10,7 @@ import org.springframework.data.domain.PageRequest
import org.springframework.data.domain.Pageable import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort import org.springframework.data.domain.Sort
import java.util.* import java.util.*
import com.synebula.gaea.query.Page as QueryPage
/** /**
* 分页方法参数映射 * 分页方法参数映射
@@ -29,8 +30,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
val fields = entityClazz.declaredFields val fields = entityClazz.declaredFields
for (field in fields) { for (field in fields) {
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation -> val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
(annotation.annotationClass.java == Id::class.java (annotation.annotationClass.java == Id::class.java || annotation.annotationClass.java == EmbeddedId::class.java)
|| annotation.annotationClass.java == EmbeddedId::class.java)
} }
if (isId) { if (isId) {
sort = Sort.by(Sort.Direction.ASC, field.name) sort = Sort.by(Sort.Direction.ASC, field.name)
@@ -50,7 +50,7 @@ class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMe
override fun mappingResult(result: Any): Any { override fun mappingResult(result: Any): Any {
val page = result as Page<*> val page = result as Page<*>
// Page 页面从0开始 // Page 页面从0开始 [com.synebula.gaea.query.Page as QueryPage]
return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content) return QueryPage(page.number + 1, page.size, page.totalElements.toInt(), page.content)
} }
} }