1.4.0 增加jpa的代理模块

This commit is contained in:
2022-08-26 10:33:24 +08:00
parent db0b538741
commit 8860aecdfe
36 changed files with 1257 additions and 105 deletions

11
src/gaea.jpa/build.gradle Normal file
View File

@@ -0,0 +1,11 @@
ext {
jassist_version = '3.29.0-GA'
}
dependencies {
api project(":src:gaea")
implementation("org.springframework.boot:spring-boot-starter-data-jpa:$spring_version")
implementation("org.javassist:javassist:$jassist_version")
}

View File

@@ -0,0 +1,67 @@
package com.synebula.gaea.jpa
import com.synebula.gaea.query.IQuery
import com.synebula.gaea.query.Page
import com.synebula.gaea.query.Params
import org.springframework.data.jpa.repository.support.SimpleJpaRepository
import javax.persistence.EntityManager
class JpaQuery<TView, ID>(override var clazz: Class<TView>, entityManager: EntityManager) : IQuery<TView, ID> {
protected var repo: SimpleJpaRepository<TView, ID>
init {
repo = SimpleJpaRepository<TView, ID>(clazz, entityManager)
}
override operator fun get(id: ID): TView? {
val view = this.repo.findById(id)
return if (view.isPresent) view.get() else null
}
/**
* 根据实体类条件查询所有符合条件记录
*`
* @param params 查询条件。
* @return 视图列表
*/
override fun list(params: Map<String, String>?): List<TView> {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return emptyList()
}
/**
* 根据条件查询符合条件记录的数量
*
* @param params 查询条件。
* @return 数量
*/
override fun count(params: Map<String, String>?): Int {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return -1
}
/**
* 根据实体类条件查询所有符合条件记录(分页查询)
*
* @param params 分页条件
* @return 分页数据
*/
override fun paging(params: Params): Page<TView> {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return Page()
}
/**
* 查询条件范围内数据。
* @param field 查询字段
* @param params 查询条件
*
* @return 视图列表
*/
override fun range(field: String, params: List<Any>): List<TView> {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return emptyList()
}
}

View File

@@ -0,0 +1,91 @@
package com.synebula.gaea.jpa
import com.synebula.gaea.domain.model.IAggregateRoot
import com.synebula.gaea.domain.repository.IRepository
import org.springframework.data.jpa.repository.JpaRepository
import org.springframework.data.jpa.repository.support.SimpleJpaRepository
import javax.persistence.EntityManager
class JpaRepository<TAggregateRoot : IAggregateRoot<ID>, ID>(
override var clazz: Class<TAggregateRoot>,
entityManager: EntityManager
) : IRepository<TAggregateRoot, ID> {
protected var repo: JpaRepository<TAggregateRoot, ID>? = null
init {
repo = SimpleJpaRepository(clazz, entityManager)
}
/**
* 插入单个对象。
*
* @param obj 需要插入的对象。
* @return 返回原对象如果对象ID为自增则补充自增ID。
*/
override fun add(obj: TAggregateRoot) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 插入多个个对象。
*
* @param list 需要插入的对象。
* @return 返回原对象如果对象ID为自增则补充自增ID。
*/
override fun add(list: List<TAggregateRoot>) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 更新对象。
*
* @param obj 需要更新的对象。
* @return
*/
override fun update(obj: TAggregateRoot) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 更新多个个对象。
*
* @param list 需要新的对象。
*/
override fun update(list: List<TAggregateRoot>) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 通过id删除该条数据
*
* @param id 对象ID。
* @return
*/
override fun remove(id: ID) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 根据ID获取对象。
*
* @param id 对象ID。
* @return
*/
override fun get(id: ID): TAggregateRoot? {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return null
}
/**
* 根据条件查询符合条件记录的数量
*
* @param params 查询条件。
* @return int
*/
override fun count(params: Map<String, String>?): Int {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return -1
}
}

View File

