1.1.0 add service / repository / query spring auto config function

This commit is contained in:
2022-08-18 14:35:01 +08:00
parent b7cfd0a7f9
commit 07684e814d
42 changed files with 1001 additions and 272 deletions

View File

@@ -0,0 +1,13 @@
dependencies {
api project(":src:gaea")
api project(":src:gaea.spring")
api("org.springframework.boot:spring-boot-starter-data-mongodb:$spring_version")
}
publishing {
publications {
publish(MavenPublication) {
from components.java
}
}
}

View File

@@ -0,0 +1,135 @@
package com.synebula.gaea.mongodb
import com.synebula.gaea.data.date.DateTime
import com.synebula.gaea.query.Operator
import com.synebula.gaea.query.Order
import com.synebula.gaea.query.Where
import org.springframework.data.domain.Sort
import org.springframework.data.mongodb.core.query.Criteria
import org.springframework.data.mongodb.core.query.Query
import java.lang.reflect.Field
import java.util.*
/**
* 获取查询字段列表
*
* @param fields 字段列表
*/
fun Query.select(fields: Array<String>): Query {
fields.forEach {
this.fields().include(it)
}
return this
}
/**
* 根据参数获取查询条件
*
* @param params 参数列表
* @param onWhere 获取字段查询方式的方法
*/
fun Query.where(
params: Map<String, Any>?,
onWhere: ((v: String) -> Where?) = { null },
onFieldType: ((v: String) -> Class<*>?) = { null }
): Query {
val list = arrayListOf<Criteria>()
if (params != null) {
for (param in params) {
val key = param.key
var value = param.value
//日期类型特殊处理为String类型
val fieldType = onFieldType(key)
if (fieldType != null && value.javaClass != fieldType) {
when (fieldType) {
Date::class.java -> value = DateTime(value.toString(), "yyyy-MM-ddTHH:mm:ss").date
Int::class.java -> value = value.toString().toInt()
Integer::class.java -> value = value.toString().toInt()
}
}
val where = onWhere(key)
if (where == null) {
list.add(tryRangeWhere(param.key, value, onFieldType))
} else {
//判断执行查询子元素还是本字段
val field = where.children.ifEmpty { key }
var criteria = Criteria.where(field)
criteria = when (where.operator) {
Operator.eq -> criteria.`is`(value)
Operator.ne -> criteria.ne(value)
Operator.lt -> criteria.lt(value)
Operator.gt -> criteria.gt(value)
Operator.lte -> criteria.lte(value)
Operator.gte -> criteria.gte(value)
Operator.like -> criteria.regex(value.toString(), if (where.sensitiveCase) "" else "i")
Operator.default -> tryRangeWhere(param.key, value, onFieldType)
}
list.add(if (where.children.isEmpty()) criteria else Criteria.where(key).elemMatch(criteria))
}
}
}
val criteria = Criteria()
if (list.isNotEmpty()) criteria.andOperator(*list.toTypedArray())
return this.addCriteria(criteria)
}
/**
* 尝试范围查询,失败则返回正常查询条件。
*/
private fun tryRangeWhere(key: String, value: Any, onFieldType: ((v: String) -> Class<*>?) = { null }): Criteria {
val rangeStartSuffix = "[0]" //范围查询开始后缀
val rangeEndSuffix = "[1]" //范围查询结束后缀
var condition = value
val realKey = key.removeSuffix(rangeStartSuffix).removeSuffix(rangeEndSuffix)
val fieldType = onFieldType(realKey)
if (fieldType != null && value.javaClass != fieldType && fieldType == Date::class.java) {
condition = DateTime(value.toString(), "yyyy-MM-dd HH:mm:ss").date
}
return when {
//以范围查询开始后缀结尾表示要用大于或等于查询方式
key.endsWith(rangeStartSuffix) -> Criteria.where(realKey).gte(condition)
//以范围查询结束后缀结尾表示要用小于或等于查询方式
key.endsWith(rangeEndSuffix) -> Criteria.where(realKey).lte(condition)
else -> Criteria.where(key).`is`(value)
}
}
/**
* 根据参数获取查询条件
*
* @param params 参数列表
*/
fun Query.where(params: Map<String, Any>?, clazz: Class<*>): Query {
var field: Field?
return this.where(params, { name ->
field = clazz.declaredFields.find { it.name == name }
field?.getDeclaredAnnotation(Where::class.java)
}, { name -> clazz.declaredFields.find { it.name == name }?.type })
}
/**
* 获取ID查询条件
*
* @param id 业务ID
*/
fun <ID> whereId(id: ID): Query = Query.query(Criteria.where("_id").`is`(id))
/**
* 获取排序对象
*
* @param orders 排序条件字段
*/
fun order(orders: Map<String, Order>?): Sort {
val orderList = mutableListOf<Sort.Order>()
orders?.forEach {
orderList.add(Sort.Order(Sort.Direction.valueOf(it.value.name), it.key))
}
return if (orderList.size == 0)
Sort.by(Sort.Direction.DESC, "_id")
else
Sort.by(orderList)
}

