1.4.0 增加jpa的代理模块

This commit is contained in:
2022-08-26 10:33:24 +08:00
parent db0b538741
commit 8860aecdfe
36 changed files with 1257 additions and 105 deletions

View File

@@ -17,7 +17,7 @@ buildscript {
subprojects { subprojects {
group 'com.synebula' group 'com.synebula'
version '1.3.0' version '1.4.0'
buildscript { buildscript {
repositories { repositories {
@@ -53,4 +53,23 @@ subprojects {
compileTestKotlin { compileTestKotlin {
kotlinOptions.jvmTarget = "1.8" kotlinOptions.jvmTarget = "1.8"
} }
publishing {
// repositories {
// maven {
// allowInsecureProtocol = true
// url = "$nexus_url"
// credentials {
// username = "$nexus_usr"
// password = "$nexus_pwd"
// }
// }
// }
publications {
mavenJava(MavenPublication) {
from components.java
}
}
}
} }

View File

@@ -1,6 +1,7 @@
rootProject.name = 'gaea' rootProject.name = 'gaea'
include 'src:gaea' include 'src:gaea'
include 'src:gaea.app' include 'src:gaea.app'
include 'src:gaea.mongodb'
include 'src:gaea.spring' include 'src:gaea.spring'
include 'src:gaea.mongodb'
include 'src:gaea.jpa'

View File

@@ -19,11 +19,3 @@ dependencies {
api group: 'com.auth0', name: 'java-jwt', version: '3.14.0' api group: 'com.auth0', name: 'java-jwt', version: '3.14.0'
} }
publishing {
publications {
publish(MavenPublication) {
from components.java
}
}
}

View File

@@ -26,7 +26,7 @@ interface IQueryApp<TView, ID> : IApplication {
@Method("获取列表数据") @Method("获取列表数据")
@GetMapping @GetMapping
fun list(@RequestParam params: LinkedHashMap<String, Any>): HttpMessage { fun list(@RequestParam params: LinkedHashMap<String, String>): HttpMessage {
val data = this.query.list(params) val data = this.query.list(params)
return HttpMessage(data) return HttpMessage(data)
} }
@@ -36,7 +36,7 @@ interface IQueryApp<TView, ID> : IApplication {
fun paging( fun paging(
@PathVariable size: Int, @PathVariable size: Int,
@PathVariable page: Int, @PathVariable page: Int,
@RequestParam parameters: LinkedHashMap<String, Any> @RequestParam parameters: LinkedHashMap<String, String>
): HttpMessage { ): HttpMessage {
val params = Params(page, size, parameters) val params = Params(page, size, parameters)
val data = this.query.paging(params) val data = this.query.paging(params)

11
src/gaea.jpa/build.gradle Normal file
View File

@@ -0,0 +1,11 @@
ext {
jassist_version = '3.29.0-GA'
}
dependencies {
api project(":src:gaea")
implementation("org.springframework.boot:spring-boot-starter-data-jpa:$spring_version")
implementation("org.javassist:javassist:$jassist_version")
}

View File

@@ -0,0 +1,67 @@
package com.synebula.gaea.jpa
import com.synebula.gaea.query.IQuery
import com.synebula.gaea.query.Page
import com.synebula.gaea.query.Params
import org.springframework.data.jpa.repository.support.SimpleJpaRepository
import javax.persistence.EntityManager
class JpaQuery<TView, ID>(override var clazz: Class<TView>, entityManager: EntityManager) : IQuery<TView, ID> {
protected var repo: SimpleJpaRepository<TView, ID>
init {
repo = SimpleJpaRepository<TView, ID>(clazz, entityManager)
}
override operator fun get(id: ID): TView? {
val view = this.repo.findById(id)
return if (view.isPresent) view.get() else null
}
/**
* 根据实体类条件查询所有符合条件记录
*`
* @param params 查询条件。
* @return 视图列表
*/
override fun list(params: Map<String, String>?): List<TView> {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return emptyList()
}
/**
* 根据条件查询符合条件记录的数量
*
* @param params 查询条件。
* @return 数量
*/
override fun count(params: Map<String, String>?): Int {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return -1
}
/**
* 根据实体类条件查询所有符合条件记录(分页查询)
*
* @param params 分页条件
* @return 分页数据
*/
override fun paging(params: Params): Page<TView> {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return Page()
}
/**
* 查询条件范围内数据。
* @param field 查询字段
* @param params 查询条件
*
* @return 视图列表
*/
override fun range(field: String, params: List<Any>): List<TView> {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return emptyList()
}
}

View File

@@ -0,0 +1,91 @@
package com.synebula.gaea.jpa
import com.synebula.gaea.domain.model.IAggregateRoot
import com.synebula.gaea.domain.repository.IRepository
import org.springframework.data.jpa.repository.JpaRepository
import org.springframework.data.jpa.repository.support.SimpleJpaRepository
import javax.persistence.EntityManager
class JpaRepository<TAggregateRoot : IAggregateRoot<ID>, ID>(
override var clazz: Class<TAggregateRoot>,
entityManager: EntityManager
) : IRepository<TAggregateRoot, ID> {
protected var repo: JpaRepository<TAggregateRoot, ID>? = null
init {
repo = SimpleJpaRepository(clazz, entityManager)
}
/**
* 插入单个对象。
*
* @param obj 需要插入的对象。
* @return 返回原对象如果对象ID为自增则补充自增ID。
*/
override fun add(obj: TAggregateRoot) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 插入多个个对象。
*
* @param list 需要插入的对象。
* @return 返回原对象如果对象ID为自增则补充自增ID。
*/
override fun add(list: List<TAggregateRoot>) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 更新对象。
*
* @param obj 需要更新的对象。
* @return
*/
override fun update(obj: TAggregateRoot) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 更新多个个对象。
*
* @param list 需要新的对象。
*/
override fun update(list: List<TAggregateRoot>) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 通过id删除该条数据
*
* @param id 对象ID。
* @return
*/
override fun remove(id: ID) {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
}
/**
* 根据ID获取对象。
*
* @param id 对象ID。
* @return
*/
override fun get(id: ID): TAggregateRoot? {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return null
}
/**
* 根据条件查询符合条件记录的数量
*
* @param params 查询条件。
* @return int
*/
override fun count(params: Map<String, String>?): Int {
// method proxy in JpaRepositoryProxy [SimpleJpaRepository]
return -1
}
}

View File

@@ -0,0 +1,147 @@
package com.synebula.gaea.jpa
import com.synebula.gaea.data.date.DateTime
import com.synebula.gaea.query.Operator
import com.synebula.gaea.query.Where
import org.springframework.data.jpa.domain.Specification
import java.lang.reflect.Field
import java.util.*
import javax.persistence.criteria.CriteriaBuilder
import javax.persistence.criteria.CriteriaQuery
import javax.persistence.criteria.Predicate
import javax.persistence.criteria.Root
/**
* 类型转换
*
* @param field 字段对象
* @return object
*/
fun String.toFieldType(field: Field): Any? {
var result: Any? = this
val fieldType = field.type
if (fieldType != this.javaClass) {
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType) {
result = this.toInt()
}
if (Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType) {
result = this.toDouble()
}
if (Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType) {
result = this.toFloat()
}
if (Date::class.java == fieldType) {
result = DateTime(this, "yyyy-MM-dd HH:mm:ss").date
}
}
return result
}
/**
* 格式化数值类型
*
* @param field 字段对象
* @return double
*/
fun String.tryToDigital(field: Field): Double {
val result: Double
val fieldType = field.type
result =
if (Int::class.java == fieldType || Int::class.javaPrimitiveType == fieldType
|| Double::class.java == fieldType || Double::class.javaPrimitiveType == fieldType
|| Float::class.java == fieldType || Float::class.javaPrimitiveType == fieldType
) {
this.toDouble()
} else throw RuntimeException(
String.format(
"class [%s] field [%s] is not digital",
field.declaringClass.name,
field.name
)
)
return result
}
/**
* 参数 Map 转换查询 Specification
*
* @param clazz 类
* @return Specification
*/
fun Map<String, String>?.toSpecification(clazz: Class<*>): Specification<*> {
val rangeStartSuffix = "[0]" //范围查询开始后缀
val rangeEndSuffix = "[1]" //范围查询结束后缀
return Specification<Any?> { root: Root<Any?>, _: CriteriaQuery<*>?, criteriaBuilder: CriteriaBuilder ->
val predicates: MutableList<Predicate> = ArrayList()
for (argumentName in this!!.keys) {
if (this[argumentName] == null) continue
var fieldName = argumentName
var operator: Operator
// 判断是否为range类型(范围内查询)
var start = true
if (fieldName.endsWith(rangeStartSuffix) || fieldName.endsWith(rangeEndSuffix)) {
fieldName = fieldName.substring(fieldName.length - 3)
if (fieldName.endsWith(rangeEndSuffix)) start = false
}
val field = clazz.getDeclaredField(fieldName)
val where: Where = field.getDeclaredAnnotation(Where::class.java)
operator = where.operator
// 如果是范围内容, 判断是数值类型还是时间类型
if (operator === Operator.Range) {
if (clazz.getDeclaredField(fieldName).type != Date::class.java) {
operator = if (start) Operator.Gte else Operator.Lte
}
}
var predicate: Predicate
var digitalValue: Double
try {
when (operator) {
Operator.Ne -> predicate =
criteriaBuilder.notEqual(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
Operator.Lt -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.lessThan(root.get(fieldName), digitalValue)
}
Operator.Gt -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.greaterThan(root.get(fieldName), digitalValue)
}
Operator.Lte -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), digitalValue)
}
Operator.Gte -> {
digitalValue = this[fieldName]!!.tryToDigital(field)
predicate = criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), digitalValue)
}
Operator.Like -> predicate = criteriaBuilder.like(root.get(fieldName), "%${this[fieldName]}%")
Operator.Range -> {
predicate = if (start) {
criteriaBuilder.greaterThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
} else {
criteriaBuilder.lessThanOrEqualTo(root.get(fieldName), this[argumentName]!!)
}
}
else -> predicate =
criteriaBuilder.equal(root.get<Any>(fieldName), this[fieldName]!!.toFieldType(field))
}
predicates.add(predicate)
} catch (e: NoSuchFieldException) {
throw Error(
"class [${field.declaringClass.name}] field [${field.name}] can't annotation [@Where(${operator.declaringClass.simpleName}.${operator.name})]",
e
)
}
}
criteriaBuilder.and(*predicates.toTypedArray())
}
}

View File

@@ -0,0 +1,36 @@
package com.synebula.gaea.jpa.proxy
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.FactoryBean
import org.springframework.cglib.proxy.Enhancer
import org.springframework.data.repository.Repository
class JpaRepositoryFactory(
private val beanFactory: BeanFactory,
private val interfaceType: Class<*>,
private val implBeanNames: List<String>
) : FactoryBean<Any> {
override fun getObject(): Any {
val handler: JpaRepositoryProxy<*, *, *> = JpaRepositoryProxy<Repository<Any, Any>, Any, Any>(
beanFactory,
interfaceType, implBeanNames
)
//JDK 方式代理代码, 暂时选用cglib
//Object proxy = Proxy.newProxyInstance(this.interfaceType.getClassLoader(), new Class[]{this.interfaceType}, handler);
//cglib代理
val enhancer = Enhancer()
enhancer.setSuperclass(interfaceType)
enhancer.setCallback(handler)
return enhancer.create()
}
override fun getObjectType(): Class<*> {
return interfaceType
}
override fun isSingleton(): Boolean {
return true
}
}

View File

@@ -0,0 +1,289 @@
package com.synebula.gaea.jpa.proxy
import com.synebula.gaea.jpa.proxy.method.JpaMethodProxy
import javassist.*
import javassist.bytecode.AnnotationsAttribute
import javassist.bytecode.MethodInfo
import javassist.bytecode.SignatureAttribute
import javassist.bytecode.annotation.Annotation
import javassist.bytecode.annotation.BooleanMemberValue
import javassist.bytecode.annotation.StringMemberValue
import org.springframework.beans.BeansException
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.ObjectProvider
import org.springframework.cglib.proxy.MethodInterceptor
import org.springframework.cglib.proxy.MethodProxy
import org.springframework.data.jpa.repository.Modifying
import org.springframework.data.jpa.repository.Query
import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean
import org.springframework.data.jpa.repository.support.JpaRepositoryImplementation
import org.springframework.data.mapping.context.MappingContext
import org.springframework.data.querydsl.EntityPathResolver
import org.springframework.data.querydsl.SimpleEntityPathResolver
import org.springframework.data.repository.Repository
import java.lang.reflect.Method
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import javax.persistence.EntityManager
class JpaRepositoryProxy<T : Repository<S, ID>?, S, ID>(
beanFactory: BeanFactory,
interfaceType: Class<*>,
implementBeanNames: List<String>?
) : MethodInterceptor { //InvocationHandler {
//JPA 默认Entity Manager上下文, 如不用该上下文则没有事务管理器
private val EntityManagerName = "org.springframework.orm.jpa.SharedEntityManagerCreator#0"
/**
* JPA代理对象
*/
private var jpaRepository: JpaRepositoryImplementation<*, *>? = null
/**
* bean注册器
*/
protected var beanFactory: BeanFactory
/**
* 方法映射管理器
*/
protected var jpaMethodProxy: JpaMethodProxy
/**
* 需要代理的接口类型
*/
protected var interfaceType: Class<*>
/**
* 接口实现bean名称
*/
protected var implementBeanNames: List<String>
init {
try {
this.beanFactory = beanFactory
this.implementBeanNames = implementBeanNames ?: listOf()
this.interfaceType = interfaceType
// 设置方法映射查询参数类(Entity类型)
val type = interfaceType.genericInterfaces[0]
val typeArguments = (type as ParameterizedType).actualTypeArguments
jpaMethodProxy = JpaMethodProxy(Class.forName(typeArguments[0].typeName))
// 创建虚假的JpaRepository接口
val repoClazz = createJpaRepoClazz(*typeArguments)
val jpaRepositoryFactoryBean: JpaRepositoryFactoryBean<*, *, *> = createJPARepositoryFactoryBean(repoClazz)
jpaRepository = jpaRepositoryFactoryBean.getObject() as JpaRepositoryImplementation<*, *>
} catch (e: ClassNotFoundException) {
throw RuntimeException(e)
}
}
/**
* JDK 方式代理代码, 暂时选用cglib
*/
@Deprecated("")
@Throws(Throwable::class)
operator fun invoke(proxy: Any?, method: Method, args: Array<Any>): Any? {
return if (Any::class.java == method.declaringClass) {
method.invoke(proxy, *args)
} else {
execMethod(method, args)
}
}
/**
* 暂时选用cglib 方式代理代码
*/
@Throws(Throwable::class)
override fun intercept(o: Any, method: Method, args: Array<Any>, methodProxy: MethodProxy): Any {
return if (Any::class.java == method.declaringClass) {
methodProxy.invoke(this, args)
} else {
execMethod(method, args)!!
}
}
/**
* 执行代理方法
*
* @param method 需要执行的方法
* @param args 参数列表
* @return 方法执行结果
* @throws Throwable 异常
*/
@Throws(Throwable::class)
private fun execMethod(method: Method, args: Array<Any>): Any? {
// 找到对应代理方法, 代理执行
return if (jpaMethodProxy.match(method)) {
try {
jpaMethodProxy.proxyExecMethod(jpaRepository, method, args)
} catch (ex: Exception) {
throw RuntimeException(
String.format(
"对象[%s]代理执行方法[%s.%s]出错",
jpaMethodProxy.javaClass, interfaceType.name, method.name
), ex
)
}
} else {
// 找不到代理方法则查找具体实现类执行
if (implementBeanNames.isEmpty()) throw RuntimeException(
String.format(
"找不到[%s.%s]对应的代理方法",
method.declaringClass.name, method.name
)
) else {
val bean = beanFactory.getBean(implementBeanNames[0])
val proxyMethod = bean.javaClass.getMethod(method.name, *method.parameterTypes)
proxyMethod.invoke(bean, *args)
}
}
}
/**
* 使用javassist创建虚拟的jpa repo类
*
* @param typeArgs 泛型参数
* @return 虚拟的jpa repo类形
*/
@Suppress("unchecked_cast")
private fun createJpaRepoClazz(vararg typeArgs: Type): Class<T> {
return try {
val pool = ClassPool.getDefault()
val jpaRepoCt = pool[JpaRepositoryImplementation::class.java.name]
val clazzName = String.format("%sRepository", typeArgs[0].typeName)
val repoCt = pool.makeInterface(clazzName, jpaRepoCt)
val typeArguments = arrayOfNulls<SignatureAttribute.TypeArgument>(typeArgs.size)
for (i in typeArgs.indices) {
typeArguments[i] = SignatureAttribute.TypeArgument(SignatureAttribute.ClassType(typeArgs[i].typeName))
}
val ac = SignatureAttribute.ClassSignature(
null,
null,
arrayOf(SignatureAttribute.ClassType(jpaRepoCt.name, typeArguments))
)
repoCt.genericSignature = ac.encode()
addClassQueryMethod(repoCt)
repoCt.toClass() as Class<T>
} catch (ex: Exception) {
throw RuntimeException(ex)
}
}
/**
* 给虚拟接口添加Query注解方法
*
* @param ctClass jpa虚拟接口
*/
@Throws(NotFoundException::class, CannotCompileException::class, ClassNotFoundException::class)
private fun addClassQueryMethod(ctClass: CtClass) {
// 找到Query注解方法并加入到虚拟接口中
val interfaceCtClass = ClassPool.getDefault()[interfaceType.name]
for (ctMethod in interfaceCtClass.methods) {
val query = ctMethod.getAnnotation(Query::class.java)
if (query != null) {
val method = CtNewMethod.abstractMethod(
ctMethod.returnType, ctMethod.name,
ctMethod.parameterTypes, arrayOfNulls(0), ctClass
)
val methodInfo = method.methodInfo
var modifying: Modifying? = null
//查找有无@Modifing注解有的化虚拟接口也需要加上
val annotation = ctMethod.getAnnotation(Modifying::class.java)
if (annotation != null) {
modifying = annotation as Modifying
}
// 增加Query注解
val attribute = buildQueryAttribute(methodInfo, query as Query, modifying)
methodInfo.addAttribute(attribute)
ctClass.addMethod(method)
}
}
// Query注解方法加入到代理中
for (method in interfaceType.methods) {
val annotation: Any? = method.getAnnotation(Query::class.java)
if (annotation != null) {
jpaMethodProxy.addQueryMethodMapper(method)
}
}
}
/**
* 创建javassist方法Query注解
*
* @param methodInfo 方法信息
* @param query query注解实例
* @param modifying Modifying注解
* @return 注解信息
*/
private fun buildQueryAttribute(methodInfo: MethodInfo, query: Query, modifying: Modifying?): AnnotationsAttribute {
val cp = methodInfo.constPool
val attribute = AnnotationsAttribute(cp, AnnotationsAttribute.visibleTag)
val queryAnnotation = Annotation(Query::class.java.name, cp)
queryAnnotation.addMemberValue("value", StringMemberValue(query.value, cp))
queryAnnotation.addMemberValue("countQuery", StringMemberValue(query.countQuery, cp))
queryAnnotation.addMemberValue("countProjection", StringMemberValue(query.countProjection, cp))
queryAnnotation.addMemberValue("nativeQuery", BooleanMemberValue(query.nativeQuery, cp))
queryAnnotation.addMemberValue("name", StringMemberValue(query.name, cp))
queryAnnotation.addMemberValue("countName", StringMemberValue(query.countName, cp))
if (modifying != null) {
val modifyingAnnotation = Annotation(
Modifying::class.java.name, cp
)
modifyingAnnotation.addMemberValue(
"flushAutomatically",
BooleanMemberValue(modifying.flushAutomatically, cp)
)
modifyingAnnotation.addMemberValue(
"clearAutomatically",
BooleanMemberValue(modifying.clearAutomatically, cp)
)
attribute.annotations = arrayOf(queryAnnotation, modifyingAnnotation)
} else attribute.setAnnotation(queryAnnotation)
return attribute
}
/**
* 创建JpaRepositoryFactoryBean对象
*
* @param jpaRepositoryClass 需要创建的 JPA JpaRepository Class
* @return JpaRepositoryFactoryBean对象
*/
private fun createJPARepositoryFactoryBean(jpaRepositoryClass: Class<out T>): JpaRepositoryFactoryBean<T, S, ID> {
// jpa 默认使用改名成EntityManager, 若用默认则没有事务上下文
val entityManager = beanFactory.getBean(EntityManagerName) as EntityManager
val repositoryFactoryBean = JpaRepositoryFactoryBean(jpaRepositoryClass)
repositoryFactoryBean.setEntityManager(entityManager)
repositoryFactoryBean.setBeanFactory(beanFactory)
repositoryFactoryBean.setBeanClassLoader(JpaRepositoryFactoryBean::class.java.classLoader)
repositoryFactoryBean.setMappingContext(beanFactory.getBean("jpaMappingContext") as MappingContext<*, *>)
repositoryFactoryBean.setEntityPathResolver(object : ObjectProvider<EntityPathResolver> {
@Throws(BeansException::class)
override fun getObject(vararg objects: Any): EntityPathResolver {
return SimpleEntityPathResolver("")
}
@Throws(BeansException::class)
override fun getIfAvailable(): EntityPathResolver? {
return null
}
@Throws(BeansException::class)
override fun getIfUnique(): EntityPathResolver? {
return null
}
@Throws(BeansException::class)
override fun getObject(): EntityPathResolver {
return SimpleEntityPathResolver("")
}
})
repositoryFactoryBean.afterPropertiesSet()
return repositoryFactoryBean
}
}

View File

@@ -0,0 +1,12 @@
package com.synebula.gaea.jpa.proxy
import org.springframework.context.annotation.Import
import java.lang.annotation.Inherited
import kotlin.reflect.KClass
@Target(AnnotationTarget.ANNOTATION_CLASS, AnnotationTarget.CLASS)
@Retention(AnnotationRetention.RUNTIME)
@MustBeDocumented
@Inherited
@Import(JpaRepositoryRegister::class)
annotation class JpaRepositoryProxyScan(val basePackages: Array<String> = [], val scanInterfaces: Array<KClass<*>> = [])

View File

@@ -0,0 +1,141 @@
package com.synebula.gaea.jpa.proxy
import org.springframework.beans.BeansException
import org.springframework.beans.factory.BeanClassLoaderAware
import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.BeanFactoryAware
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.support.BeanDefinitionBuilder
import org.springframework.beans.factory.support.BeanDefinitionRegistry
import org.springframework.beans.factory.support.GenericBeanDefinition
import org.springframework.context.EnvironmentAware
import org.springframework.context.ResourceLoaderAware
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar
import org.springframework.core.annotation.AnnotationAttributes
import org.springframework.core.env.Environment
import org.springframework.core.io.ResourceLoader
import org.springframework.core.type.AnnotationMetadata
import org.springframework.core.type.classreading.MetadataReader
import org.springframework.core.type.classreading.MetadataReaderFactory
import org.springframework.core.type.filter.TypeFilter
import org.springframework.util.ClassUtils
import java.util.*
import java.util.stream.Collectors
class JpaRepositoryRegister : ImportBeanDefinitionRegistrar, ResourceLoaderAware, BeanClassLoaderAware,
EnvironmentAware,
BeanFactoryAware {
private lateinit var environment: Environment
private lateinit var resourceLoader: ResourceLoader
private var classLoader: ClassLoader? = null
private var beanFactory: BeanFactory? = null
override fun registerBeanDefinitions(metadata: AnnotationMetadata, registry: BeanDefinitionRegistry) {
val attributes = AnnotationAttributes(
metadata.getAnnotationAttributes(
JpaRepositoryProxyScan::class.java.name
) ?: mapOf()
)
val basePackages = attributes.getStringArray("basePackages")
val scanInterfaces = attributes.getClassArray("scanInterfaces")
// 过滤scanInterfaces接口内容
val filter = getSubObjectTypeFilter(scanInterfaces)
val beanDefinitions = scan(basePackages, arrayOf(filter))
// 遍历处理接口
for (beanDefinition in beanDefinitions) {
// 获取RepositoryFor注解信息
val beanClazz: Class<*> = try {
Class.forName(beanDefinition.beanClassName)
} catch (e: ClassNotFoundException) {
throw RuntimeException(e)
}
val beanClazzTypeFilter = getSubObjectTypeFilter(arrayOf(beanClazz))
val implClazzDefinitions = scan(basePackages, arrayOf(beanClazzTypeFilter))
for (definition in implClazzDefinitions) {
definition.isAutowireCandidate = false
registry.registerBeanDefinition(Objects.requireNonNull(definition.beanClassName), definition)
}
// 构建bean定义
// 1 bean参数
val implBeanNames = implClazzDefinitions.stream().map { obj: BeanDefinition -> obj.beanClassName }
.collect(Collectors.toList())
val builder = BeanDefinitionBuilder.genericBeanDefinition(beanClazz)
builder.addConstructorArgValue(beanFactory)
builder.addConstructorArgValue(beanClazz)
builder.addConstructorArgValue(implBeanNames)
val definition = builder.rawBeanDefinition as GenericBeanDefinition
definition.beanClass = JpaRepositoryFactory::class.java
definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE
registry.registerBeanDefinition(beanClazz.name, definition)
}
}
/**
* 根据过滤器扫描直接包下bean
*
* @param packages 指定的扫描包
* @param filters 过滤器
* @return 扫描后的bean定义
*/
private fun scan(packages: Array<String>?, filters: Array<TypeFilter>): List<BeanDefinition> {
val scanner: ClassPathScanningCandidateComponentProvider =
object : ClassPathScanningCandidateComponentProvider() {
override fun isCandidateComponent(beanDefinition: AnnotatedBeanDefinition): Boolean {
try {
val metadata = beanDefinition.metadata
val target = ClassUtils.forName(metadata.className, classLoader)
return !target.isAnnotation
} catch (ignored: Exception) {
}
return false
}
}
scanner.environment = environment
scanner.resourceLoader = resourceLoader
for (filter in filters) {
scanner.addIncludeFilter(filter)
}
val beanDefinitions: MutableList<BeanDefinition> = LinkedList()
for (basePackage in packages!!) {
beanDefinitions.addAll(scanner.findCandidateComponents(basePackage))
}
return beanDefinitions
}
/**
* 获取父接口实现对象的类型过滤器
*
* @param interfaces 父接口
* @return 类型过滤器
*/
private fun getSubObjectTypeFilter(interfaces: Array<Class<*>>?): TypeFilter {
return TypeFilter { metadataReader: MetadataReader, _: MetadataReaderFactory? ->
val interfaceNames = metadataReader.classMetadata.interfaceNames
var matched = false
for (interfaceName in interfaceNames) {
matched = Arrays.stream(interfaces)
.anyMatch { clazz: Class<*> -> clazz.name == interfaceName }
}
matched
}
}
override fun setResourceLoader(resourceLoader: ResourceLoader) {
this.resourceLoader = resourceLoader
}
override fun setBeanClassLoader(classLoader: ClassLoader) {
this.classLoader = classLoader
}
override fun setEnvironment(environment: Environment) {
this.environment = environment
}
@Throws(BeansException::class)
override fun setBeanFactory(beanFactory: BeanFactory) {
this.beanFactory = beanFactory
}
}

View File

@@ -0,0 +1,141 @@
package com.synebula.gaea.jpa.proxy.method
import com.synebula.gaea.domain.model.IAggregateRoot
import com.synebula.gaea.jpa.proxy.method.resolver.AbstractMethodResolver
import com.synebula.gaea.jpa.proxy.method.resolver.DefaultMethodResolver
import com.synebula.gaea.jpa.proxy.method.resolver.FindMethodResolver
import com.synebula.gaea.jpa.proxy.method.resolver.PageMethodResolver
import com.synebula.gaea.query.Params
import org.springframework.data.domain.Pageable
import org.springframework.data.jpa.domain.Specification
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
/**
* Jpa 方法映射包装类
*/
class JpaMethodProxy(
/**
* 方法需要实现的实体类
*/
private var entityClazz: Class<*>
) {
/**
* 默认的方法映射配置(IRepository, IQuery 接口中定义的方法)
*/
private val defaultMethodMapper: MutableMap<String, AbstractMethodResolver?> = LinkedHashMap()
/**
* 用户自定义的query注解方法处理
*/
private val queryMethodMapper: MutableMap<String, AbstractMethodResolver> = LinkedHashMap()
/**
* 方法参数映射
*/
var argumentResolver: AbstractMethodResolver? = null
private set
init {
initDefaultMethodMapper()
}
/**
* 匹配方法是否需要代理
*
* @param method 方法
* @return ture/false
*/
fun match(method: Method): Boolean {
var isMatch = (defaultMethodMapper.containsKey(method.name)
&& defaultMethodMapper[method.name]!!.match(method, AbstractMethodResolver.MethodType.SourceMethod))
// 如果默认代理方法没有匹配,则查找Query方法映射
if (!isMatch) {
isMatch = queryMethodMapper.containsKey(method.toString())
}
return isMatch
}
/**
* 解析代理方法
*
* @param proxy 代理对象
* @param method 源方法
* @param args 参数列表
* @return 执行结果
*/
@Throws(NoSuchMethodException::class, InvocationTargetException::class, IllegalAccessException::class)
fun proxyExecMethod(proxy: Any?, method: Method, args: Array<Any>): Any? {
// 匹配方法是否需要代理
if (defaultMethodMapper.containsKey(method.name)) {
val resolver = defaultMethodMapper[method.name]
// 匹配参数是否相同
if (resolver!!.match(method, AbstractMethodResolver.MethodType.SourceMethod)) {
//遍历代理对象, 找到合适的代理方法
val targetMethod =
proxy!!.javaClass.getMethod(resolver.targetMethodName, *resolver.targetMethodParameters)
return try {
// 开始执行代理方法
val mappingArguments = resolver.mappingArguments(args)
val result = targetMethod.invoke(proxy, *mappingArguments)
resolver.mappingResult(result)
} catch (e: IllegalAccessException) {
throw RuntimeException(e)
} catch (e: InvocationTargetException) {
throw RuntimeException(e)
}
}
}
// 如果默认代理方法没有匹配,则查找Query方法映射
if (queryMethodMapper.containsKey(method.toString())) {
val targetMethod = proxy!!.javaClass.getMethod(method.name, *method.parameterTypes)
return targetMethod.invoke(proxy, *args)
}
throw RuntimeException(
String.format(
"方法[%s,%s]没有匹配的代理配置信息, 执行该方法前请先执行match方法判断",
method.declaringClass.name, method.name
)
)
}
/**
* 初始化默认的方法映射列表
*/
private fun initDefaultMethodMapper() {
defaultMethodMapper["add"] = DefaultMethodResolver("saveAndFlush")
.sourceMethodParameters(IAggregateRoot::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["update"] = DefaultMethodResolver("saveAndFlush")
.sourceMethodParameters(IAggregateRoot::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["remove"] = DefaultMethodResolver("deleteById")
.sourceMethodParameters(Any::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["get"] = DefaultMethodResolver("findById")
.sourceMethodParameters(Any::class.java).targetMethodParameters(Any::class.java)
defaultMethodMapper["list"] = FindMethodResolver("findAll", entityClazz)
.sourceMethodParameters(MutableMap::class.java).targetMethodParameters(Specification::class.java)
defaultMethodMapper["count"] = FindMethodResolver("count", entityClazz)
.sourceMethodParameters(MutableMap::class.java).targetMethodParameters(Specification::class.java)
defaultMethodMapper["paging"] = PageMethodResolver("findAll", entityClazz)
.sourceMethodParameters(Params::class.java)
.targetMethodParameters(Specification::class.java, Pageable::class.java)
}
/**
* 增加用户自定义Query注解方法映射信息
*
* @param method 需要添加的方法
*/
fun addQueryMethodMapper(method: Method) {
queryMethodMapper[method.toString()] = DefaultMethodResolver(method.name)
}
fun setArgumentResolver(argumentResolver: AbstractMethodResolver?): JpaMethodProxy {
this.argumentResolver = argumentResolver
return this
}
fun setEntityClazz(entityClazz: Class<*>) {
this.entityClazz = entityClazz
}
}

View File

@@ -0,0 +1,100 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import java.lang.reflect.Method
/**
* 解决JPA方法参数的映射
*
* @param targetMethodName 目标方法名称
*/
abstract class AbstractMethodResolver(var targetMethodName: String) {
/**
* 方法相关实体类型
*/
lateinit var entityClazz: Class<*>
/**
* 目标方法形参类型列表
*
*/
lateinit var targetMethodParameters: Array<out Class<*>>
/**
* 源方法形参类型列表
*
*/
lateinit var sourceMethodParameters: Array<out Class<*>>
constructor(targetMethodName: String, entityClazz: Class<*>) : this(targetMethodName) {
this.entityClazz = entityClazz
}
/**
* 解析映射实参
*
* @param args 实参列表
* @return 映射后的实参列表
*/
abstract fun mappingArguments(args: Array<Any>): Array<Any>
/**
* 解析映射方法结果
*
* @param result 方法结果
* @return 映射后的方法结果
*/
abstract fun mappingResult(result: Any): Any
/**
* 获取源方法形参类型列表
*/
open fun sourceMethodParameters(vararg params: Class<*>): AbstractMethodResolver {
this.sourceMethodParameters = params
return this
}
/**
* 设置目标方法形参类型列表
*/
open fun targetMethodParameters(vararg params: Class<*>): AbstractMethodResolver {
this.targetMethodParameters = params
return this
}
/**
* 匹配方法名(目标方法)/参数是否复合
*
* @param method 需要匹配的方法
* @param methodType 需要匹配的方法类型
* @return ture/false
*/
fun match(method: Method, methodType: MethodType): Boolean {
var methodParameters = sourceMethodParameters
var matched = true
// 匹配目标方法的时候额外匹配下方法名
if (methodType == MethodType.TargetMethod) {
methodParameters = targetMethodParameters
matched = method.name == targetMethodName
}
// 如果[目标]方法名匹配, 判断参数是否匹配
matched = matched && method.parameterCount == methodParameters.size
if (matched) {
for (i in methodParameters.indices) {
val parameterTypes = method.parameterTypes
if (methodParameters[i] != parameterTypes[i]) {
matched = false
break
}
}
}
return matched
}
enum class MethodType {
SourceMethod, TargetMethod
}
}

View File

@@ -0,0 +1,20 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import java.util.*
/**
* 默认返回全部
*/
class DefaultMethodResolver(targetMethodName: String) : AbstractMethodResolver(targetMethodName) {
override fun mappingArguments(args: Array<Any>): Array<Any> {
return args
}
override fun mappingResult(result: Any): Any {
if (result is Optional<*>) {
return result.orElse(null)
}
return result
}
}

View File

@@ -0,0 +1,20 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import com.synebula.gaea.jpa.toSpecification
/**
* 查询方法参数映射
*/
class FindMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMethodResolver(targetMethodName, clazz) {
@Suppress("UNCHECKED_CAST")
override fun mappingArguments(args: Array<Any>): Array<Any> {
val params = args[0] as Map<String, String>?
val specification = params.toSpecification(entityClazz)
return arrayOf(specification)
}
override fun mappingResult(result: Any): Any {
return result
}
}

View File

@@ -0,0 +1,56 @@
package com.synebula.gaea.jpa.proxy.method.resolver
import com.synebula.gaea.jpa.toSpecification
import com.synebula.gaea.query.Order
import com.synebula.gaea.query.Params
import org.springframework.data.domain.Page
import org.springframework.data.domain.PageRequest
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import java.util.*
import javax.persistence.EmbeddedId
import javax.persistence.Id
/**
* 分页方法参数映射
*/
class PageMethodResolver(targetMethodName: String, clazz: Class<*>) : AbstractMethodResolver(targetMethodName, clazz) {
override fun mappingArguments(args: Array<Any>): Array<Any> {
return try {
val params: Params? = args[0] as Params?
val specification = params!!.parameters.toSpecification(entityClazz)
var sort = Sort.unsorted()
for (key in params.orders.keys) {
val direction = if (params.orders[key] === Order.ASC) Sort.Direction.ASC else Sort.Direction.DESC
sort = sort.and(Sort.by(direction, key))
}
if (sort.isEmpty) {
val fields = entityClazz.declaredFields
for (field in fields) {
val isId = Arrays.stream(field.declaredAnnotations).anyMatch { annotation: Annotation ->
(annotation.annotationClass.java == Id::class.java
|| annotation.annotationClass.java == EmbeddedId::class.java)
}
if (isId) {
sort = Sort.by(Sort.Direction.ASC, field.name)
break
}
}
}
// Pageable 页面从0开始
val pageable: Pageable = PageRequest.of(params.page - 1, params.size, sort)
arrayOf(specification, pageable)
} catch (e: Exception) {
throw RuntimeException(e)
}
}
override fun mappingResult(result: Any): Any {
val page = result as Page<*>
// Page 页面从0开始
return com.synebula.gaea.query.Page(page.number + 1, page.size, page.totalElements.toInt(), page.content)
}
}

View File

@@ -3,11 +3,3 @@ dependencies {
api project(":src:gaea.spring") api project(":src:gaea.spring")
api("org.springframework.boot:spring-boot-starter-data-mongodb:$spring_version") api("org.springframework.boot:spring-boot-starter-data-mongodb:$spring_version")
} }
publishing {
publications {
publish(MavenPublication) {
from components.java
}
}
}

View File

@@ -30,7 +30,7 @@ fun Query.select(fields: Array<String>): Query {
* @param onWhere 获取字段查询方式的方法 * @param onWhere 获取字段查询方式的方法
*/ */
fun Query.where( fun Query.where(
params: Map<String, Any>?, params: Map<String, String>?,
onWhere: ((v: String) -> Where?) = { null }, onWhere: ((v: String) -> Where?) = { null },
onFieldType: ((v: String) -> Class<*>?) = { null } onFieldType: ((v: String) -> Class<*>?) = { null }
): Query { ): Query {
@@ -38,15 +38,15 @@ fun Query.where(
if (params != null) { if (params != null) {
for (param in params) { for (param in params) {
val key = param.key val key = param.key
var value = param.value var value: Any = param.value
//日期类型特殊处理为String类型 //日期类型特殊处理为String类型
val fieldType = onFieldType(key) val fieldType = onFieldType(key)
if (fieldType != null && value.javaClass != fieldType) { if (fieldType != null && value.javaClass != fieldType) {
when (fieldType) { when (fieldType) {
Date::class.java -> value = DateTime(value.toString(), "yyyy-MM-ddTHH:mm:ss").date Date::class.java -> value = DateTime(param.value, "yyyy-MM-ddTHH:mm:ss").date
Int::class.java -> value = value.toString().toInt() Int::class.java -> value = param.value.toInt()
Integer::class.java -> value = value.toString().toInt() Integer::class.java -> value = param.value.toInt()
} }
} }
@@ -58,14 +58,15 @@ fun Query.where(
val field = where.children.ifEmpty { key } val field = where.children.ifEmpty { key }
var criteria = Criteria.where(field) var criteria = Criteria.where(field)
criteria = when (where.operator) { criteria = when (where.operator) {
Operator.eq -> criteria.`is`(value) Operator.Eq -> criteria.`is`(value)
Operator.ne -> criteria.ne(value) Operator.Ne -> criteria.ne(value)
Operator.lt -> criteria.lt(value) Operator.Lt -> criteria.lt(value)
Operator.gt -> criteria.gt(value) Operator.Gt -> criteria.gt(value)
Operator.lte -> criteria.lte(value) Operator.Lte -> criteria.lte(value)
Operator.gte -> criteria.gte(value) Operator.Gte -> criteria.gte(value)
Operator.like -> criteria.regex(value.toString(), if (where.sensitiveCase) "" else "i") Operator.Like -> criteria.regex(value.toString(), if (where.sensitiveCase) "" else "i")
Operator.default -> tryRangeWhere(param.key, value, onFieldType) Operator.Range -> tryRangeWhere(param.key, value, onFieldType)
Operator.Default -> tryRangeWhere(param.key, value, onFieldType)
} }
list.add(if (where.children.isEmpty()) criteria else Criteria.where(key).elemMatch(criteria)) list.add(if (where.children.isEmpty()) criteria else Criteria.where(key).elemMatch(criteria))
} }
@@ -103,7 +104,7 @@ private fun tryRangeWhere(key: String, value: Any, onFieldType: ((v: String) ->
* *
* @param params 参数列表 * @param params 参数列表
*/ */
fun Query.where(params: Map<String, Any>?, clazz: Class<*>): Query { fun Query.where(params: Map<String, String>?, clazz: Class<*>): Query {
var field: Field? var field: Field?
return this.where(params, { name -> return this.where(params, { name ->
field = clazz.declaredFields.find { it.name == name } field = clazz.declaredFields.find { it.name == name }

View File

@@ -4,11 +4,11 @@ import com.synebula.gaea.spring.autoconfig.Factory
import com.synebula.gaea.spring.autoconfig.Proxy import com.synebula.gaea.spring.autoconfig.Proxy
import org.springframework.beans.factory.BeanFactory import org.springframework.beans.factory.BeanFactory
class MongodbRepoFactory( class MongodbRepositoryFactory(
supertype: Class<*>, supertype: Class<*>,
var beanFactory: BeanFactory, var beanFactory: BeanFactory,
) : Factory(supertype) { ) : Factory(supertype) {
override fun createProxy(): Proxy { override fun createProxy(): Proxy {
return MongodbRepoProxy(supertype, this.beanFactory) return MongodbRepositoryProxy(supertype, this.beanFactory)
} }
} }

View File

@@ -10,7 +10,7 @@ import org.springframework.beans.factory.BeanFactory
import org.springframework.data.mongodb.core.MongoTemplate import org.springframework.data.mongodb.core.MongoTemplate
import java.lang.reflect.Method import java.lang.reflect.Method
class MongodbRepoProxy( class MongodbRepositoryProxy(
private var supertype: Class<*>, private var beanFactory: BeanFactory private var supertype: Class<*>, private var beanFactory: BeanFactory
) : Proxy() { ) : Proxy() {

View File

@@ -9,14 +9,14 @@ import org.springframework.beans.factory.support.GenericBeanDefinition
import org.springframework.core.annotation.AnnotationAttributes import org.springframework.core.annotation.AnnotationAttributes
import org.springframework.core.type.AnnotationMetadata import org.springframework.core.type.AnnotationMetadata
class MongodbRepoRegister : Register() { class MongodbRepositoryRegister : Register() {
override fun scan(metadata: AnnotationMetadata): Map<String, BeanDefinition> { override fun scan(metadata: AnnotationMetadata): Map<String, BeanDefinition> {
val result = mutableMapOf<String, BeanDefinition>() val result = mutableMapOf<String, BeanDefinition>()
// 获取注解参数信息:basePackages // 获取注解参数信息:basePackages
val attributes = AnnotationAttributes( val attributes = AnnotationAttributes(
metadata.getAnnotationAttributes( metadata.getAnnotationAttributes(
MongodbRepoScan::class.java.name MongodbRepositoryScan::class.java.name
) ?: mapOf() ) ?: mapOf()
) )
val basePackages = attributes.getStringArray("basePackages") val basePackages = attributes.getStringArray("basePackages")
@@ -44,7 +44,7 @@ class MongodbRepoRegister : Register() {
builder.addConstructorArgValue(beanClazz) builder.addConstructorArgValue(beanClazz)
builder.addConstructorArgValue(this._beanFactory) builder.addConstructorArgValue(this._beanFactory)
val definition = builder.rawBeanDefinition as GenericBeanDefinition val definition = builder.rawBeanDefinition as GenericBeanDefinition
definition.beanClass = MongodbRepoFactory::class.java definition.beanClass = MongodbRepositoryFactory::class.java
definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE
result[beanClazz.name] = definition result[beanClazz.name] = definition
} }
@@ -59,7 +59,7 @@ class MongodbRepoRegister : Register() {
builder.addConstructorArgValue(this._beanFactory) builder.addConstructorArgValue(this._beanFactory)
builder.addConstructorArgValue(emptyArray<String>()) builder.addConstructorArgValue(emptyArray<String>())
val definition = builder.rawBeanDefinition as GenericBeanDefinition val definition = builder.rawBeanDefinition as GenericBeanDefinition
definition.beanClass = MongodbRepoFactory::class.java definition.beanClass = MongodbRepositoryFactory::class.java
definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE
result[IRepository::class.java.name] = definition result[IRepository::class.java.name] = definition
} }

View File

@@ -8,5 +8,5 @@ import java.lang.annotation.Inherited
@Retention(AnnotationRetention.RUNTIME) @Retention(AnnotationRetention.RUNTIME)
@MustBeDocumented @MustBeDocumented
@Inherited @Inherited
@Import(MongodbRepoRegister::class) @Import(MongodbRepositoryRegister::class)
annotation class MongodbRepoScan(val basePackages: Array<String> = []) annotation class MongodbRepositoryScan(val basePackages: Array<String> = [])

View File

@@ -31,7 +31,7 @@ open class MongodbQuery<TView, ID>(override var clazz: Class<TView>, var templat
return this.template.findOne(whereId(id), clazz, this.collection(clazz)) return this.template.findOne(whereId(id), clazz, this.collection(clazz))
} }
override fun list(params: Map<String, Any>?): List<TView> { override fun list(params: Map<String, String>?): List<TView> {
val fields = this.fields(clazz) val fields = this.fields(clazz)
val query = Query() val query = Query()
query.where(params, clazz) query.where(params, clazz)
@@ -39,7 +39,7 @@ open class MongodbQuery<TView, ID>(override var clazz: Class<TView>, var templat
return this.find(query, clazz) return this.find(query, clazz)
} }
override fun count(params: Map<String, Any>?): Int { override fun count(params: Map<String, String>?): Int {
val query = Query() val query = Query()
return this.template.count(query.where(params, clazz), this.collection(clazz)).toInt() return this.template.count(query.where(params, clazz), this.collection(clazz)).toInt()
} }

View File

@@ -30,7 +30,7 @@ open class MongodbUniversalQuery(var template: MongoTemplate) : IUniversalQuery
return this.template.findOne(whereId(id), clazz, this.collection(clazz)) return this.template.findOne(whereId(id), clazz, this.collection(clazz))
} }
override fun <TView> list(params: Map<String, Any>?, clazz: Class<TView>): List<TView> { override fun <TView> list(params: Map<String, String>?, clazz: Class<TView>): List<TView> {
val fields = this.fields(clazz) val fields = this.fields(clazz)
val query = Query() val query = Query()
query.where(params, clazz) query.where(params, clazz)
@@ -38,7 +38,7 @@ open class MongodbUniversalQuery(var template: MongoTemplate) : IUniversalQuery
return this.find(query, clazz) return this.find(query, clazz)
} }
override fun <TView> count(params: Map<String, Any>?, clazz: Class<TView>): Int { override fun <TView> count(params: Map<String, String>?, clazz: Class<TView>): Int {
val query = Query() val query = Query()
return this.template.count(query.where(params, clazz), this.collection(clazz)).toInt() return this.template.count(query.where(params, clazz), this.collection(clazz)).toInt()
} }

View File

@@ -40,7 +40,7 @@ open class MongodbRepository<TAggregateRoot : IAggregateRoot<ID>, ID>(
this.repo.save(list) this.repo.save(list)
} }
override fun count(params: Map<String, Any>?): Int { override fun count(params: Map<String, String>?): Int {
val query = Query() val query = Query()
return this.repo.count(query.where(params, clazz), clazz).toInt() return this.repo.count(query.where(params, clazz), clazz).toInt()
} }

View File

@@ -54,7 +54,7 @@ open class MongodbUniversalRepository(private var repo: MongoTemplate) : IUniver
this.repo.insert(roots, clazz) this.repo.insert(roots, clazz)
} }
override fun <TAggregateRoot> count(params: Map<String, Any>?, clazz: Class<TAggregateRoot>): Int { override fun <TAggregateRoot> count(params: Map<String, String>?, clazz: Class<TAggregateRoot>): Int {
val query = Query() val query = Query()
return this.repo.count(query.where(params, clazz), clazz).toInt() return this.repo.count(query.where(params, clazz), clazz).toInt()
} }

View File

@@ -1,10 +0,0 @@
publishing {
publications {
publish(MavenPublication) {
group "${project.group}"
artifactId "${project.name}"
version "$version"
from components.java
}
}
}

View File

@@ -35,7 +35,7 @@ import java.util.logging.Logger
* them together as a set ([Dagger](https://dagger.dev/dev-guide/multibindings), [Guice](https://github.com/google/guice/wiki/Multibindings), [Spring](https://docs.spring.io/spring-framework/docs/current/reference/html/core.html#beans-autowired-annotation)). * them together as a set ([Dagger](https://dagger.dev/dev-guide/multibindings), [Guice](https://github.com/google/guice/wiki/Multibindings), [Spring](https://docs.spring.io/spring-framework/docs/current/reference/html/core.html#beans-autowired-annotation)).
* *
* *
* To react to messages, we recommend a reactive-streams framework like [RxJava](https://github.com/ReactiveX/RxJava/wiki) (supplemented with its [RxAndroid](https://github.com/ReactiveX/RxAndroid) extension if you are building for * To react to messages, we recommend a reactive-streams framework Like [RxJava](https://github.com/ReactiveX/RxJava/wiki) (supplemented with its [RxAndroid](https://github.com/ReactiveX/RxAndroid) extension if you are building for
* Android) or [Project Reactor](https://projectreactor.io/). (For the basics of * Android) or [Project Reactor](https://projectreactor.io/). (For the basics of
* translating code from using a message bus to using a reactive-streams framework, see these two * translating code from using a message bus to using a reactive-streams framework, see these two
* guides: [1](https://blog.jkl.gg/implementing-an-message-bus-with-rxjava-rxbus/), [2](https://lorentzos.com/rxjava-as-message-bus-the-right-way-10a36bdd49ba).) Some usages * guides: [1](https://blog.jkl.gg/implementing-an-message-bus-with-rxjava-rxbus/), [2](https://lorentzos.com/rxjava-as-message-bus-the-right-way-10a36bdd49ba).) Some usages
@@ -49,7 +49,7 @@ import java.util.logging.Logger
* * It makes the cross-references between producer and subscriber harder to find. This can * * It makes the cross-references between producer and subscriber harder to find. This can
* complicate debugging, lead to unintentional reentrant calls, and force apps to eagerly * complicate debugging, lead to unintentional reentrant calls, and force apps to eagerly
* initialize all possible subscribers at startup time. * initialize all possible subscribers at startup time.
* * It uses reflection in ways that break when code is processed by optimizers/minimizer like * * It uses reflection in ways that break when code is processed by optimizers/minimizer Like
* [R8 and Proguard](https://developer.android.com/studio/build/shrink-code). * [R8 and Proguard](https://developer.android.com/studio/build/shrink-code).
* * It doesn't offer a way to wait for multiple messages before taking action. For example, it * * It doesn't offer a way to wait for multiple messages before taking action. For example, it
* doesn't offer a way to wait for multiple producers to all report that they're "ready," nor * doesn't offer a way to wait for multiple producers to all report that they're "ready," nor
@@ -137,7 +137,7 @@ import java.util.logging.Logger
* @author Cliff * @author Cliff
* @since 10.0 * @since 10.0
* @param identifier a brief name for this bus, for logging purposes. Should/home/alex/privacy/project/myths/gaea be a valid Java * @param identifier a brief name for this bus, for logging purposes. Should/home/alex/privacy/project/myths/gaea be a valid Java
* @param executor the default executor this event bus uses for dispatching events to subscribers. * @param executor the Default executor this event bus uses for dispatching events to subscribers.
* @param dispatcher message dispatcher. * @param dispatcher message dispatcher.
* @param exceptionHandler Handler for subscriber exceptions. * @param exceptionHandler Handler for subscriber exceptions.
*/ */
@@ -159,7 +159,7 @@ open class Bus<T : Any>(
* identifier. * identifier.
*/ */
@JvmOverloads @JvmOverloads
constructor(identifier: String = "default") : this( constructor(identifier: String = "Default") : this(
identifier, identifier,
Executor { it.run() }, Executor { it.run() },
Dispatcher.perThreadDispatchQueue(), Dispatcher.perThreadDispatchQueue(),
@@ -173,7 +173,7 @@ open class Bus<T : Any>(
* @since 16.0 * @since 16.0
*/ */
constructor(exceptionHandler: SubscriberExceptionHandler<T>) : this( constructor(exceptionHandler: SubscriberExceptionHandler<T>) : this(
"default", "Default",
Executor { it.run() }, Executor { it.run() },
Dispatcher.perThreadDispatchQueue(), Dispatcher.perThreadDispatchQueue(),
exceptionHandler exceptionHandler
@@ -203,7 +203,7 @@ open class Bus<T : Any>(
* @since 16.0 * @since 16.0
*/ */
constructor(executor: Executor, subscriberExceptionHandler: SubscriberExceptionHandler<T>) : this( constructor(executor: Executor, subscriberExceptionHandler: SubscriberExceptionHandler<T>) : this(
"default", "Default",
executor, executor,
Dispatcher.legacyAsync(), Dispatcher.legacyAsync(),
subscriberExceptionHandler subscriberExceptionHandler
@@ -216,7 +216,7 @@ open class Bus<T : Any>(
* down the executor after the last message has been posted to this message bus. * down the executor after the last message has been posted to this message bus.
*/ */
constructor(executor: Executor) : this( constructor(executor: Executor) : this(
"default", "Default",
executor, executor,
Dispatcher.legacyAsync(), Dispatcher.legacyAsync(),
LoggingHandler() LoggingHandler()

View File

@@ -68,7 +68,7 @@ interface IRepository<TAggregateRoot : IAggregateRoot<ID>, ID> {
* @param params 查询条件。 * @param params 查询条件。
* @return int * @return int
*/ */
fun count(params: Map<String, Any>?): Int fun count(params: Map<String, String>?): Int
} }

View File

@@ -61,5 +61,5 @@ interface IUniversalRepository {
* @param params 查询条件。 * @param params 查询条件。
* @return int * @return int
*/ */
fun <TAggregateRoot> count(params: Map<String, Any>?, clazz: Class<TAggregateRoot>): Int fun <TAggregateRoot> count(params: Map<String, String>?, clazz: Class<TAggregateRoot>): Int
} }

View File

@@ -25,7 +25,7 @@ interface IQuery<TView, ID> {
* @param params 查询条件。 * @param params 查询条件。
* @return 视图列表 * @return 视图列表
*/ */
fun list(params: Map<String, Any>?): List<TView> fun list(params: Map<String, String>?): List<TView>
/** /**
* 根据条件查询符合条件记录的数量 * 根据条件查询符合条件记录的数量
@@ -33,7 +33,7 @@ interface IQuery<TView, ID> {
* @param params 查询条件。 * @param params 查询条件。
* @return 数量 * @return 数量
*/ */
fun count(params: Map<String, Any>?): Int fun count(params: Map<String, String>?): Int
/** /**
* 根据实体类条件查询所有符合条件记录(分页查询) * 根据实体类条件查询所有符合条件记录(分页查询)

View File

@@ -20,7 +20,7 @@ interface IUniversalQuery {
* @param params 查询条件。 * @param params 查询条件。
* @return 视图列表 * @return 视图列表
*/ */
fun <TView> list(params: Map<String, Any>?, clazz: Class<TView>): List<TView> fun <TView> list(params: Map<String, String>?, clazz: Class<TView>): List<TView>
/** /**
* 根据条件查询符合条件记录的数量 * 根据条件查询符合条件记录的数量
@@ -28,7 +28,7 @@ interface IUniversalQuery {
* @param params 查询条件。 * @param params 查询条件。
* @return 数量 * @return 数量
*/ */
fun <TView> count(params: Map<String, Any>?, clazz: Class<TView>): Int fun <TView> count(params: Map<String, String>?, clazz: Class<TView>): Int
/** /**
* 根据实体类条件查询所有符合条件记录(分页查询) * 根据实体类条件查询所有符合条件记录(分页查询)

View File

@@ -4,40 +4,45 @@ enum class Operator {
/** /**
* 等于 * 等于
*/ */
eq, Eq,
/** /**
* 不等于 * 不等于
*/ */
ne, Ne,
/** /**
* 小于 * 小于
*/ */
lt, Lt,
/** /**
* 大于 * 大于
*/ */
gt, Gt,
/** /**
* 小于或等于 * 小于或等于
*/ */
lte, Lte,
/** /**
* 大于或等于 * 大于或等于
*/ */
gte, Gte,
/** /**
* 模糊匹配 * 模糊匹配
*/ */
like, Like,
/**
* 范围内
*/
Range,
/** /**
* 默认查询, 未定义查询方式, 业务人员自己实现查询方式 * 默认查询, 未定义查询方式, 业务人员自己实现查询方式
*/ */
default Default
} }

View File

@@ -9,7 +9,7 @@ package com.synebula.gaea.query
*/ */
data class Params(var page: Int = 1, var size: Int = 10) { data class Params(var page: Int = 1, var size: Int = 10) {
private var _parameters = linkedMapOf<String, Any>() private var _parameters = linkedMapOf<String, String>()
private var _orders = linkedMapOf<String, Order>() private var _orders = linkedMapOf<String, Order>()
/** /**
@@ -27,41 +27,23 @@ data class Params(var page: Int = 1, var size: Int = 10) {
this._orders = value this._orders = value
} }
get() { get() {
if (this._parameters.keys.count { it.startsWith("@") } > 0) { this.filterOrderParams()
val params = linkedMapOf<String, Any>()
this._parameters.forEach {
if (it.key.startsWith("@")) {
this._orders[it.key.removePrefix("@")] = Order.valueOf(it.value.toString())
} else
params[it.key] = it.value
}
this._parameters = params
}
return this._orders return this._orders
} }
/** /**
* 查询条件。 * 查询条件。
*/ */
var parameters: LinkedHashMap<String, Any> var parameters: LinkedHashMap<String, String>
set(value) { set(value) {
this._parameters = value this._parameters = value
} }
get() { get() {
if (this._parameters.keys.count { it.startsWith("@") } > 0) { this.filterOrderParams()
val params = linkedMapOf<String, Any>()
this._parameters.forEach {
if (it.key.startsWith("@")) {
this._orders[it.key.removePrefix("@")] = Order.valueOf(it.value.toString())
} else
params[it.key] = it.value
}
this._parameters = params
}
return this._parameters return this._parameters
} }
constructor(page: Int, size: Int, parameters: LinkedHashMap<String, Any>) : this(page, size) { constructor(page: Int, size: Int, parameters: LinkedHashMap<String, String>) : this(page, size) {
this.page = page this.page = page
this.size = size this.size = size
this._parameters = parameters this._parameters = parameters
@@ -70,7 +52,7 @@ data class Params(var page: Int = 1, var size: Int = 10) {
/** /**
* 添加查询条件 * 添加查询条件
*/ */
fun where(field: String, value: Any): Params { fun where(field: String, value: String): Params {
_parameters[field] = value _parameters[field] = value
return this return this
} }
@@ -82,4 +64,20 @@ data class Params(var page: Int = 1, var size: Int = 10) {
_orders[field] = order _orders[field] = order
return this return this
} }
/**
* 过滤参数中的排序字段
*/
private fun filterOrderParams() {
if (this._parameters.keys.count { it.startsWith("@") } > 0) {
val params = linkedMapOf<String, String>()
this._parameters.forEach {
if (it.key.startsWith("@")) {
this._orders[it.key.removePrefix("@")] = Order.valueOf(it.value)
} else
params[it.key] = it.value
}
this._parameters = params
}
}
} }

View File

@@ -1,5 +1,6 @@
package com.synebula.gaea.reflect package com.synebula.gaea.reflect
import java.lang.reflect.Field
import java.lang.reflect.ParameterizedType import java.lang.reflect.ParameterizedType
/** /**
@@ -24,5 +25,27 @@ fun Class<*>.getGenericInterface(interfaceClazz: Class<*>): ParameterizedType? {
val type = this.genericInterfaces.find { it.typeName.startsWith(interfaceClazz.typeName) } val type = this.genericInterfaces.find { it.typeName.startsWith(interfaceClazz.typeName) }
return if (type == null) null return if (type == null) null
else type as ParameterizedType else type as ParameterizedType
} }
/**
* 查找类字段, 可以查找包括继承类的私有字段
*
* @param name 字段名称
* @return 字段类型
*/
fun Class<*>.findField(name: String): Field? {
var field: Field? = null
for (f in this.declaredFields) {
if (f.name == name) {
field = f
}
}
if (field == null) {
val superclass = this.superclass
if (superclass != Any::class.java) {
field = superclass.findField(name)
}
}
return field
}