@@ -0,0 +1,147 @@
package com.synebula.gaea.jpa
import com.synebula.gaea.data.date.DateTime
import com.synebula.gaea.query.Operator
import com.synebula.gaea.query.Where
import org.springframework.data.jpa.domain.Specification
import java.lang.reflect.Field
import java.util.*
import javax.persistence.criteria.CriteriaBuilder
import javax.persistence.criteria.CriteriaQuery
import javax.persistence.criteria.Predicate
import javax.persistence.criteria.Root
/**
* 类型转换
*
* @param field 字段对象
* @return object
*/
fun String.toFieldType(field: Field): Any? {
var result: Any? = this
val fieldType = field.type
if (fieldType != this.javaClass) {
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) {
result = this.toInt()
}
if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) {
result = this.toDouble()
}
if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
result = this.toFloat()
}
if (Date::class.java == fieldType) {
result = DateTime(this, "yyyy-MM-dd HH:mm:ss").date
}
}
return result
}
/**
* 格式化数值类型
*
* @param field 字段对象
* @return double
*/
fun String.tryToDigital(field: Field): Double {
val result: Double
val fieldType = field.type
result =
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType
|| Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType
|| Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType
) {
this.toDouble()
} else throw RuntimeException(
String.format(
"class [%s] field [%s] is not digital",
field.declaringClass.name,
field.name
)
)
return result
}
/**
* 参数 Map 转换查询 Specification
*
* @param clazz 类
* @return Specification
*/
fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> {
val rangeStartSuffix = "[0]" //范围查询开始后缀
val rangeEndSuffix = "[1]" //范围查询结束后缀
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
val predicates: MutableList<Predicate> = ArrayList()
for (argumentName in this!!.keys) {
if (this[argumentName] == null) continue
var fieldName = argumentName
var operator: Operator
// 判断是否为range类型(范围内查询)
var start = true
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
fieldName = fieldName.substring(fieldName.length - 3)
if (fieldName.endsWith(rangeEndSuffix)) start = false
}
val field = clazz.getDeclaredField(fieldName)
val where: Where = field.getDeclaredAnnotation(Where::class.java)
operator = where.operator
// 如果是范围内容, 判断是数值类型还是时间类型
if (operator === Operator.Range) {
if (clazz.getDeclaredField(fieldName).type != Date::class.java) {
operator = if (start) Operator.Gte else Operator.Lte
}
}
var predicate: Predicate
var digitalValue: Double
try {
when (operator) {
Operator.Ne -> predicate =
criteriaBuilder.notEqual(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
Operator.Lt -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue)
}
Operator.Gt -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue)
}
Operator.Lte -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue)
}
Operator.Gte -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue)
}
Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%")
Operator.Range -> {
predicate = if (start) {
criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
} else {
criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
}
}
else -> predicate =
criteriaBuilder.equal(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
}
predicates.add(predicate)
} catch (e: NoSuchFieldException) {
throw Error(
"class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringClass.simpleName}.${operator.name})]",
e
)
}
}
criteriaBuilder.and(*predicates.toTypedArray())
}
}

View File

@@ -0,0 +1,36 @@
package com.synebula.gaea.jpa.proxy
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.FactoryBean
import org.springframework.cglib.proxy.Enhancer
import org.springframework.data.repository.Repository
class JpaRepositoryFactory(
private val beanFactory: BeanFactory,
private val interfaceType: Class<*>,
private val implBeanNames: List<String>
) : FactoryBean<Any> {
override fun getObject(): Any {
val handler: JpaRepositoryProxy<*, *, *> = JpaRepositoryProxy<Repository<Any, Any>, Any, Any>(
beanFactory,
interfaceType, implBeanNames
)
//JDK 方式代理代码, 暂时选用cglib
//Object proxy = Proxy.newProxyInstance(this.interfaceType.getClassLoader(), new Class[]{this.interfaceType}, handler);
//cglib代理
val enhancer = Enhancer()
enhancer.setSuperclass(interfaceType)
enhancer.setCallback(handler)
return enhancer.create()
}
override fun getObjectType(): Class<*> {
return interfaceType
}
override fun isSingleton(): Boolean {
return true
}
}

View File

