1.4.0 增加jpa的代理模块
This commit is contained in:
11
src/gaea.jpa/build.gradle
Normal file
11
src/gaea.jpa/build.gradle
Normal file
@@ -0,0 +1,11 @@
|
||||
ext {
|
||||
jassist_version = '3.29.0-GA'
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api project(":src:gaea")
|
||||
|
||||
implementation("org.springframework.boot:spring-boot-starter-data-jpa:$spring_version")
|
||||
implementation("org.javassist:javassist:$jassist_version")
|
||||
}
|
||||
|
||||
67
src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/JpaQuery.kt
Normal file
67
src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/JpaQuery.kt
Normal file
@@ -0,0 +1,67 @@
|
||||
package com.synebula.gaea.jpa
|
||||
|
||||
import com.synebula.gaea.query.IQuery
|
||||
import com.synebula.gaea.query.Page
|
||||
import com.synebula.gaea.query.Params
|
||||
import org.springframework.data.jpa.repository.support.SimpleJpaRepository
|
||||
import javax.persistence.EntityManager
|
||||
|
||||
class JpaQuery<TView, ID>(override var clazz: Class<TView>, entityManager: EntityManager) : IQuery<TView, ID> {
|
||||
protected var repo: SimpleJpaRepository<TView, ID>
|
||||
|
||||
init {
|
||||
repo = SimpleJpaRepository<TView, ID>(clazz, entityManager)
|
||||
}
|
||||
|
||||
override operator fun get(id: ID): TView? {
|
||||
val view = this.repo.findById(id)
|
||||
return if (view.isPresent) view.get() else null
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 根据实体类条件查询所有符合条件记录
|
||||
*`
|
||||
* @param params 查询条件。
|
||||
* @return 视图列表
|
||||
*/
|
||||
override fun list(params: Map<String, String>?): List<TView> {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据条件查询符合条件记录的数量
|
||||
*
|
||||
* @param params 查询条件。
|
||||
* @return 数量
|
||||
*/
|
||||
override fun count(params: Map<String, String>?): Int {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
return -1
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据实体类条件查询所有符合条件记录(分页查询)
|
||||
*
|
||||
* @param params 分页条件
|
||||
* @return 分页数据
|
||||
*/
|
||||
override fun paging(params: Params): Page<TView> {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
return Page()
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询条件范围内数据。
|
||||
* @param field 查询字段
|
||||
* @param params 查询条件
|
||||
*
|
||||
* @return 视图列表
|
||||
*/
|
||||
override fun range(field: String, params: List<Any>): List<TView> {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package com.synebula.gaea.jpa
|
||||
|
||||
import com.synebula.gaea.domain.model.IAggregateRoot
|
||||
import com.synebula.gaea.domain.repository.IRepository
|
||||
import org.springframework.data.jpa.repository.JpaRepository
|
||||
import org.springframework.data.jpa.repository.support.SimpleJpaRepository
|
||||
import javax.persistence.EntityManager
|
||||
|
||||
|
||||
class JpaRepository<TAggregateRoot : IAggregateRoot<ID>, ID>(
|
||||
override var clazz: Class<TAggregateRoot>,
|
||||
entityManager: EntityManager
|
||||
) : IRepository<TAggregateRoot, ID> {
|
||||
protected var repo: JpaRepository<TAggregateRoot, ID>? = null
|
||||
|
||||
init {
|
||||
repo = SimpleJpaRepository(clazz, entityManager)
|
||||
}
|
||||
|
||||
/**
|
||||
* 插入单个对象。
|
||||
*
|
||||
* @param obj 需要插入的对象。
|
||||
* @return 返回原对象,如果对象ID为自增,则补充自增ID。
|
||||
*/
|
||||
override fun add(obj: TAggregateRoot) {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
}
|
||||
|
||||
/**
|
||||
* 插入多个个对象。
|
||||
*
|
||||
* @param list 需要插入的对象。
|
||||
* @return 返回原对象,如果对象ID为自增,则补充自增ID。
|
||||
*/
|
||||
override fun add(list: List<TAggregateRoot>) {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新对象。
|
||||
*
|
||||
* @param obj 需要更新的对象。
|
||||
* @return
|
||||
*/
|
||||
override fun update(obj: TAggregateRoot) {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新多个个对象。
|
||||
*
|
||||
* @param list 需要新的对象。
|
||||
*/
|
||||
override fun update(list: List<TAggregateRoot>) {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过id删除该条数据
|
||||
*
|
||||
* @param id 对象ID。
|
||||
* @return
|
||||
*/
|
||||
override fun remove(id: ID) {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据ID获取对象。
|
||||
*
|
||||
* @param id 对象ID。
|
||||
* @return
|
||||
*/
|
||||
override fun get(id: ID): TAggregateRoot? {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据条件查询符合条件记录的数量
|
||||
*
|
||||
* @param params 查询条件。
|
||||
* @return int
|
||||
*/
|
||||
override fun count(params: Map<String, String>?): Int {
|
||||
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
|
||||
return -1
|
||||
}
|
||||
|
||||
}
|
||||
147
src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/Jpas.kt
Normal file
147
src/gaea.jpa/src/main/java/com/synebula/gaea/jpa/Jpas.kt
Normal file
@@ -0,0 +1,147 @@
|
||||
package com.synebula.gaea.jpa
|
||||
|
||||
import com.synebula.gaea.data.date.DateTime
|
||||
import com.synebula.gaea.query.Operator
|
||||
import com.synebula.gaea.query.Where
|
||||
import org.springframework.data.jpa.domain.Specification
|
||||
import java.lang.reflect.Field
|
||||
import java.util.*
|
||||
import javax.persistence.criteria.CriteriaBuilder
|
||||
import javax.persistence.criteria.CriteriaQuery
|
||||
import javax.persistence.criteria.Predicate
|
||||
import javax.persistence.criteria.Root
|
||||
|
||||
|
||||
/**
|
||||
* 类型转换
|
||||
*
|
||||
* @param field 字段对象
|
||||
* @return object
|
||||
*/
|
||||
fun String.toFieldType(field: Field): Any? {
|
||||
var result: Any? = this
|
||||
val fieldType = field.type
|
||||
if (fieldType != this.javaClass) {
|
||||
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) {
|
||||
result = this.toInt()
|
||||
}
|
||||
if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) {
|
||||
result = this.toDouble()
|
||||
}
|
||||
if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
|
||||
result = this.toFloat()
|
||||
}
|
||||
if (Date::class.java == fieldType) {
|
||||
result = DateTime(this, "yyyy-MM-dd HH:mm:ss").date
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化数值类型
|
||||
*
|
||||
* @param field 字段对象
|
||||
* @return double
|
||||
*/
|
||||
fun String.tryToDigital(field: Field): Double {
|
||||
val result: Double
|
||||
val fieldType = field.type
|
||||
result =
|
||||
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType
|
||||
|| Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType
|
||||
|| Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType
|
||||
) {
|
||||
this.toDouble()
|
||||
} else throw RuntimeException(
|
||||
String.format(
|
||||
"class [%s] field [%s] is not digital",
|
||||
field.declaringClass.name,
|
||||
field.name
|
||||
)
|
||||
)
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 参数 Map 转换查询 Specification
|
||||
*
|
||||
* @param clazz 类
|
||||
* @return Specification
|
||||
*/
|
||||
fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> {
|
||||
val rangeStartSuffix = "[0]" //范围查询开始后缀
|
||||
val rangeEndSuffix = "[1]" //范围查询结束后缀
|
||||
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
|
||||
val predicates: MutableList<Predicate> = ArrayList()
|
||||
for (argumentName in this!!.keys) {
|
||||
if (this[argumentName] == null) continue
|
||||
var fieldName = argumentName
|
||||
var operator: Operator
|
||||
|
||||
// 判断是否为range类型(范围内查询)
|
||||
var start = true
|
||||
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
|
||||
fieldName = fieldName.substring(fieldName.length - 3)
|
||||
if (fieldName.endsWith(rangeEndSuffix)) start = false
|
||||
}
|
||||
val field = clazz.getDeclaredField(fieldName)
|
||||
val where: Where = field.getDeclaredAnnotation(Where::class.java)
|
||||
operator = where.operator
|
||||
|
||||
// 如果是范围内容, 判断是数值类型还是时间类型
|
||||
if (operator === Operator.Range) {
|
||||
if (clazz.getDeclaredField(fieldName).type != Date::class.java) {
|
||||
operator = if (start) Operator.Gte else Operator.Lte
|
||||
}
|
||||
}
|
||||
var predicate: Predicate
|
||||
var digitalValue: Double
|
||||
try {
|
||||
when (operator) {
|
||||
Operator.Ne -> predicate =
|
||||
criteriaBuilder.notEqual(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
|
||||
|
||||
Operator.Lt -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue)
|
||||
}
|
||||
|
||||
Operator.Gt -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue)
|
||||
}
|
||||
|
||||
Operator.Lte -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue)
|
||||
}
|
||||
|
||||
Operator.Gte -> {
|
||||
digitalValue = this[fieldName]!!.tryToDigital(field)
|
||||
predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue)
|
||||
}
|
||||
|
||||
Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%")
|
||||
Operator.Range -> {
|
||||
predicate = if (start) {
|
||||
criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
|
||||
} else {
|
||||
criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
|
||||
}
|
||||
}
|
||||
|
||||
else -> predicate =
|
||||
criteriaBuilder.equal(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
|
||||
}
|
||||
predicates.add(predicate)
|
||||
} catch (e: NoSuchFieldException) {
|
||||
throw Error(
|
||||
"class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringClass.simpleName}.${operator.name})]",
|
||||
e
|
||||
)
|
||||
}
|
||||
}
|
||||
criteriaBuilder.and(*predicates.toTypedArray())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.synebula.gaea.jpa.proxy
|
||||
|
||||
import org.springframework.beans.factory.BeanFactory
|
||||
import org.springframework.beans.factory.FactoryBean
|
||||
import org.springframework.cglib.proxy.Enhancer
|
||||
import org.springframework.data.repository.Repository
|
||||
|
||||
class JpaRepositoryFactory(
|
||||
private val beanFactory: BeanFactory,
|
||||
private val interfaceType: Class<*>,
|
||||
private val implBeanNames: List<String>
|
||||
) : FactoryBean<Any> {
|
||||
override fun getObject(): Any {
|
||||
val handler: JpaRepositoryProxy<*, *, *> = JpaRepositoryProxy<Repository<Any, Any>, Any, Any>(
|
||||
beanFactory,
|
||||
interfaceType, implBeanNames
|
||||
)
|
||||
|
||||
//JDK 方式代理代码, 暂时选用cglib
|
||||
//Object proxy = Proxy.newProxyInstance(this.interfaceType.getClassLoader(), new Class[]{this.interfaceType}, handler);
|
||||
|
||||
//cglib代理
|
||||
val enhancer = Enhancer()
|
||||
enhancer.setSuperclass(interfaceType)
|
||||
enhancer.setCallback(handler)
|
||||
return enhancer.create()
|
||||
}
|
||||
|
||||
override fun getObjectType(): Class<*> {
|
||||
return interfaceType
|
||||
}
|
||||
|
||||
override fun isSingleton(): Boolean {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,289 @@
|
||||
package com.synebula.gaea.jpa.proxy
|
||||
|
||||
import com.synebula.gaea.jpa.proxy.method.JpaMethodProxy
|
||||
import javassist.*
|
||||
import javassist.bytecode.AnnotationsAttribute
|
||||
import javassist.bytecode.MethodInfo
|
||||
import javassist.bytecode.SignatureAttribute
|
||||
import javassist.bytecode.annotation.Annotation
|
||||
import javassist.bytecode.annotation.BooleanMemberValue
|
||||
import javassist.bytecode.annotation.StringMemberValue
|
||||
import org.springframework.beans.BeansException
|
||||
import org.springframework.beans.factory.BeanFactory
|
||||
import org.springframework.beans.factory.ObjectProvider
|
||||
import org.springframework.cglib.proxy.MethodInterceptor
|
||||
import org.springframework.cglib.proxy.MethodProxy
|
||||
import org.springframework.data.jpa.repository.Modifying
|
||||
import org.springframework.data.jpa.repository.Query
|
||||
import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean
|
||||
import org.springframework.data.jpa.repository.support.JpaRepositoryImplementation
|
||||
import org.springframework.data.mapping.context.MappingContext
|
||||
import org.springframework.data.querydsl.EntityPathResolver
|
||||
import org.springframework.data.querydsl.SimpleEntityPathResolver
|
||||
import org.springframework.data.repository.Repository
|
||||
import java.lang.reflect.Method
|
||||
import java.lang.reflect.ParameterizedType
|
||||
import java.lang.reflect.Type
|
||||
import javax.persistence.EntityManager
|
||||
|
||||
class JpaRepositoryProxy<T : Repository<S, ID>?, S, ID>(
|
||||
beanFactory: BeanFactory,
|
||||
interfaceType: Class<*>,
|
||||
implementBeanNames: List<String>?
|
||||
) : MethodInterceptor { //InvocationHandler {
|
||||
|
||||
//JPA 默认Entity Manager上下文, 如不用该上下文则没有事务管理器
|
||||
private val EntityManagerName = "org.springframework.orm.jpa.SharedEntityManagerCreator#0"
|
||||
|
||||
/**
|
||||
* JPA代理对象
|
||||
*/
|
||||
private var jpaRepository: JpaRepositoryImplementation<*, *>? = null
|
||||
|
||||
/**
|
||||
* bean注册器
|
||||
*/
|
||||
protected var beanFactory: BeanFactory
|
||||
|
||||
/**
|
||||
* 方法映射管理器
|
||||
*/
|
||||
protected var jpaMethodProxy: JpaMethodProxy
|
||||
|
||||
/**
|
||||
* 需要代理的接口类型
|
||||
*/
|
||||
protected var interfaceType: Class<*>
|
||||
|
||||
/**
|
||||
* 接口实现bean名称
|
||||
*/
|
||||
protected var implementBeanNames: List<String>
|
||||
|
||||
init {
|
||||
try {
|
||||
this.beanFactory = beanFactory
|
||||
this.implementBeanNames = implementBeanNames ?: listOf()
|
||||
this.interfaceType = interfaceType
|
||||
|
||||
// 设置方法映射查询参数类(Entity类型)
|
||||
val type = interfaceType.genericInterfaces[0]
|
||||
val typeArguments = (type as ParameterizedType).actualTypeArguments
|
||||
jpaMethodProxy = JpaMethodProxy(Class.forName(typeArguments[0].typeName))
|
||||
|
||||
|
||||
// 创建虚假的JpaRepository接口
|
||||
val repoClazz = createJpaRepoClazz(*typeArguments)
|
||||
val jpaRepositoryFactoryBean: JpaRepositoryFactoryBean<*, *, *> = createJPARepositoryFactoryBean(repoClazz)
|
||||
jpaRepository = jpaRepositoryFactoryBean.getObject() as JpaRepositoryImplementation<*, *>
|
||||
} catch (e: ClassNotFoundException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* JDK 方式代理代码, 暂时选用cglib
|
||||
*/
|
||||
@Deprecated("")
|
||||
@Throws(Throwable::class)
|
||||
operator fun invoke(proxy: Any?, method: Method, args: Array<Any>): Any? {
|
||||
return if (Any::class.java == method.declaringClass) {
|
||||
method.invoke(proxy, *args)
|
||||
} else {
|
||||
execMethod(method, args)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 暂时选用cglib 方式代理代码
|
||||
*/
|
||||
@Throws(Throwable::class)
|
||||
override fun intercept(o: Any, method: Method, args: Array<Any>, methodProxy: MethodProxy): Any {
|
||||
return if (Any::class.java == method.declaringClass) {
|
||||
methodProxy.invoke(this, args)
|
||||
} else {
|
||||
execMethod(method, args)!!
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行代理方法
|
||||
*
|
||||
* @param method 需要执行的方法
|
||||
* @param args 参数列表
|
||||
* @return 方法执行结果
|
||||
* @throws Throwable 异常
|
||||
*/
|
||||
@Throws(Throwable::class)
|
||||
private fun execMethod(method: Method, args: Array<Any>): Any? {
|
||||
// 找到对应代理方法, 代理执行
|
||||
return if (jpaMethodProxy.match(method)) {
|
||||
try {
|
||||
jpaMethodProxy.proxyExecMethod(jpaRepository, method, args)
|
||||
} catch (ex: Exception) {
|
||||
throw RuntimeException(
|
||||
String.format(
|
||||
"对象[%s]代理执行方法[%s.%s]出错",
|
||||
jpaMethodProxy.javaClass, interfaceType.name, method.name
|
||||
), ex
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// 找不到代理方法则查找具体实现类执行
|
||||
if (implementBeanNames.isEmpty()) throw RuntimeException(
|
||||
String.format(
|
||||
"找不到[%s.%s]对应的代理方法",
|
||||
method.declaringClass.name, method.name
|
||||
)
|
||||
) else {
|
||||
val bean = beanFactory.getBean(implementBeanNames[0])
|
||||
val proxyMethod = bean.javaClass.getMethod(method.name, *method.parameterTypes)
|
||||
proxyMethod.invoke(bean, *args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用javassist创建虚拟的jpa repo类
|
||||
*
|
||||
* @param typeArgs 泛型参数
|
||||
* @return 虚拟的jpa repo类形
|
||||
*/
|
||||
@Suppress("unchecked_cast")
|
||||
private fun createJpaRepoClazz(vararg typeArgs: Type): Class<T> {
|
||||
return try {
|
||||
val pool = ClassPool.getDefault()
|
||||
val jpaRepoCt = pool[JpaRepositoryImplementation::class.java.name]
|
||||
val clazzName = String.format("%sRepository", typeArgs[0].typeName)
|
||||
val repoCt = pool.makeInterface(clazzName, jpaRepoCt)
|
||||
val typeArguments = arrayOfNulls<SignatureAttribute.TypeArgument>(typeArgs.size)
|
||||
for (i in typeArgs.indices) {
|
||||
typeArguments[i] = SignatureAttribute.TypeArgument(SignatureAttribute.ClassType(typeArgs[i].typeName))
|
||||
}
|
||||
val ac = SignatureAttribute.ClassSignature(
|
||||
null,
|
||||
null,
|
||||
arrayOf(SignatureAttribute.ClassType(jpaRepoCt.name, typeArguments))
|
||||
)
|
||||
repoCt.genericSignature = ac.encode()
|
||||
addClassQueryMethod(repoCt)
|
||||
repoCt.toClass() as Class<T>
|
||||
} catch (ex: Exception) {
|
||||
throw RuntimeException(ex)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 给虚拟接口添加Query注解方法
|
||||
*
|
||||
* @param ctClass jpa虚拟接口
|
||||
*/
|
||||
@Throws(NotFoundException::class, CannotCompileException::class, ClassNotFoundException::class)
|
||||
private fun addClassQueryMethod(ctClass: CtClass) {
|
||||
// 找到Query注解方法并加入到虚拟接口中
|
||||
val interfaceCtClass = ClassPool.getDefault()[interfaceType.name]
|
||||
for (ctMethod in interfaceCtClass.methods) {
|
||||
val query = ctMethod.getAnnotation(Query::class.java)
|
||||
if (query != null) {
|
||||
val method = CtNewMethod.abstractMethod(
|
||||
ctMethod.returnType, ctMethod.name,
|
||||
ctMethod.parameterTypes, arrayOfNulls(0), ctClass
|
||||
)
|
||||
val methodInfo = method.methodInfo
|
||||
var modifying: Modifying? = null
|
||||
//查找有无@Modifing注解,有的化虚拟接口也需要加上
|
||||
val annotation = ctMethod.getAnnotation(Modifying::class.java)
|
||||
if (annotation != null) {
|
||||
modifying = annotation as Modifying
|
||||
}
|
||||
|
||||
// 增加Query注解
|
||||
val attribute = buildQueryAttribute(methodInfo, query as Query, modifying)
|
||||
methodInfo.addAttribute(attribute)
|
||||
ctClass.addMethod(method)
|
||||
}
|
||||
}
|
||||
|
||||
// Query注解方法加入到代理中
|
||||
for (method in interfaceType.methods) {
|
||||
val annotation: Any? = method.getAnnotation(Query::class.java)
|
||||
if (annotation != null) {
|
||||
jpaMethodProxy.addQueryMethodMapper(method)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建javassist方法Query注解
|
||||
*
|
||||
* @param methodInfo 方法信息
|
||||
* @param query query注解实例
|
||||
* @param modifying Modifying注解
|
||||
* @return 注解信息
|
||||
*/
|
||||
private fun buildQueryAttribute(methodInfo: MethodInfo, query: Query, modifying: Modifying?): AnnotationsAttribute {
|
||||
val cp = methodInfo.constPool
|
||||
val attribute = AnnotationsAttribute(cp, AnnotationsAttribute.visibleTag)
|
||||
val queryAnnotation = Annotation(Query::class.java.name, cp)
|
||||
queryAnnotation.addMemberValue("value", StringMemberValue(query.value, cp))
|
||||
queryAnnotation.addMemberValue("countQuery", StringMemberValue(query.countQuery, cp))
|
||||
queryAnnotation.addMemberValue("countProjection", StringMemberValue(query.countProjection, cp))
|
||||
queryAnnotation.addMemberValue("nativeQuery", BooleanMemberValue(query.nativeQuery, cp))
|
||||
queryAnnotation.addMemberValue("name", StringMemberValue(query.name, cp))
|
||||
queryAnnotation.addMemberValue("countName", StringMemberValue(query.countName, cp))
|
||||
if (modifying != null) {
|
||||
val modifyingAnnotation = Annotation(
|
||||
Modifying::class.java.name, cp
|
||||
)
|
||||
modifyingAnnotation.addMemberValue(
|
||||
"flushAutomatically",
|
||||
BooleanMemberValue(modifying.flushAutomatically, cp)
|
||||
)
|
||||
modifyingAnnotation.addMemberValue(
|
||||
"clearAutomatically",
|
||||
BooleanMemberValue(modifying.clearAutomatically, cp)
|
||||
)
|
||||
attribute.annotations = arrayOf(queryAnnotation, modifyingAnnotation)
|
||||
} else attribute.setAnnotation(queryAnnotation)
|
||||
return attribute
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建JpaRepositoryFactoryBean对象
|
||||
*
|
||||
* @param jpaRepositoryClass 需要创建的 JPA JpaRepository Class
|
||||
* @return JpaRepositoryFactoryBean对象
|
||||
*/
|
||||
private fun createJPARepositoryFactoryBean(jpaRepositoryClass: Class<out T>): JpaRepositoryFactoryBean<T, S, ID> {
|
||||
// jpa 默认使用改名成EntityManager, 若用默认则没有事务上下文
|
||||
val entityManager = beanFactory.getBean(EntityManagerName) as EntityManager
|
||||
val repositoryFactoryBean = JpaRepositoryFactoryBean(jpaRepositoryClass)
|
||||
repositoryFactoryBean.setEntityManager(entityManager)
|
||||
repositoryFactoryBean.setBeanFactory(beanFactory)
|
||||
repositoryFactoryBean.setBeanClassLoader(JpaRepositoryFactoryBean::class.java.classLoader)
|
||||
repositoryFactoryBean.setMappingContext(beanFactory.getBean("jpaMappingContext") as MappingContext<*, *>)
|
||||
repositoryFactoryBean.setEntityPathResolver(object : ObjectProvider<EntityPathResolver> {
|
||||
@Throws(BeansException::class)
|
||||
override fun getObject(vararg objects: Any): EntityPathResolver {
|
||||
return SimpleEntityPathResolver("")
|
||||
}
|
||||
|
||||
@Throws(BeansException::class)
|
||||
override fun getIfAvailable(): EntityPathResolver? {
|
||||
return null
|
||||
}
|
||||
|
||||
@Throws(BeansException::class)
|
||||
override fun getIfUnique(): EntityPathResolver? {
|
||||
return null
|
||||
}
|
||||
|
||||
@Throws(BeansException::class)
|
||||
override fun getObject(): EntityPathResolver {
|
||||
return SimpleEntityPathResolver("")
|
||||
}
|
||||
})
|
||||
repositoryFactoryBean.afterPropertiesSet()
|
||||
return repositoryFactoryBean
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.synebula.gaea.jpa.proxy
|
||||
|
||||
import org.springframework.context.annotation.Import
|
||||
import java.lang.annotation.Inherited
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
@Target(AnnotationTarget.ANNOTATION_CLASS, AnnotationTarget.CLASS)
|
||||
@Retention(AnnotationRetention.RUNTIME)
|
||||
@MustBeDocumented
|
||||
@Inherited
|
||||
@Import(JpaRepositoryRegister::class)
|
||||
annotation class JpaRepositoryProxyScan(val basePackages: Array<String> = [], val scanInterfaces: Array<KClass<*>> = [])
|
||||
@@ -0,0 +1,141 @@
|
||||
package com.synebula.gaea.jpa.proxy
|
||||
|
||||
import org.springframework.beans.BeansException
|
||||
import org.springframework.beans.factory.BeanClassLoaderAware
|
||||
import org.springframework.beans.factory.BeanFactory
|
||||
import org.springframework.beans.factory.BeanFactoryAware
|
||||
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition
|
||||
import org.springframework.beans.factory.config.BeanDefinition
|
||||
import org.springframework.beans.factory.support.BeanDefinitionBuilder
|
||||
import org.springframework.beans.factory.support.BeanDefinitionRegistry
|
||||
import org.springframework.beans.factory.support.GenericBeanDefinition
|
||||
import org.springframework.context.EnvironmentAware
|
||||
import org.springframework.context.ResourceLoaderAware
|
||||
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider
|
||||
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar
|
||||
import org.springframework.core.annotation.AnnotationAttributes
|
||||
import org.springframework.core.env.Environment
|
||||
import org.springframework.core.io.ResourceLoader
|
||||
import org.springframework.core.type.AnnotationMetadata
|
||||
import org.springframework.core.type.classreading.MetadataReader
|
||||
import org.springframework.core.type.classreading.MetadataReaderFactory
|
||||
import org.springframework.core.type.filter.TypeFilter
|
||||
import org.springframework.util.ClassUtils
|
||||
import java.util.*
|
||||
import java.util.stream.Collectors
|
||||
|
||||
class JpaRepositoryRegister : ImportBeanDefinitionRegistrar, ResourceLoaderAware, BeanClassLoaderAware,
|
||||
EnvironmentAware,
|
||||
BeanFactoryAware {
|
||||
private lateinit var environment: Environment
|
||||
private lateinit var resourceLoader: ResourceLoader
|
||||
private var classLoader: ClassLoader? = null
|
||||
private var beanFactory: BeanFactory? = null
|
||||
override fun registerBeanDefinitions(metadata: AnnotationMetadata, registry: BeanDefinitionRegistry) {
|
||||
val attributes = AnnotationAttributes(
|
||||
metadata.getAnnotationAttributes(
|
||||
JpaRepositoryProxyScan::class.java.name
|
||||
) ?: mapOf()
|
||||
)
|
||||
val basePackages = attributes.getStringArray("basePackages")
|
||||
val scanInterfaces = attributes.getClassArray("scanInterfaces")
|
||||
// 过滤scanInterfaces接口内容
|
||||
val filter = getSubObjectTypeFilter(scanInterfaces)
|
||||
val beanDefinitions = scan(basePackages, arrayOf(filter))
|
||||
|
||||
// 遍历处理接口
|
||||
for (beanDefinition in beanDefinitions) {
|
||||
// 获取RepositoryFor注解信息
|
||||
val beanClazz: Class<*> = try {
|
||||
Class.forName(beanDefinition.beanClassName)
|
||||
} catch (e: ClassNotFoundException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
val beanClazzTypeFilter = getSubObjectTypeFilter(arrayOf(beanClazz))
|
||||
val implClazzDefinitions = scan(basePackages, arrayOf(beanClazzTypeFilter))
|
||||
for (definition in implClazzDefinitions) {
|
||||
definition.isAutowireCandidate = false
|
||||
registry.registerBeanDefinition(Objects.requireNonNull(definition.beanClassName), definition)
|
||||
}
|
||||
// 构建bean定义
|
||||
// 1 bean参数
|
||||
val implBeanNames = implClazzDefinitions.stream().map { obj: BeanDefinition -> obj.beanClassName }
|
||||
.collect(Collectors.toList())
|
||||
val builder = BeanDefinitionBuilder.genericBeanDefinition(beanClazz)
|
||||
builder.addConstructorArgValue(beanFactory)
|
||||
builder.addConstructorArgValue(beanClazz)
|
||||
builder.addConstructorArgValue(implBeanNames)
|
||||
val definition = builder.rawBeanDefinition as GenericBeanDefinition
|
||||
definition.beanClass = JpaRepositoryFactory::class.java
|
||||
definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE
|
||||
registry.registerBeanDefinition(beanClazz.name, definition)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据过滤器扫描直接包下bean
|
||||
*
|
||||
* @param packages 指定的扫描包
|
||||
* @param filters 过滤器
|
||||
* @return 扫描后的bean定义
|
||||
*/
|
||||
private fun scan(packages: Array<String>?, filters: Array<TypeFilter>): List<BeanDefinition> {
|
||||
val scanner: ClassPathScanningCandidateComponentProvider =
|
||||
object : ClassPathScanningCandidateComponentProvider() {
|
||||
override fun isCandidateComponent(beanDefinition: AnnotatedBeanDefinition): Boolean {
|
||||
try {
|
||||
val metadata = beanDefinition.metadata
|
||||
val target = ClassUtils.forName(metadata.className, classLoader)
|
||||
return !target.isAnnotation
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
scanner.environment = environment
|
||||
scanner.resourceLoader = resourceLoader
|
||||
for (filter in filters) {
|
||||
scanner.addIncludeFilter(filter)
|
||||
}
|
||||
val beanDefinitions: MutableList<BeanDefinition> = LinkedList()
|
||||
for (basePackage in packages!!) {
|
||||
beanDefinitions.addAll(scanner.findCandidateComponents(basePackage))
|
||||
}
|
||||
return beanDefinitions
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取父接口实现对象的类型过滤器
|
||||
*
|
||||
* @param interfaces 父接口
|
||||
* @return 类型过滤器
|
||||
*/
|
||||
private fun getSubObjectTypeFilter(interfaces: Array<Class<*>>?): TypeFilter {
|
||||
return TypeFilter { metadataReader: MetadataReader, _: MetadataReaderFactory? ->
|
||||
val interfaceNames = metadataReader.classMetadata.interfaceNames
|
||||
var matched = false
|
||||
for (interfaceName in interfaceNames) {
|
||||
matched = Arrays.stream(interfaces)
|
||||
.anyMatch { clazz: Class<*> -> clazz.name == interfaceName }
|
||||
}
|
||||
matched
|
||||
}
|
||||
}
|
||||
|
||||
override fun setResourceLoader(resourceLoader: ResourceLoader) {
|
||||
this.resourceLoader = resourceLoader
|
||||
}
|
||||
|
||||
override fun setBeanClassLoader(classLoader: ClassLoader) {
|
||||
this.classLoader = classLoader
|
||||
}
|
||||
|
||||
override fun setEnvironment(environment: Environment) {
|
||||
this.environment = environment
|
||||
}
|
||||
|
||||
@Throws(BeansException::class)
|
||||
override fun setBeanFactory(beanFactory: BeanFactory) {
|
||||
this.beanFactory = beanFactory
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package com.synebula.gaea.jpa.proxy.method
|
||||
|
||||
import com.synebula.gaea.domain.model.IAggregateRoot
|
||||
import com.synebula.gaea.jpa.proxy.method.resolver.AbstractMethodResolver
|
||||
import com.synebula.gaea.jpa.proxy.method.resolver.DefaultMethodResolver
|
||||
import com.synebula.gaea.jpa.proxy.method.resolver.FindMethodResolver
|
||||
import com.synebula.gaea.jpa.proxy.method.resolver.PageMethodResolver
|
||||
import com.synebula.gaea.query.Params
|
||||
import org.springframework.data.domain.Pageable
|
||||
import org.springframework.data.jpa.domain.Specification
|
||||
import java.lang.reflect.InvocationTargetException
|
||||
import java.lang.reflect.Method
|
||||
|
||||
/**
|
||||
* Jpa 方法映射包装类
|
||||
*/
|
||||
class JpaMethodProxy(
|
||||
/**
|
||||
* 方法需要实现的实体类
|
||||
*/
|
||||
private var entityClazz: Class<*>
|
||||
) {
|
||||
/**
|
||||
* 默认的方法映射配置(IRepository, IQuery 接口中定义的方法)
|
||||
*/
|
||||
private val defaultMethodMapper: MutableMap<String, AbstractMethodResolver?> = LinkedHashMap()
|
||||
|
||||
/**
|
||||
* 用户自定义的query注解方法处理
|
||||
*/
|
||||
private val queryMethodMapper: MutableMap<String, AbstractMethodResolver> = LinkedHashMap()
|
||||
|
||||
/**
|
||||
* 方法参数映射
|
||||
*/
|
||||
var argumentResolver: AbstractMethodResolver? = null
|
||||
private set
|
||||
|
||||
init {
|
||||
initDefaultMethodMapper()
|
||||
}
|
||||
|
||||
/**
|
||||
* 匹配方法是否需要代理
|
||||
*
|
||||
* @param method 方法
|
||||
* @return ture/false
|
||||
*/
|
||||
fun match(method: Method): Boolean {
|
||||
var isMatch = (defaultMethodMapper.containsKey(method.name)
|
||||
&& defaultMethodMapper[method.name]!!.match(method, AbstractMethodResolver.MethodType.SourceMethod))
|
||||
|
||||
// 如果默认代理方法没有匹配,则查找Query方法映射
|
||||
if (!isMatch) {
|
||||
isMatch = queryMethodMapper.containsKey(method.toString())
|
||||
}
|
||||
return isMatch
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析代理方法
|
||||
*
|
||||
* @param proxy 代理对象
|
||||
* @param method 源方法
|
||||
* @param args 参数列表
|
||||
* @return 执行结果
|
||||
*/
|
||||
@Throws(NoSuchMethodException::class, InvocationTargetException::class, IllegalAccessException::class)
|
||||
fun proxyExecMethod(proxy: Any?, method: Method, args: Array<Any>): Any? {
|
||||
// 匹配方法是否需要代理
|
||||
if (defaultMethodMapper.containsKey(method.name)) {
|
||||
val resolver = defaultMethodMapper[method.name]
|
||||
// 匹配参数是否相同
|
||||
if (resolver!!.match(method, AbstractMethodResolver.MethodType.SourceMethod)) {
|
||||
//遍历代理对象, 找到合适的代理方法
|
||||
val targetMethod =
|
||||
proxy!!.javaClass.getMethod(resolver.targetMethodName, *resolver.targetMethodParameters)
|
||||
return try {
|
||||
// 开始执行代理方法
|
||||
val mappingArguments = resolver.mappingArguments(args)
|
||||
val result = targetMethod.invoke(proxy, *mappingArguments)
|
||||
resolver.mappingResult(result)
|
||||
} catch (e: IllegalAccessException) {
|
||||
throw RuntimeException(e)
|
||||
} catch (e: InvocationTargetException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 如果默认代理方法没有匹配,则查找Query方法映射
|
||||
if (queryMethodMapper.containsKey(method.toString())) {
|
||||
val targetMethod = proxy!!.javaClass.getMethod(method.name, *method.parameterTypes)
|
||||
return targetMethod.invoke(proxy, *args)
|
||||
}
|
||||
throw RuntimeException(
|
||||
String.format(
|
||||
"方法[%s,%s]没有匹配的代理配置信息, 执行该方法前请先执行match方法判断",
|
||||
method.declaringClass.name, method.name
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化默认的方法映射列表
|
||||
*/
|
||||
private fun initDefaultMethodMapper() {
|
||||
defaultMethodMapper["add"] = DefaultMethodResolver("saveAndFlush")
|
||||
.sourceMethodParameters(IAggregateRoot::class.java).targetMethodParameters(Any::class.java)
|
||||
defaultMethodMapper["update"] = DefaultMethodResolver("saveAndFlush")
|
||||
.sourceMethodParameters(IAggregateRoot::class.java).targetMethodParameters(Any::class.java)
|
||||
defaultMethodMapper["remove"] = DefaultMethodResolver("deleteById")
|
||||
.sourceMethodParameters(Any::class.java).targetMethodParameters(Any::class.java)
|
||||
defaultMethodMapper["get"] = DefaultMethodResolver("findById")
|
||||
.sourceMethodParameters(Any::class.java).targetMethodParameters(Any::class.java)
|
||||
defaultMethodMapper["list"] = FindMethodResolver("findAll", entityClazz)
|
||||
.sourceMethodParameters(MutableMap::class.java).targetMethodParameters(Specification::class.java)
|
||||
defaultMethodMapper["count"] = FindMethodResolver("count", entityClazz)
|
||||
.sourceMethodParameters(MutableMap::class.java).targetMethodParameters(Specification::class.java)
|
||||
defaultMethodMapper["paging"] = PageMethodResolver("findAll", entityClazz)
|
||||
.sourceMethodParameters(Params::class.java)
|
||||
.targetMethodParameters(Specification::class.java, Pageable::class.java)
|
||||
}
|
||||
|
||||
/**
|
||||
* 增加用户自定义Query注解方法映射信息
|
||||
*
|
||||
* @param method 需要添加的方法
|
||||
*/
|
||||
fun addQueryMethodMapper(method: Method) {
|
||||
queryMethodMapper[method.toString()] = DefaultMethodResolver(method.name)
|
||||
}
|
||||
|
||||
fun setArgumentResolver(argumentResolver: AbstractMethodResolver?): JpaMethodProxy {
|
||||
this.argumentResolver = argumentResolver
|
||||
return this
|
||||
}
|
||||
|
||||
fun setEntityClazz(entityClazz: Class<*>) {
|
||||
this.entityClazz = entityClazz
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package com.synebula.gaea.jpa.proxy.method.resolver
|
||||
|
||||
import java.lang.reflect.Method
|
||||
|
||||
/**
|
||||
* 解决JPA方法参数的映射
|
||||
*
|
||||
* @param targetMethodName 目标方法名称
|
||||
*/
|
||||
abstract class AbstractMethodResolver(var targetMethodName: String) {
|
||||
|
||||
|
||||
/**
|
||||
* 方法相关实体类型
|
||||
*/
|
||||
lateinit var entityClazz: Class<*>
|
||||
|
||||
/**
|
||||
* 目标方法形参类型列表
|
||||
*
|
||||
*/
|
||||
lateinit var targetMethodParameters: Array<out Class<*>>
|
||||
|
||||
/**
|
||||
* 源方法形参类型列表
|
||||
*
|
||||
*/
|
||||
lateinit var sourceMethodParameters: Array<out Class<*>>
|
||||
|
||||
constructor(targetMethodName: String, entityClazz: Class<*>) : this(targetMethodName) {
|
||||
this.entityClazz = entityClazz
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析映射实参
|
||||
*
|
||||
* @param args 实参列表
|
||||
* @return 映射后的实参列表
|
||||
*/
|
||||
abstract fun mappingArguments(args: Array<Any>): Array<Any>
|
||||
|
||||
/**
|
||||
* 解析映射方法结果
|
||||
*
|
||||
* @param result 方法结果
|
||||
* @return 映射后的方法结果
|
||||
*/
|
||||
abstract fun mappingResult(result: Any): Any
|
||||
|
||||
/**
|
||||
* 获取源方法形参类型列表
|
||||
*/
|
||||
open fun sourceMethodParameters(vararg params: Class<*>): AbstractMethodResolver {
|
||||
this.sourceMethodParameters = params
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置目标方法形参类型列表
|
||||
*/
|
||||
open fun targetMethodParameters(vararg params: Class<*>): AbstractMethodResolver {
|
||||
this.targetMethodParameters = params
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 匹配方法名(目标方法)/参数是否复合
|
||||
*
|
||||
* @param method 需要匹配的方法
|
||||
* @param methodType 需要匹配的方法类型
|
||||
* @return ture/false
|
||||
*/
|
||||
fun match(method: Method, methodType: MethodType): Boolean {
|
||||
var methodParameters = sourceMethodParameters
|
||||
var matched = true
|
||||
|
||||
// 匹配目标方法的时候额外匹配下方法名
|
||||
if (methodType == MethodType.TargetMethod) {
|
||||
methodParameters = targetMethodParameters
|
||||
matched = method.name == targetMethodName
|
||||
}
|
||||
|
||||
// 如果[目标]方法名匹配, 判断参数是否匹配
|
||||
matched = matched && method.parameterCount == methodParameters.size
|
||||
if (matched) {
|
||||
for (i in methodParameters.indices) {
|
||||
val parameterTypes = method.parameterTypes
|
||||
if (methodParameters[i] != parameterTypes[i]) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
enum class MethodType {
|
||||
SourceMethod, TargetMethod
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.synebula.gaea.jpa.proxy.method.resolver
|
||||
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* 默认返回全部
|
||||
*/
|
||||
class DefaultMethodResolver(targetMethodName: String) : AbstractMethodResolver(targetMethodName) {
|
||||
|
||||
override fun mappingArguments(args: Array<Any>): Array<Any> {
|
||||
return args
|
||||
}
|
||||
|
||||
override fun mappingResult(result: Any): Any {
|
||||
if (result is Optional<*>) {
|
||||
return result.orElse(null)
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.synebula.gaea.jpa.proxy.method.resolver
|
||||
|
||||
import com.synebula.gaea.jpa.toSpecification
|
||||
|
||||
/**
|
||||
* 查询方法参数映射
|
||||
*/
|
||||
class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMethodResolver(targetMethodName, clazz) {
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun mappingArguments(args: Array<Any>): Array<Any> {
|
||||
val params = args[0] as Map<String, String>?
|
||||
val specification = params.toSpecification(entityClazz)
|
||||
return arrayOf(specification)
|
||||
}
|
||||
|
||||
override fun mappingResult(result: Any): Any {
|
||||
return result
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.synebula.gaea.jpa.proxy.method.resolver
|
||||
|
||||
import com.synebula.gaea.jpa.toSpecification
|
||||
import com.synebula.gaea.query.Order
|
||||
import com.synebula.gaea.query.Params
|
||||
import org.springframework.data.domain.Page
|
||||
import org.springframework.data.domain.PageRequest
|
||||
import org.springframework.data.domain.Pageable
|
||||
import org.springframework.data.domain.Sort
|
||||
import java.util.*
|
||||
import javax.persistence.EmbeddedId
|
||||
import javax.persistence.Id
|
||||
|
||||
/**
|
||||
* 分页方法参数映射
|
||||
*/
|
||||
class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMethodResolver(targetMethodName, clazz) {
|
||||
|
||||
override fun mappingArguments(args: Array<Any>): Array<Any> {
|
||||
return try {
|
||||
val params: Params? = args[0] as Params?
|
||||
val specification = params!!.parameters.toSpecification(entityClazz)
|
||||
var sort = Sort.unsorted()
|
||||
for (key in params.orders.keys) {
|
||||
val direction = if (params.orders[key] === Order.ASC) Sort.Direction.ASC else Sort.Direction.DESC
|
||||
sort = sort.and(Sort.by(direction, key))
|
||||
}
|
||||
if (sort.isEmpty) {
|
||||
val fields = entityClazz.declaredFields
|
||||
for (field in fields) {
|
||||
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
|
||||
(annotation.annotationClass.java == Id::class.java
|
||||
|| annotation.annotationClass.java == EmbeddedId::class.java)
|
||||
}
|
||||
if (isId) {
|
||||
sort = Sort.by(Sort.Direction.ASC, field.name)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pageable 页面从0开始
|
||||
val pageable: Pageable = PageRequest.of(params.page - 1, params.size, sort)
|
||||
arrayOf(specification, pageable)
|
||||
} catch (e: Exception) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
|
||||
override fun mappingResult(result: Any): Any {
|
||||
val page = result as Page<*>
|
||||
|
||||
// Page 页面从0开始
|
||||
return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user