View File

@@ -0,0 +1,15 @@
package com.synebula.gaea.mongodb.autoconfig
import com.synebula.gaea.spring.autoconfig.Factory
import com.synebula.gaea.spring.autoconfig.Proxy
import org.springframework.beans.factory.BeanFactory
class MongodbRepoFactory(
supertype: Class<*>,
var beanFactory: BeanFactory,
var implementBeanNames: Array<String> = arrayOf()
) : Factory(supertype) {
override fun createProxy(): Proxy {
return MongodbRepoProxy(supertype, this.beanFactory, this.implementBeanNames)
}
}

View File

@@ -0,0 +1,84 @@
package com.synebula.gaea.mongodb.autoconfig
import com.synebula.gaea.domain.repository.IRepository
import com.synebula.gaea.log.ILogger
import com.synebula.gaea.mongodb.query.MongodbQuery
import com.synebula.gaea.mongodb.repository.MongodbRepository
import com.synebula.gaea.query.IQuery
import com.synebula.gaea.spring.autoconfig.Proxy
import org.springframework.beans.factory.BeanFactory
import org.springframework.data.mongodb.core.MongoTemplate
import java.io.InvalidClassException
import java.lang.reflect.Method
import java.lang.reflect.ParameterizedType
class MongodbRepoProxy(
private var supertype: Class<*>, private var beanFactory: BeanFactory, implementBeanNames: Array<String> = arrayOf()
) : Proxy() {
private var repo: IRepository<*, *>? = null
private var query: IQuery<*, *>? = null
init {
if (this.supertype.interfaces.any { it == IRepository::class.java }) {
// 如果是IRepository子接口
if (implementBeanNames.isEmpty()) {
val genericInterfaces = this.supertype.genericInterfaces.find {
it.typeName.startsWith(IRepository::class.java.typeName)
}!!
val constructor = MongodbRepository::class.java.getConstructor(
Class::class.java, MongoTemplate::class.java
)
this.repo = constructor.newInstance(
(genericInterfaces as ParameterizedType).actualTypeArguments[0],
this.beanFactory.getBean(MongoTemplate::class.java)
)
} else {
this.repo = this.beanFactory.getBean(implementBeanNames[0]) as IRepository<*, *>
}
} else {
// 否则是IQuery子接口
if (implementBeanNames.isEmpty()) {
val genericInterfaces = this.supertype.genericInterfaces.find {
it.typeName.startsWith(IQuery::class.java.typeName)
}!!
val constructor = MongodbQuery::class.java.getConstructor(
Class::class.java, MongoTemplate::class.java, ILogger::class.java
)
this.query = constructor.newInstance(
(genericInterfaces as ParameterizedType).actualTypeArguments[0],
this.beanFactory.getBean(MongoTemplate::class.java),
this.beanFactory.getBean(ILogger::class.java),
)
} else {
this.query = this.beanFactory.getBean(implementBeanNames[0]) as IQuery<*, *>
}
}
}
/**
* 执行代理方法
*
* @param proxy 代理对象
* @param method 需要执行的方法
* @param args 参数列表
* @return 方法执行结果
*/
override fun exec(proxy: Any, method: Method, args: Array<Any>): Any? {
val proxyClazz = if (this.repo != null) {
this.repo!!.javaClass
} else if (this.query != null) {
this.query!!.javaClass
} else
throw InvalidClassException("class ${this.supertype.name} property repo and query are both null")
try {
val proxyMethod: Method = proxyClazz.getDeclaredMethod(method.name, *method.parameterTypes)
return proxyMethod.invoke(this.repo, *args)
} catch (ex: NoSuchMethodException) {
throw NoSuchMethodException("method [${method.toGenericString()}] not implements in class [${proxyClazz}], you must implements interface [${this.supertype.name}] ")
}
}
}

