统一为config *DataSourceConfig参数

This commit is contained in:
springrain
2023-09-13 12:00:14 +08:00
parent 421e42af55
commit cf938f108d
5 changed files with 37 additions and 43 deletions

View File

@@ -501,11 +501,9 @@ var queryRow = func(ctx context.Context, finder *Finder, entity interface{}) (ha
FuncLogError(ctx, errConfig)
return has, errConfig
}
dialect := config.Dialect
// 获取到sql语句
// Get the sql statement
sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, nil)
sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, nil)
if errSQL != nil {
errSQL = fmt.Errorf("->QueryRow-->wrapQuerySQL获取查询SQL语句错误:%w", errSQL)
FuncLogError(ctx, errSQL)
@@ -601,10 +599,10 @@ var queryRow = func(ctx context.Context, finder *Finder, entity interface{}) (ha
return has, errQueryRow
}
if oneColumnScanner {
err = sqlRowsValues(ctx, dialect, nil, typeOf, rows, &driverValue, columnTypes, entity, dbColumnFieldMap, exportFieldMap)
err = sqlRowsValues(ctx, config, nil, typeOf, rows, &driverValue, columnTypes, entity, dbColumnFieldMap, exportFieldMap)
} else {
pv := reflect.ValueOf(entity)
err = sqlRowsValues(ctx, dialect, &pv, typeOf, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
err = sqlRowsValues(ctx, config, &pv, typeOf, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
}
// pv = pv.Elem()
@@ -689,9 +687,7 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{},
FuncLogError(ctx, errConfig)
return errConfig
}
dialect := config.Dialect
sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, page)
sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, page)
if errSQL != nil {
errSQL = fmt.Errorf("->Query-->wrapQuerySQL获取查询SQL语句错误:%w", errSQL)
FuncLogError(ctx, errSQL)
@@ -781,9 +777,9 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{},
for rows.Next() {
pv := reflect.New(sliceElementType)
if oneColumnScanner {
err = sqlRowsValues(ctx, dialect, nil, &sliceElementType, rows, &driverValue, columnTypes, pv.Interface(), dbColumnFieldMap, exportFieldMap)
err = sqlRowsValues(ctx, config, nil, &sliceElementType, rows, &driverValue, columnTypes, pv.Interface(), dbColumnFieldMap, exportFieldMap)
} else {
err = sqlRowsValues(ctx, dialect, &pv, &sliceElementType, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
err = sqlRowsValues(ctx, config, &pv, &sliceElementType, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
}
// err = sqlRowsValues(ctx, dialect, &pv, rows, &driverValue, columnTypes, oneColumnScanner, structType, &dbColumnFieldMap, &exportFieldMap)
@@ -811,7 +807,7 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{},
// 查询总条数
// Query total number
if finder.SelectTotalCount && page != nil {
count, errCount := selectCount(ctx, finder)
count, errCount := selectCount(ctx, config, finder)
if errCount != nil {
errCount = fmt.Errorf("->Query-->selectCount查询总条数错误:%w", errCount)
FuncLogError(ctx, errCount)
@@ -895,8 +891,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
FuncLogError(ctx, errConfig)
return nil, errConfig
}
dialect := config.Dialect
sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, page)
sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, page)
if errSQL != nil {
errSQL = fmt.Errorf("->QueryMap -->wrapQuerySQL查询SQL语句错误:%w", errSQL)
FuncLogError(ctx, errSQL)
@@ -984,7 +979,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName())
// 判断是否有自定义扩展,避免无意义的反射
if iscdvm {
customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName]
customDriverValueConver, converOK = customDriverValueMap[config.Dialect+"."+databaseTypeName]
if !converOK {
customDriverValueConver, converOK = customDriverValueMap[databaseTypeName]
}
@@ -1031,7 +1026,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
case "NUMBER":
precision, scale, isDecimal := columnType.DecimalSize()
if isDecimal || precision > 18 || precision-scale > 18 { // 如果是Decimal类型
values[i] = FuncDecimalValue(ctx, dialect)
values[i] = FuncDecimalValue(ctx, config)
} else if scale > 0 { // 有小数位,默认使用float64接收
values[i] = new(float64)
} else if precision-scale > 9 { // 超过9位,使用int64
@@ -1041,7 +1036,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
}
case "DECIMAL", "NUMERIC", "DEC":
values[i] = FuncDecimalValue(ctx, dialect)
values[i] = FuncDecimalValue(ctx, config)
case "BOOLEAN", "BOOL", "BIT":
values[i] = new(bool)
default:
@@ -1098,7 +1093,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
// 查询总条数
// Query total number
if finder.SelectTotalCount && page != nil {
count, errCount := selectCount(ctx, finder)
count, errCount := selectCount(ctx, config, finder)
if errCount != nil {
errCount = fmt.Errorf("->QueryMap-->selectCount查询总条数错误:%w", errCount)
FuncLogError(ctx, errCount)
@@ -1180,7 +1175,7 @@ var insert = func(ctx context.Context, entity IEntityStruct) (int, error) {
// SQL语句
// SQL statement
sqlstr, autoIncrement, pktype, err := wrapInsertSQL(ctx, typeOf, entity, columns, values)
sqlstr, autoIncrement, pktype, err := wrapInsertSQL(ctx, dbConnection.config, typeOf, entity, columns, values)
if err != nil {
err = fmt.Errorf("->Insert-->wrapInsertSQL获取保存语句错误:%w", err)
FuncLogError(ctx, err)
@@ -1197,8 +1192,7 @@ var insert = func(ctx context.Context, entity IEntityStruct) (int, error) {
if errConfig != nil {
return affected, errConfig
}
dialect := config.Dialect
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, entity.GetPKColumnName(), sqlstr, dialect, values)
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, config, entity.GetPKColumnName(), sqlstr, values)
}
@@ -1420,7 +1414,7 @@ var insertEntityMap = func(ctx context.Context, entity IEntityMap) (int, error)
}
// SQL语句
sqlstr, values, autoIncrement, err := wrapInsertEntityMapSQL(ctx, entity)
sqlstr, values, autoIncrement, err := wrapInsertEntityMapSQL(ctx, dbConnection.config, entity)
if err != nil {
err = fmt.Errorf("->InsertEntityMap-->wrapInsertEntityMapSQL获取SQL语句错误:%w", err)
FuncLogError(ctx, err)
@@ -1435,8 +1429,7 @@ var insertEntityMap = func(ctx context.Context, entity IEntityMap) (int, error)
if errConfig != nil {
return affected, errConfig
}
dialect := config.Dialect
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, entity.GetPKColumnName(), &sqlstr, dialect, values)
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, config, entity.GetPKColumnName(), &sqlstr, values)
}
// 包装update执行,赋值给影响的函数指针变量,返回*sql.Result
@@ -1625,7 +1618,7 @@ func WrapUpdateStructFinder(ctx context.Context, entity IEntityStruct, onlyUpdat
// context必须传入,不能为空
// selectCount Query the total number of items according to finder
// context must be passed in and cannot be empty
var selectCount = func(ctx context.Context, finder *Finder) (int, error) {
var selectCount = func(ctx context.Context, config *DataSourceConfig, finder *Finder) (int, error) {
if finder == nil {
return -1, errors.New("->selectCount-->finder参数为nil")
}

View File

@@ -127,33 +127,34 @@ func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, e
}
case "wrapQuerySQL": //查询的SQL
newFunc, ok := funcObject.(func(ctx context.Context, dialect string, finder *Finder, page *Page) (string, error))
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, finder *Finder, page *Page) (string, error))
if ok {
oldFunc = wrapQuerySQL
wrapQuerySQL = newFunc
}
case "selectCount": //查询总条数
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (int, error))
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, finder *Finder) (int, error))
if ok {
oldFunc = selectCount
selectCount = newFunc
}
case "wrapPageSQL": //分页SQL
newFunc, ok := funcObject.(func(ctx context.Context, dialect string, sqlstr *string, page *Page) error)
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, sqlstr *string, page *Page) error)
if ok {
oldFunc = wrapPageSQL
wrapPageSQL = newFunc
}
case "wrapInsertSQL": //Insert IEntityStruct SQL
newFunc, ok := funcObject.(func(ctx context.Context, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error))
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error))
if ok {
oldFunc = wrapInsertSQL
wrapInsertSQL = newFunc
}
case "wrapAutoIncrementInsertSQL": //IEntityStruct 主键自增值的SQL
newFunc, ok := funcObject.(func(ctx context.Context, pkColumnName string, sqlstr *string, dialect string, values *[]interface{}) (*int64, *int64))
case "wrapAutoIncrementInsertSQL": //Insert IEntityStruct 主键自增值的SQL
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, pkColumnName string, sqlstr *string, values *[]interface{}) (*int64, *int64))
if ok {
oldFunc = wrapAutoIncrementInsertSQL
wrapAutoIncrementInsertSQL = newFunc
@@ -166,7 +167,7 @@ func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, e
wrapInsertSliceSQL = newFunc
}
case "wrapInsertEntityMapSQL": //插入 IEntityMap 的SQL
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (string, *[]interface{}, bool, error))
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, entity IEntityMap) (string, *[]interface{}, bool, error))
if ok {
oldFunc = wrapInsertEntityMapSQL
wrapInsertEntityMapSQL = newFunc

View File

@@ -34,14 +34,14 @@ import (
// wrapPageSQL 包装分页的SQL语句
// wrapPageSQL SQL statement for wrapping paging
var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page *Page) error {
var wrapPageSQL = func(ctx context.Context, config *DataSourceConfig, sqlstr *string, page *Page) error {
if page.PageNo < 1 { // 默认第一页
page.PageNo = 1
}
var sqlbuilder strings.Builder
sqlbuilder.Grow(stringBuilderGrowLen)
sqlbuilder.WriteString(*sqlstr)
switch dialect {
switch config.Dialect {
case "mysql", "sqlite", "dm", "gbase", "clickhouse", "tdengine", "db2": // MySQL,sqlite3,dm,南通,clickhouse,TDengine,db2 7.2+
sqlbuilder.WriteString(" LIMIT ")
sqlbuilder.WriteString(strconv.Itoa(page.PageSize * (page.PageNo - 1)))
@@ -74,7 +74,7 @@ var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page
sqlbuilder.WriteString(strconv.Itoa(page.PageSize))
sqlbuilder.WriteString(" ROWS ONLY ")
default:
return errors.New("->wrapPageSQL-->不支持的数据库类型:" + dialect)
return errors.New("->wrapPageSQL-->不支持的数据库类型:" + config.Dialect)
}
*sqlstr = sqlbuilder.String()
@@ -86,7 +86,7 @@ var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page
// 数组传递,如果外部方法有调用append的逻辑append会破坏指针引用所以传递指针
// wrapInsertSQL Pack and save 'Struct' statement. Return SQL statement, whether it is incremented, error message
// Array transfer, if the external method has logic to call append, append will destroy the pointer reference, so the pointer is passed
var wrapInsertSQL = func(ctx context.Context, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error) {
var wrapInsertSQL = func(ctx context.Context, config *DataSourceConfig, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error) {
sqlstr := ""
inserColumnName, valuesql, autoIncrement, pktype, err := wrapInsertValueSQL(ctx, typeOf, entity, columns, values)
if err != nil {
@@ -506,7 +506,7 @@ var wrapDeleteSQL = func(ctx context.Context, entity IEntityStruct) (string, err
// wrapInsertEntityMapSQL 包装保存Map语句,Map因为没有字段属性,无法完成Id的类型判断和赋值,需要确保Map的值是完整的
// wrapInsertEntityMapSQL Pack and save the Map statement. Because Map does not have field attributes,
// it cannot complete the type judgment and assignment of Id. It is necessary to ensure that the value of Map is complete
var wrapInsertEntityMapSQL = func(ctx context.Context, entity IEntityMap) (string, *[]interface{}, bool, error) {
var wrapInsertEntityMapSQL = func(ctx context.Context, config *DataSourceConfig, entity IEntityMap) (string, *[]interface{}, bool, error) {
sqlstr := ""
inserColumnName, valuesql, values, autoIncrement, err := wrapInsertValueEntityMapSQL(entity)
if err != nil {
@@ -649,7 +649,7 @@ var wrapUpdateEntityMapSQL = func(ctx context.Context, entity IEntityMap) (*stri
// wrapQuerySQL 封装查询语句
// wrapQuerySQL Encapsulated query statement
var wrapQuerySQL = func(ctx context.Context, dialect string, finder *Finder, page *Page) (string, error) {
var wrapQuerySQL = func(ctx context.Context, config *DataSourceConfig, finder *Finder, page *Page) (string, error) {
// 获取到没有page的sql的语句
// Get the SQL statement without page.
sqlstr, err := finder.GetSQL()
@@ -657,7 +657,7 @@ var wrapQuerySQL = func(ctx context.Context, dialect string, finder *Finder, pag
return "", err
}
if page != nil {
err = wrapPageSQL(ctx, dialect, &sqlstr, page)
err = wrapPageSQL(ctx, config, &sqlstr, page)
}
if err != nil {
return "", err
@@ -1024,7 +1024,7 @@ func reUpdateSQL(dialect string, sqlstr *string) error {
}
// wrapAutoIncrementInsertSQL 包装自增的自增主键的插入sql
var wrapAutoIncrementInsertSQL = func(ctx context.Context, pkColumnName string, sqlstr *string, dialect string, values *[]interface{}) (*int64, *int64) {
var wrapAutoIncrementInsertSQL = func(ctx context.Context, config *DataSourceConfig, pkColumnName string, sqlstr *string, values *[]interface{}) (*int64, *int64) {
// oracle 12c+ 支持IDENTITY属性的自增列,因为分页也要求12c+的语法,所以数据库就IDENTITY创建自增吧
// 处理序列产生的自增主键,例如oracle,postgresql等
var lastInsertID, zormSQLOutReturningID *int64
@@ -1032,7 +1032,7 @@ var wrapAutoIncrementInsertSQL = func(ctx context.Context, pkColumnName string,
// sqlBuilder.Grow(len(*sqlstr) + len(pkColumnName) + 40)
sqlBuilder.Grow(stringBuilderGrowLen)
sqlBuilder.WriteString(*sqlstr)
switch dialect {
switch config.Dialect {
case "postgresql", "kingbase":
var p int64 = 0
lastInsertID = &p

View File

@@ -491,7 +491,7 @@ func checkEntityKind(entity interface{}) (*reflect.Type, error) {
// 当读取数据库的值为NULL时,由于基本类型不支持为NULL,通过反射将未知driver.Value改为interface{},不再映射到struct实体类
// 感谢@fastabler提交的pr
// oneColumnScanner 只有一个字段,而且可以直接Scan,例如string或者[]string,不需要反射StructType进行处理
func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error {
func sqlRowsValues(ctx context.Context, config *DataSourceConfig, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error {
if entity == nil && valueOf == nil {
return errors.New("->sqlRowsValues-->valueOfElem为nil")
}
@@ -518,7 +518,7 @@ func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value,
if iscdvm {
databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName())
// 根据接收的类型,获取到类型转换的接口实现,优先匹配指定的数据库类型
customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName]
customDriverValueConver, converOK = customDriverValueMap[config.Dialect+"."+databaseTypeName]
if !converOK {
customDriverValueConver, converOK = customDriverValueMap[databaseTypeName]
}

View File

@@ -26,7 +26,7 @@ import (
)
// FuncDecimalValue 设置decimal类型接收值,复写函数自定义decimal实现,例如github.com/shopspring/decimal,返回的是指针
var FuncDecimalValue = func(ctx context.Context, dialect string) interface{} {
var FuncDecimalValue = func(ctx context.Context, config *DataSourceConfig) interface{} {
return &decimal.Decimal{}
}