@@ -0,0 +1,289 @@
package com.synebula.gaea.jpa.proxy
import com.synebula.gaea.jpa.proxy.method.JpaMethodProxy
import javassist.*
import javassist.bytecode.AnnotationsAttribute
import javassist.bytecode.MethodInfo
import javassist.bytecode.SignatureAttribute
import javassist.bytecode.annotation.Annotation
import javassist.bytecode.annotation.BooleanMemberValue
import javassist.bytecode.annotation.StringMemberValue
import org.springframework.beans.BeansException
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.ObjectProvider
import org.springframework.cglib.proxy.MethodInterceptor
import org.springframework.cglib.proxy.MethodProxy
import org.springframework.data.jpa.repository.Modifying
import org.springframework.data.jpa.repository.Query
import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean
import org.springframework.data.jpa.repository.support.JpaRepositoryImplementation
import org.springframework.data.mapping.context.MappingContext
import org.springframework.data.querydsl.EntityPathResolver
import org.springframework.data.querydsl.SimpleEntityPathResolver
import org.springframework.data.repository.Repository
import java.lang.reflect.Method
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import javax.persistence.EntityManager
class JpaRepositoryProxy<T : Repository<S, ID>?, S, ID>(
beanFactory: BeanFactory,
interfaceType: Class<*>,
implementBeanNames: List<String>?
) : MethodInterceptor { //InvocationHandler {
//JPA 默认Entity Manager上下文, 如不用该上下文则没有事务管理器
private val EntityManagerName = "org.springframework.orm.jpa.SharedEntityManagerCreator#0"
/**
* JPA代理对象
*/
private var jpaRepository: JpaRepositoryImplementation<*, *>? = null
/**
* bean注册器
*/
protected var beanFactory: BeanFactory
/**
* 方法映射管理器
*/
protected var jpaMethodProxy: JpaMethodProxy
/**
* 需要代理的接口类型
*/
protected var interfaceType: Class<*>
/**
* 接口实现bean名称
*/
protected var implementBeanNames: List<String>
init {
try {
this.beanFactory = beanFactory
this.implementBeanNames = implementBeanNames ?: listOf()
this.interfaceType = interfaceType
// 设置方法映射查询参数类(Entity类型)
val type = interfaceType.genericInterfaces[0]
val typeArguments = (type as ParameterizedType).actualTypeArguments
jpaMethodProxy = JpaMethodProxy(Class.forName(typeArguments[0].typeName))
// 创建虚假的JpaRepository接口
val repoClazz = createJpaRepoClazz(*typeArguments)
val jpaRepositoryFactoryBean: JpaRepositoryFactoryBean<*, *, *> = createJPARepositoryFactoryBean(repoClazz)
jpaRepository = jpaRepositoryFactoryBean.getObject() as JpaRepositoryImplementation<*, *>
} catch (e: ClassNotFoundException) {
throw RuntimeException(e)
}
}
/**
* JDK 方式代理代码, 暂时选用cglib
*/
@Deprecated("")
@Throws(Throwable::class)
operator fun invoke(proxy: Any?, method: Method, args: Array<Any>): Any? {
return if (Any::class.java == method.declaringClass) {
method.invoke(proxy, *args)
} else {
execMethod(method, args)
}
}
/**
* 暂时选用cglib 方式代理代码
*/
@Throws(Throwable::class)
override fun intercept(o: Any, method: Method, args: Array<Any>, methodProxy: MethodProxy): Any {
return if (Any::class.java == method.declaringClass) {
methodProxy.invoke(this, args)
} else {
execMethod(method, args)!!
}
}
/**
* 执行代理方法
*
* @param method 需要执行的方法
* @param args 参数列表
* @return 方法执行结果
* @throws Throwable 异常
*/
@Throws(Throwable::class)
private fun execMethod(method: Method, args: Array<Any>): Any? {
// 找到对应代理方法, 代理执行
return if (jpaMethodProxy.match(method)) {
try {
jpaMethodProxy.proxyExecMethod(jpaRepository, method, args)
} catch (ex: Exception) {
throw RuntimeException(
String.format(
"对象[%s]代理执行方法[%s.%s]出错",
jpaMethodProxy.javaClass, interfaceType.name, method.name
), ex
)
}
} else {
// 找不到代理方法则查找具体实现类执行
if (implementBeanNames.isEmpty()) throw RuntimeException(
String.format(
"找不到[%s.%s]对应的代理方法",
method.declaringClass.name, method.name
)
) else {
val bean = beanFactory.getBean(implementBeanNames[0])
val proxyMethod = bean.javaClass.getMethod(method.name, *method.parameterTypes)
proxyMethod.invoke(bean, *args)
}
}
}
/**
* 使用javassist创建虚拟的jpa repo类
*
* @param typeArgs 泛型参数
* @return 虚拟的jpa repo类形
*/
@Suppress("unchecked_cast")
private fun createJpaRepoClazz(vararg typeArgs: Type): Class<T> {
return try {
val pool = ClassPool.getDefault()
val jpaRepoCt = pool[JpaRepositoryImplementation::class.java.name]
val clazzName = String.format("%sRepository", typeArgs[0].typeName)
val repoCt = pool.makeInterface(clazzName, jpaRepoCt)
val typeArguments = arrayOfNulls<SignatureAttribute.TypeArgument>(typeArgs.size)
for (i in typeArgs.indices) {
typeArguments[i] = SignatureAttribute.TypeArgument(SignatureAttribute.ClassType(typeArgs[i].typeName))
}
val ac = SignatureAttribute.ClassSignature(
null,
null,
arrayOf(SignatureAttribute.ClassType(jpaRepoCt.name, typeArguments))
)
repoCt.genericSignature = ac.encode()
addClassQueryMethod(repoCt)
repoCt.toClass() as Class<T>
} catch (ex: Exception) {
throw RuntimeException(ex)
}
}
/**
* 给虚拟接口添加Query注解方法
*
* @param ctClass jpa虚拟接口
*/
@Throws(NotFoundException::class, CannotCompileException::class, ClassNotFoundException::class)
private fun addClassQueryMethod(ctClass: CtClass) {
// 找到Query注解方法并加入到虚拟接口中
val interfaceCtClass = ClassPool.getDefault()[interfaceType.name]
for (ctMethod in interfaceCtClass.methods) {
val query = ctMethod.getAnnotation(Query::class.java)
if (query != null) {
val method = CtNewMethod.abstractMethod(
ctMethod.returnType, ctMethod.name,
ctMethod.parameterTypes, arrayOfNulls(0), ctClass
)
val methodInfo = method.methodInfo
var modifying: Modifying? = null
//查找有无@Modifing注解有的化虚拟接口也需要加上
val annotation = ctMethod.getAnnotation(Modifying::class.java)
if (annotation != null) {
modifying = annotation as Modifying
}
// 增加Query注解
val attribute = buildQueryAttribute(methodInfo, query as Query, modifying)
methodInfo.addAttribute(attribute)
ctClass.addMethod(method)
}
}
// Query注解方法加入到代理中
for (method in interfaceType.methods) {
val annotation: Any? = method.getAnnotation(Query::class.java)
if (annotation != null) {
jpaMethodProxy.addQueryMethodMapper(method)
}
}
}
/**
* 创建javassist方法Query注解
*
* @param methodInfo 方法信息
* @param query query注解实例
* @param modifying Modifying注解
* @return 注解信息
*/
private fun buildQueryAttribute(methodInfo: MethodInfo, query: Query, modifying: Modifying?): AnnotationsAttribute {
val cp = methodInfo.constPool
val attribute = AnnotationsAttribute(cp, AnnotationsAttribute.visibleTag)
val queryAnnotation = Annotation(Query::class.java.name, cp)
queryAnnotation.addMemberValue("value", StringMemberValue(query.value, cp))
queryAnnotation.addMemberValue("countQuery", StringMemberValue(query.countQuery, cp))
queryAnnotation.addMemberValue("countProjection", StringMemberValue(query.countProjection, cp))
queryAnnotation.addMemberValue("nativeQuery", BooleanMemberValue(query.nativeQuery, cp))
queryAnnotation.addMemberValue("name", StringMemberValue(query.name, cp))
queryAnnotation.addMemberValue("countName", StringMemberValue(query.countName, cp))
if (modifying != null) {
val modifyingAnnotation = Annotation(
Modifying::class.java.name, cp
)
modifyingAnnotation.addMemberValue(
"flushAutomatically",
BooleanMemberValue(modifying.flushAutomatically, cp)
)
modifyingAnnotation.addMemberValue(
"clearAutomatically",
BooleanMemberValue(modifying.clearAutomatically, cp)
)
attribute.annotations = arrayOf(queryAnnotation, modifyingAnnotation)
} else attribute.setAnnotation(queryAnnotation)
return attribute
}
/**
* 创建JpaRepositoryFactoryBean对象
*
* @param jpaRepositoryClass 需要创建的 JPA JpaRepository Class
* @return JpaRepositoryFactoryBean对象
*/
private fun createJPARepositoryFactoryBean(jpaRepositoryClass: Class<out T>): JpaRepositoryFactoryBean<T, S, ID> {
// jpa 默认使用改名成EntityManager, 若用默认则没有事务上下文
val entityManager = beanFactory.getBean(EntityManagerName) as EntityManager
val repositoryFactoryBean = JpaRepositoryFactoryBean(jpaRepositoryClass)
repositoryFactoryBean.setEntityManager(entityManager)
repositoryFactoryBean.setBeanFactory(beanFactory)
repositoryFactoryBean.setBeanClassLoader(JpaRepositoryFactoryBean::class.java.classLoader)
repositoryFactoryBean.setMappingContext(beanFactory.getBean("jpaMappingContext") as MappingContext<*, *>)
repositoryFactoryBean.setEntityPathResolver(object : ObjectProvider<EntityPathResolver> {
@Throws(BeansException::class)
override fun getObject(vararg objects: Any): EntityPathResolver {
return SimpleEntityPathResolver("")
}
@Throws(BeansException::class)
override fun getIfAvailable(): EntityPathResolver? {
return null
}
@Throws(BeansException::class)
override fun getIfUnique(): EntityPathResolver? {
return null
}
@Throws(BeansException::class)
override fun getObject(): EntityPathResolver {
return SimpleEntityPathResolver("")
}
})
repositoryFactoryBean.afterPropertiesSet()
return repositoryFactoryBean
}
}