View File

@@ -0,0 +1,54 @@
package com.synebula.gaea.mongodb.autoconfig
import com.synebula.gaea.domain.repository.IRepository
import com.synebula.gaea.query.IQuery
import com.synebula.gaea.spring.autoconfig.Register
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.support.BeanDefinitionBuilder
import org.springframework.beans.factory.support.GenericBeanDefinition
import org.springframework.core.annotation.AnnotationAttributes
import org.springframework.core.type.AnnotationMetadata
class MongodbRepoRegister : Register() {
override fun scan(metadata: AnnotationMetadata): Map<String, BeanDefinition> {
val result = mutableMapOf<String, BeanDefinition>()
// 获取注解参数信息:basePackages
val attributes = AnnotationAttributes(
metadata.getAnnotationAttributes(
MongodbRepoScan::class.java.name
) ?: mapOf()
)
val basePackages = attributes.getStringArray("basePackages")
val beanDefinitions = this.doScan(
basePackages,
arrayOf(this.interfaceFilter(arrayOf(IRepository::class.java, IQuery::class.java)))
)
beanDefinitions.forEach { beanDefinition ->
// 获取实际的bean类型
val beanClazz: Class<*> = try {
Class.forName(beanDefinition.beanClassName)
} catch (ex: ClassNotFoundException) {
throw ex
}
// 尝试获取实际继承类型
val implBeanDefinitions = this.doScan(basePackages, arrayOf(this.interfaceFilter(arrayOf(beanClazz))))
implBeanDefinitions.forEach {
it.isAutowireCandidate = false
result[it.beanClassName!!] = it
}
// 构造BeanDefinition
val builder = BeanDefinitionBuilder.genericBeanDefinition(beanClazz)
builder.addConstructorArgValue(beanClazz)
builder.addConstructorArgValue(this._beanFactory)
builder.addConstructorArgValue(implBeanDefinitions.map { it.beanClassName })
val definition = builder.rawBeanDefinition as GenericBeanDefinition
definition.beanClass = MongodbRepoFactory::class.java
definition.autowireMode = GenericBeanDefinition.AUTOWIRE_BY_TYPE
result[beanClazz.name] = definition
}
return result
}
}

View File

@@ -0,0 +1,12 @@
package com.synebula.gaea.mongodb.autoconfig
import org.springframework.context.annotation.Import
import java.lang.annotation.Inherited
@Target(AnnotationTarget.ANNOTATION_CLASS, AnnotationTarget.CLASS)
@Retention(AnnotationRetention.RUNTIME)
@MustBeDocumented
@Inherited
@Import(MongodbRepoRegister::class)
annotation class MongodbRepoScan(val basePackages: Array<String> = [])

View File

@@ -0,0 +1,102 @@
package com.synebula.gaea.mongodb.query
import com.synebula.gaea.ext.fieldNames
import com.synebula.gaea.ext.firstCharLowerCase
import com.synebula.gaea.log.ILogger
import com.synebula.gaea.mongodb.order
import com.synebula.gaea.mongodb.select
import com.synebula.gaea.mongodb.where
import com.synebula.gaea.mongodb.whereId
import com.synebula.gaea.query.IQuery
import com.synebula.gaea.query.Page
import com.synebula.gaea.query.Params
import com.synebula.gaea.query.Table
import org.springframework.data.mongodb.core.MongoTemplate
import org.springframework.data.mongodb.core.query.Criteria
import org.springframework.data.mongodb.core.query.Query
/**
* 实现IQuery的Mongodb查询类
* @param template MongodbRepo对象
*/
open class MongodbQuery<TView, ID>(override var clazz: Class<TView>, var template: MongoTemplate, var logger: ILogger) :
IQuery<TView, ID> {
/**
* 使用View解析是collection时是否校验存在默认不校验
*/
var validViewCollection = false
override fun get(id: ID): TView? {
return this.template.findOne(whereId(id), clazz, this.collection(clazz))
}
override fun list(params: Map<String, Any>?): List<TView> {
val fields = this.fields(clazz)
val query = Query()
query.where(params, clazz)
query.select(fields)
return this.find(query, clazz)
}
override fun count(params: Map<String, Any>?): Int {
val query = Query()
return this.template.count(query.where(params, clazz), this.collection(clazz)).toInt()
}
override fun paging(params: Params): Page<TView> {
val query = Query()
val fields = this.fields(clazz)
val result = Page<TView>(params.page, params.size)
result.total = this.count(params.parameters)
//如果总数和索引相同,说明该页没有数据,直接跳到上一页
if (result.total == result.index) {
params.page -= 1
result.page -= 1
}
query.select(fields)
query.where(params.parameters, clazz)
query.with(order(params.orders))
query.skip(params.index).limit(params.size)
result.data = this.find(query, clazz)
return result
}
override fun range(field: String, params: List<Any>): List<TView> {
return this.find(Query.query(Criteria.where(field).`in`(params)), clazz)
}
protected fun find(query: Query, clazz: Class<TView>): List<TView> {
return this.template.find(query, clazz, this.collection(clazz))
}
protected fun fields(clazz: Class<TView>): Array<String> {
val fields = mutableListOf<String>()
fields.addAll(clazz.fieldNames())
var parent = clazz.superclass
while (parent != Any::class.java) {
fields.addAll(clazz.superclass.fieldNames())
parent = parent.superclass
}
return fields.toTypedArray()
}
/**
* 获取collection
*/
protected fun collection(clazz: Class<TView>): String {
val table: Table? = clazz.getDeclaredAnnotation(
Table::class.java
)
return if (table != null) return table.name
else {
this.logger.info(this, "视图类没有标记[Collection]注解无法获取Collection名称。尝试使用View<${clazz.name}>名称解析集合")
val name = clazz.simpleName.removeSuffix("View").firstCharLowerCase()
if (!validViewCollection || this.template.collectionExists(name)) name
else {
throw RuntimeException("找不到名为[${clazz.name}]的集合")
}
}
}
}

View File

@@ -0,0 +1,103 @@
package com.synebula.gaea.mongodb.query
import com.synebula.gaea.ext.fieldNames
import com.synebula.gaea.ext.firstCharLowerCase
import com.synebula.gaea.log.ILogger
import com.synebula.gaea.mongodb.order
import com.synebula.gaea.mongodb.select
import com.synebula.gaea.mongodb.where
import com.synebula.gaea.mongodb.whereId
import com.synebula.gaea.query.IUniversalQuery
import com.synebula.gaea.query.Page
import com.synebula.gaea.query.Params
import com.synebula.gaea.query.Table
import org.springframework.data.mongodb.core.MongoTemplate
import org.springframework.data.mongodb.core.query.Criteria
import org.springframework.data.mongodb.core.query.Query
/**
* 实现IQuery的Mongodb查询类
* @param template MongodbRepo对象
*/
open class MongodbUniversalQuery(var template: MongoTemplate, var logger: ILogger) : IUniversalQuery {
/**
* 使用View解析是collection时是否校验存在默认不校验
*/
var validViewCollection = false
override fun <TView, ID> get(id: ID, clazz: Class<TView>): TView? {
return this.template.findOne(whereId(id), clazz, this.collection(clazz))
}
override fun <TView> list(params: Map<String, Any>?, clazz: Class<TView>): List<TView> {
val fields = this.fields(clazz)
val query = Query()
query.where(params, clazz)
query.select(fields)
return this.find(query, clazz)
}
override fun <TView> count(params: Map<String, Any>?, clazz: Class<TView>): Int {
val query = Query()
return this.template.count(query.where(params, clazz), this.collection(clazz)).toInt()
}
override fun <TView> paging(params: Params, clazz: Class<TView>): Page<TView> {
val query = Query()
val fields = this.fields(clazz)
val result = Page<TView>(params.page, params.size)
result.total = this.count(params.parameters, clazz)
//如果总数和索引相同,说明该页没有数据,直接跳到上一页
if (result.total == result.index) {
params.page -= 1
result.page -= 1
}
query.select(fields)
query.where(params.parameters, clazz)
query.with(order(params.orders))
query.skip(params.index).limit(params.size)
result.data = this.find(query, clazz)
return result
}
override fun <TView> range(field: String, params: List<Any>, clazz: Class<TView>): List<TView> {
return this.find(Query.query(Criteria.where(field).`in`(params)), clazz)
}
protected fun <TView> find(query: Query, clazz: Class<TView>): List<TView> {
return this.template.find(query, clazz, this.collection(clazz))
}
fun <TView> fields(clazz: Class<TView>): Array<String> {
val fields = mutableListOf<String>()
fields.addAll(clazz.fieldNames())
var parent = clazz.superclass
while (parent != Any::class.java) {
fields.addAll(clazz.superclass.fieldNames())
parent = parent.superclass
}
return fields.toTypedArray()
}
/**
* 获取collection
*/
fun <TView> collection(clazz: Class<TView>): String {
val table: Table? = clazz.getDeclaredAnnotation(
Table::class.java
)
return if (table != null)
return table.name
else {
this.logger.info(this, "视图类没有标记[Collection]注解无法获取Collection名称。尝试使用View<${clazz.name}>名称解析集合")
val name = clazz.simpleName.removeSuffix("View").firstCharLowerCase()
if (!validViewCollection || this.template.collectionExists(name))
name
else {
throw RuntimeException("找不到名为[${clazz.name}]的集合")
}
}
}
}