View File

@@ -0,0 +1,12 @@
package com.synebula.gaea.jpa.proxy
import org.springframework.context.annotation.Import
import java.lang.annotation.Inherited
import kotlin.reflect.KClass
@Target(AnnotationTarget.ANNOTATION_CLASS, AnnotationTarget.CLASS)
@Retention(AnnotationRetention.RUNTIME)
@MustBeDocumented
@Inherited
@Import(JpaRepositoryRegister::class)
annotation class JpaRepositoryProxyScan(val basePackages: Array<String> = [], val scanInterfaces: Array<KClass<*>> = [])

View File

@@ -0,0 +1,141 @@
package com.synebula.gaea.jpa.proxy
import org.springframework.beans.BeansException
import org.springframework.beans.factory.BeanClassLoaderAware
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.BeanFactoryAware
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.support.BeanDefinitionBuilder
import org.springframework.beans.factory.support.BeanDefinitionRegistry
import org.springframework.beans.factory.support.GenericBeanDefinition
import org.springframework.context.EnvironmentAware
import org.springframework.context.ResourceLoaderAware
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar
import org.springframework.core.annotation.AnnotationAttributes
import org.springframework.core.env.Environment
import org.springframework.core.io.ResourceLoader
import org.springframework.core.type.AnnotationMetadata
import org.springframework.core.type.classreading.MetadataReader
import org.springframework.core.type.classreading.MetadataReaderFactory
import org.springframework.core.type.filter.TypeFilter
import org.springframework.util.ClassUtils
import java.util.*
import java.util.stream.Collectors
class JpaRepositoryRegister : ImportBeanDefinitionRegistrar, ResourceLoaderAware, BeanClassLoaderAware,
EnvironmentAware,
BeanFactoryAware {
private lateinit var environment: Environment
private lateinit var resourceLoader: ResourceLoader
private var classLoader: ClassLoader? = null
private var beanFactory: BeanFactory? = null
override fun registerBeanDefinitions(metadata: AnnotationMetadata, registry: BeanDefinitionRegistry) {
val attributes = AnnotationAttributes(
metadata.getAnnotationAttributes(
JpaRepositoryProxyScan::class.java.name
) ?: mapOf()
)
val basePackages = attributes.getStringArray("basePackages")
val scanInterfaces = attributes.getClassArray("scanInterfaces")
// 过滤scanInterfaces接口内容
val filter = getSubObjectTypeFilter(scanInterfaces)
val beanDefinitions = scan(basePackages, arrayOf(filter))
// 遍历处理接口
for (beanDefinition in beanDefinitions) {
// 获取RepositoryFor注解信息
val beanClazz: Class<*> = try {
Class.forName(beanDefinition.beanClassName)
} catch (e: ClassNotFoundException) {
throw RuntimeException(e)
}
val beanClazzTypeFilter = getSubObjectTypeFilter(arrayOf(beanClazz))
val implClazzDefinitions = scan(basePackages, arrayOf(beanClazzTypeFilter))
for (definition in implClazzDefinitions) {
definition.isAutowireCandidate = false
registry.registerBeanDefinition(Objects.requireNonNull(definition.beanClassName), definition)
}
// 构建bean定义
// 1 bean参数
val implBeanNames = implClazzDefinitions.stream().map { obj: BeanDefinition -> obj.beanClassName }
.collect(Collectors.toList())
val builder = BeanDefinitionBuilder.genericBeanDefinition(beanClazz)
builder.addConstructorArgValue(beanFactory)
builder.addConstructorArgValue(beanClazz)
builder.addConstructorArgValue(implBeanNames)
val definition = builder.rawBeanDefinition as GenericBeanDefinition
definition.beanClass = JpaRepositoryFactory::class.java
definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE
registry.registerBeanDefinition(beanClazz.name, definition)
}
}
/**
* 根据过滤器扫描直接包下bean
*
* @param packages 指定的扫描包
* @param filters 过滤器
* @return 扫描后的bean定义
*/
private fun scan(packages: Array<String>?, filters: Array<TypeFilter>): List<BeanDefinition> {
val scanner: ClassPathScanningCandidateComponentProvider =
object : ClassPathScanningCandidateComponentProvider() {
override fun isCandidateComponent(beanDefinition: AnnotatedBeanDefinition): Boolean {
try {
val metadata = beanDefinition.metadata
val target = ClassUtils.forName(metadata.className, classLoader)
return !target.isAnnotation
} catch (ignored: Exception) {
}
return false
}
}
scanner.environment = environment
scanner.resourceLoader = resourceLoader
for (filter in filters) {
scanner.addIncludeFilter(filter)
}
val beanDefinitions: MutableList<BeanDefinition> = LinkedList()
for (basePackage in packages!!) {
beanDefinitions.addAll(scanner.findCandidateComponents(basePackage))
}
return beanDefinitions
}
/**
* 获取父接口实现对象的类型过滤器
*
* @param interfaces 父接口
* @return 类型过滤器
*/
private fun getSubObjectTypeFilter(interfaces: Array<Class<*>>?): TypeFilter {
return TypeFilter { metadataReader: MetadataReader, _: MetadataReaderFactory? ->
val interfaceNames = metadataReader.classMetadata.interfaceNames
var matched = false
for (interfaceName in interfaceNames) {
matched = Arrays.stream(interfaces)
.anyMatch { clazz: Class<*> -> clazz.name == interfaceName }
}
matched
}
}
override fun setResourceLoader(resourceLoader: ResourceLoader) {
this.resourceLoader = resourceLoader
}
override fun setBeanClassLoader(classLoader: ClassLoader) {
this.classLoader = classLoader
}
override fun setEnvironment(environment: Environment) {
this.environment = environment
}
@Throws(BeansException::class)
override fun setBeanFactory(beanFactory: BeanFactory) {
this.beanFactory = beanFactory
}
}

View File

@@ -0,0 +1,141 @@
package com.synebula.gaea.jpa.proxy.method
import com.synebula.gaea.domain.model.IAggregateRoot
import com.synebula.gaea.jpa.proxy.method.resolver.AbstractMethodResolver
import com.synebula.gaea.jpa.proxy.method.resolver.DefaultMethodResolver
import com.synebula.gaea.jpa.proxy.method.resolver.FindMethodResolver
import com.synebula.gaea.jpa.proxy.method.resolver.PageMethodResolver
import com.synebula.gaea.query.Params
import org.springframework.data.domain.Pageable
import org.springframework.data.jpa.domain.Specification
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
/**
* Jpa 方法映射包装类
*/
class JpaMethodProxy(
/**
* 方法需要实现的实体类
*/
private var entityClazz: Class<*>
) {
/**
* 默认的方法映射配置(IRepository, IQuery 接口中定义的方法)
*/
private val defaultMethodMapper: MutableMap<String, AbstractMethodResolver?> = LinkedHashMap()
/**
* 用户自定义的query注解方法处理
*/
private val queryMethodMapper: MutableMap<String, AbstractMethodResolver> = LinkedHashMap()
/**
* 方法参数映射
*/
var argumentResolver: AbstractMethodResolver? = null
private set
init {
initDefaultMethodMapper()
}
/**
* 匹配方法是否需要代理
*
* @param method 方法
* @return ture/false
*/
fun match(method: Method): Boolean {
var isMatch = (defaultMethodMapper.containsKey(method.name)
&& defaultMethodMapper[method.name]!!.match(method, AbstractMethodResolver.MethodType.SourceMethod))
// 如果默认代理方法没有匹配,则查找Query方法映射
if (!isMatch) {
isMatch = queryMethodMapper.containsKey(method.toString())
}
return isMatch
}
/**
* 解析代理方法
*
* @param proxy 代理对象
* @param method 源方法
* @param args 参数列表
* @return 执行结果
*/
@Throws(NoSuchMethodException::class, InvocationTargetException::class, IllegalAccessException::class)
fun proxyExecMethod(proxy: Any?, method: Method, args: Array<Any>): Any? {
// 匹配方法是否需要代理
if (defaultMethodMapper.containsKey(method.name)) {
val resolver = defaultMethodMapper[method.name]
// 匹配参数是否相同
if (resolver!!.match(method, AbstractMethodResolver.MethodType.SourceMethod)) {
//遍历代理对象, 找到合适的代理方法
val targetMethod =
proxy!!.javaClass.getMethod(resolver.targetMethodName, *resolver.targetMethodParameters)
return try {
// 开始执行代理方法
val mappingArguments = resolver.mappingArguments(args)
val result = targetMethod.invoke(proxy, *mappingArguments)
resolver.mappingResult(result)
} catch (e: IllegalAccessException) {
throw RuntimeException(e)
} catch (e: InvocationTargetException) {
throw RuntimeException(e)
}
}
}
// 如果默认代理方法没有匹配,则查找Query方法映射
if (queryMethodMapper.containsKey(method.toString())) {
val targetMethod = proxy!!.javaClass.getMethod(method.name, *method.parameterTypes)
return targetMethod.invoke(proxy, *args)
}
throw RuntimeException(
String.format(
"方法[%s,%s]没有匹配的代理配置信息, 执行该方法前请先执行match方法判断",
method.declaringClass.name, method.name
)
)
}
/**
* 初始化默认的方法映射列表
*/
private fun initDefaultMethodMapper() {
defaultMethodMapper["add"] = DefaultMethodResolver("saveAndFlush")
.sourceMethodParameters(IAggregateRoot::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["update"] = DefaultMethodResolver("saveAndFlush")
.sourceMethodParameters(IAggregateRoot::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["remove"] = DefaultMethodResolver("deleteById")
.sourceMethodParameters(Any::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["get"] = DefaultMethodResolver("findById")
.sourceMethodParameters(Any::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["list"] = FindMethodResolver("findAll", entityClazz)
.sourceMethodParameters(MutableMap::class.java).targetMethodParameters(Specification::class.java)
defaultMethodMapper["count"] = FindMethodResolver("count", entityClazz)
.sourceMethodParameters(MutableMap::class.java).targetMethodParameters(Specification::class.java)
defaultMethodMapper["paging"] = PageMethodResolver("findAll", entityClazz)
.sourceMethodParameters(Params::class.java)
.targetMethodParameters(Specification::class.java, Pageable::class.java)
}
/**
* 增加用户自定义Query注解方法映射信息
*
* @param method 需要添加的方法
*/
fun addQueryMethodMapper(method: Method) {
queryMethodMapper[method.toString()] = DefaultMethodResolver(method.name)
}
fun setArgumentResolver(argumentResolver: AbstractMethodResolver?): JpaMethodProxy {
this.argumentResolver = argumentResolver
return this
}
fun setEntityClazz(entityClazz: Class<*>) {
this.entityClazz = entityClazz
}
}

View File

@@ -0,0 +1,100 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import java.lang.reflect.Method
/**
* 解决JPA方法参数的映射
*
* @param targetMethodName 目标方法名称
*/
abstract class AbstractMethodResolver(var targetMethodName: String) {
/**
* 方法相关实体类型
*/
lateinit var entityClazz: Class<*>
/**
* 目标方法形参类型列表
*
*/
lateinit var targetMethodParameters: Array<out Class<*>>
/**
* 源方法形参类型列表
*
*/
lateinit var sourceMethodParameters: Array<out Class<*>>
constructor(targetMethodName: String, entityClazz: Class<*>) : this(targetMethodName) {
this.entityClazz = entityClazz
}
/**
* 解析映射实参
*
* @param args 实参列表
* @return 映射后的实参列表
*/
abstract fun mappingArguments(args: Array<Any>): Array<Any>
/**
* 解析映射方法结果
*
* @param result 方法结果
* @return 映射后的方法结果
*/
abstract fun mappingResult(result: Any): Any
/**
* 获取源方法形参类型列表
*/
open fun sourceMethodParameters(vararg params: Class<*>): AbstractMethodResolver {
this.sourceMethodParameters = params
return this
}
/**
* 设置目标方法形参类型列表
*/
open fun targetMethodParameters(vararg params: Class<*>): AbstractMethodResolver {
this.targetMethodParameters = params
return this
}
/**
* 匹配方法名(目标方法)/参数是否复合
*
* @param method 需要匹配的方法
* @param methodType 需要匹配的方法类型
* @return ture/false
*/
fun match(method: Method, methodType: MethodType): Boolean {
var methodParameters = sourceMethodParameters
var matched = true
// 匹配目标方法的时候额外匹配下方法名
if (methodType == MethodType.TargetMethod) {
methodParameters = targetMethodParameters
matched = method.name == targetMethodName
}
// 如果[目标]方法名匹配, 判断参数是否匹配
matched = matched && method.parameterCount == methodParameters.size
if (matched) {
for (i in methodParameters.indices) {
val parameterTypes = method.parameterTypes
if (methodParameters[i] != parameterTypes[i]) {
matched = false
break
}
}
}
return matched
}
enum class MethodType {
SourceMethod, TargetMethod
}
}

View File

@@ -0,0 +1,20 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import java.util.*
/**
* 默认返回全部
*/
class DefaultMethodResolver(targetMethodName: String) : AbstractMethodResolver(targetMethodName) {
override fun mappingArguments(args: Array<Any>): Array<Any> {
return args
}
override fun mappingResult(result: Any): Any {
if (result is Optional<*>) {
return result.orElse(null)
}
return result
}
}

View File

@@ -0,0 +1,20 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import com.synebula.gaea.jpa.toSpecification
/**
* 查询方法参数映射
*/
class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMethodResolver(targetMethodName, clazz) {
@Suppress("UNCHECKED_CAST")
override fun mappingArguments(args: Array<Any>): Array<Any> {
val params = args[0] as Map<String, String>?
val specification = params.toSpecification(entityClazz)
return arrayOf(specification)
}
override fun mappingResult(result: Any): Any {
return result
}
}

View File

@@ -0,0 +1,56 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import com.synebula.gaea.jpa.toSpecification
import com.synebula.gaea.query.Order
import com.synebula.gaea.query.Params
import org.springframework.data.domain.Page
import org.springframework.data.domain.PageRequest
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import java.util.*
import javax.persistence.EmbeddedId
import javax.persistence.Id
/**
* 分页方法参数映射
*/
class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMethodResolver(targetMethodName, clazz) {
override fun mappingArguments(args: Array<Any>): Array<Any> {
return try {
val params: Params? = args[0] as Params?
val specification = params!!.parameters.toSpecification(entityClazz)
var sort = Sort.unsorted()
for (key in params.orders.keys) {
val direction = if (params.orders[key] === Order.ASC) Sort.Direction.ASC else Sort.Direction.DESC
sort = sort.and(Sort.by(direction, key))
}
if (sort.isEmpty) {
val fields = entityClazz.declaredFields
for (field in fields) {
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
(annotation.annotationClass.java == Id::class.java
|| annotation.annotationClass.java == EmbeddedId::class.java)
}
if (isId) {
sort = Sort.by(Sort.Direction.ASC, field.name)
break
}
}
}
// Pageable 页面从0开始
val pageable: Pageable = PageRequest.of(params.page - 1, params.size, sort)
arrayOf(specification, pageable)
} catch (e: Exception) {
throw RuntimeException(e)
}
}
override fun mappingResult(result: Any): Any {
val page = result as Page<*>
// Page 页面从0开始
return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content)
}
}