View File

@@ -0,0 +1,47 @@
package com.synebula.gaea.mongodb.repository
import com.synebula.gaea.domain.model.IAggregateRoot
import com.synebula.gaea.domain.repository.IRepository
import com.synebula.gaea.mongodb.where
import com.synebula.gaea.mongodb.whereId
import org.springframework.data.mongodb.core.MongoTemplate
import org.springframework.data.mongodb.core.query.Query
/**
* 实现[IRepository]的Mongodb仓储类
* @param repo MongodbRepo对象
*/
open class MongodbRepository<TAggregateRoot : IAggregateRoot<ID>, ID>(
override var clazz: Class<TAggregateRoot>,
protected var repo: MongoTemplate
) : IRepository<TAggregateRoot, ID> {
override fun add(obj: TAggregateRoot) {
this.repo.save(obj)
}
override fun add(list: List<TAggregateRoot>) {
this.repo.insert(list, clazz)
}
override fun remove(id: ID) {
this.repo.remove(whereId(id), clazz)
}
override fun get(id: ID): TAggregateRoot? {
return this.repo.findOne(whereId(id), clazz)
}
override fun update(obj: TAggregateRoot) {
this.repo.save(obj)
}
override fun update(list: List<TAggregateRoot>) {
this.repo.save(list)
}
override fun <TAggregateRoot> count(params: Map<String, Any>?): Int {
val query = Query()
return this.repo.count(query.where(params, clazz), clazz).toInt()
}
}

View File

@@ -0,0 +1,61 @@
package com.synebula.gaea.mongodb.repository
import com.synebula.gaea.domain.model.IAggregateRoot
import com.synebula.gaea.domain.repository.IUniversalRepository
import com.synebula.gaea.mongodb.where
import com.synebula.gaea.mongodb.whereId
import org.springframework.data.mongodb.core.MongoTemplate
import org.springframework.data.mongodb.core.query.Query
/**
* 实现ITypedRepository的Mongodb仓储类
* @param repo MongodbRepo对象
*/
open class MongodbUniversalRepository(private var repo: MongoTemplate) : IUniversalRepository {
override fun <TAggregateRoot : IAggregateRoot<ID>, ID> remove(id: ID, clazz: Class<TAggregateRoot>) {
this.repo.remove(whereId(id), clazz)
}
override fun <TAggregateRoot : IAggregateRoot<ID>, ID> get(
id: ID,
clazz: Class<TAggregateRoot>,
): TAggregateRoot? {
return this.repo.findOne(whereId(id), clazz)
}
override fun <TAggregateRoot : IAggregateRoot<ID>, ID> update(
root: TAggregateRoot,
clazz: Class<TAggregateRoot>,
) {
this.repo.save(root)
}
/**
* 更新多个个对象。
*
* @param roots 需要更新的对象。
*/
override fun <TAggregateRoot : IAggregateRoot<ID>, ID> update(
roots: List<TAggregateRoot>,
clazz: Class<TAggregateRoot>
) {
this.repo.save(roots)
}
override fun <TAggregateRoot : IAggregateRoot<ID>, ID> add(root: TAggregateRoot, clazz: Class<TAggregateRoot>) {
this.repo.save(root)
}
override fun <TAggregateRoot : IAggregateRoot<ID>, ID> add(
roots: List<TAggregateRoot>,
clazz: Class<TAggregateRoot>,
) {
this.repo.insert(roots, clazz)
}
override fun <TAggregateRoot> count(params: Map<String, Any>?, clazz: Class<TAggregateRoot>): Int {
val query = Query()
return this.repo.count(query.where(params, clazz), clazz).toInt()
}
}