From cf938f108d516690254273d1612c0661cbe379f4 Mon Sep 17 00:00:00 2001 From: springrain Date: Wed, 13 Sep 2023 12:00:14 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=9F=E4=B8=80=E4=B8=BAconfig=20*DataSource?= =?UTF-8?q?Config=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DBDao.go | 41 +++++++++++++++++------------------------ OverrideFunc.go | 15 ++++++++------- dialect.go | 18 +++++++++--------- structFieldInfo.go | 4 ++-- typeConvert.go | 2 +- 5 files changed, 37 insertions(+), 43 deletions(-) diff --git a/DBDao.go b/DBDao.go index 89d360e..436311a 100644 --- a/DBDao.go +++ b/DBDao.go @@ -501,11 +501,9 @@ var queryRow = func(ctx context.Context, finder *Finder, entity interface{}) (ha FuncLogError(ctx, errConfig) return has, errConfig } - dialect := config.Dialect - // 获取到sql语句 // Get the sql statement - sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, nil) + sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, nil) if errSQL != nil { errSQL = fmt.Errorf("->QueryRow-->wrapQuerySQL获取查询SQL语句错误:%w", errSQL) FuncLogError(ctx, errSQL) @@ -601,10 +599,10 @@ var queryRow = func(ctx context.Context, finder *Finder, entity interface{}) (ha return has, errQueryRow } if oneColumnScanner { - err = sqlRowsValues(ctx, dialect, nil, typeOf, rows, &driverValue, columnTypes, entity, dbColumnFieldMap, exportFieldMap) + err = sqlRowsValues(ctx, config, nil, typeOf, rows, &driverValue, columnTypes, entity, dbColumnFieldMap, exportFieldMap) } else { pv := reflect.ValueOf(entity) - err = sqlRowsValues(ctx, dialect, &pv, typeOf, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap) + err = sqlRowsValues(ctx, config, &pv, typeOf, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap) } // pv = pv.Elem() @@ -689,9 +687,7 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{}, FuncLogError(ctx, errConfig) return errConfig } - dialect := config.Dialect - - sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, page) + sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, page) if errSQL != nil { errSQL = fmt.Errorf("->Query-->wrapQuerySQL获取查询SQL语句错误:%w", errSQL) FuncLogError(ctx, errSQL) @@ -781,9 +777,9 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{}, for rows.Next() { pv := reflect.New(sliceElementType) if oneColumnScanner { - err = sqlRowsValues(ctx, dialect, nil, &sliceElementType, rows, &driverValue, columnTypes, pv.Interface(), dbColumnFieldMap, exportFieldMap) + err = sqlRowsValues(ctx, config, nil, &sliceElementType, rows, &driverValue, columnTypes, pv.Interface(), dbColumnFieldMap, exportFieldMap) } else { - err = sqlRowsValues(ctx, dialect, &pv, &sliceElementType, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap) + err = sqlRowsValues(ctx, config, &pv, &sliceElementType, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap) } // err = sqlRowsValues(ctx, dialect, &pv, rows, &driverValue, columnTypes, oneColumnScanner, structType, &dbColumnFieldMap, &exportFieldMap) @@ -811,7 +807,7 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{}, // 查询总条数 // Query total number if finder.SelectTotalCount && page != nil { - count, errCount := selectCount(ctx, finder) + count, errCount := selectCount(ctx, config, finder) if errCount != nil { errCount = fmt.Errorf("->Query-->selectCount查询总条数错误:%w", errCount) FuncLogError(ctx, errCount) @@ -895,8 +891,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL FuncLogError(ctx, errConfig) return nil, errConfig } - dialect := config.Dialect - sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, page) + sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, page) if errSQL != nil { errSQL = fmt.Errorf("->QueryMap -->wrapQuerySQL查询SQL语句错误:%w", errSQL) FuncLogError(ctx, errSQL) @@ -984,7 +979,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName()) // 判断是否有自定义扩展,避免无意义的反射 if iscdvm { - customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName] + customDriverValueConver, converOK = customDriverValueMap[config.Dialect+"."+databaseTypeName] if !converOK { customDriverValueConver, converOK = customDriverValueMap[databaseTypeName] } @@ -1031,7 +1026,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL case "NUMBER": precision, scale, isDecimal := columnType.DecimalSize() if isDecimal || precision > 18 || precision-scale > 18 { // 如果是Decimal类型 - values[i] = FuncDecimalValue(ctx, dialect) + values[i] = FuncDecimalValue(ctx, config) } else if scale > 0 { // 有小数位,默认使用float64接收 values[i] = new(float64) } else if precision-scale > 9 { // 超过9位,使用int64 @@ -1041,7 +1036,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL } case "DECIMAL", "NUMERIC", "DEC": - values[i] = FuncDecimalValue(ctx, dialect) + values[i] = FuncDecimalValue(ctx, config) case "BOOLEAN", "BOOL", "BIT": values[i] = new(bool) default: @@ -1098,7 +1093,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL // 查询总条数 // Query total number if finder.SelectTotalCount && page != nil { - count, errCount := selectCount(ctx, finder) + count, errCount := selectCount(ctx, config, finder) if errCount != nil { errCount = fmt.Errorf("->QueryMap-->selectCount查询总条数错误:%w", errCount) FuncLogError(ctx, errCount) @@ -1180,7 +1175,7 @@ var insert = func(ctx context.Context, entity IEntityStruct) (int, error) { // SQL语句 // SQL statement - sqlstr, autoIncrement, pktype, err := wrapInsertSQL(ctx, typeOf, entity, columns, values) + sqlstr, autoIncrement, pktype, err := wrapInsertSQL(ctx, dbConnection.config, typeOf, entity, columns, values) if err != nil { err = fmt.Errorf("->Insert-->wrapInsertSQL获取保存语句错误:%w", err) FuncLogError(ctx, err) @@ -1197,8 +1192,7 @@ var insert = func(ctx context.Context, entity IEntityStruct) (int, error) { if errConfig != nil { return affected, errConfig } - dialect := config.Dialect - lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, entity.GetPKColumnName(), sqlstr, dialect, values) + lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, config, entity.GetPKColumnName(), sqlstr, values) } @@ -1420,7 +1414,7 @@ var insertEntityMap = func(ctx context.Context, entity IEntityMap) (int, error) } // SQL语句 - sqlstr, values, autoIncrement, err := wrapInsertEntityMapSQL(ctx, entity) + sqlstr, values, autoIncrement, err := wrapInsertEntityMapSQL(ctx, dbConnection.config, entity) if err != nil { err = fmt.Errorf("->InsertEntityMap-->wrapInsertEntityMapSQL获取SQL语句错误:%w", err) FuncLogError(ctx, err) @@ -1435,8 +1429,7 @@ var insertEntityMap = func(ctx context.Context, entity IEntityMap) (int, error) if errConfig != nil { return affected, errConfig } - dialect := config.Dialect - lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, entity.GetPKColumnName(), &sqlstr, dialect, values) + lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, config, entity.GetPKColumnName(), &sqlstr, values) } // 包装update执行,赋值给影响的函数指针变量,返回*sql.Result @@ -1625,7 +1618,7 @@ func WrapUpdateStructFinder(ctx context.Context, entity IEntityStruct, onlyUpdat // context必须传入,不能为空 // selectCount Query the total number of items according to finder // context must be passed in and cannot be empty -var selectCount = func(ctx context.Context, finder *Finder) (int, error) { +var selectCount = func(ctx context.Context, config *DataSourceConfig, finder *Finder) (int, error) { if finder == nil { return -1, errors.New("->selectCount-->finder参数为nil") } diff --git a/OverrideFunc.go b/OverrideFunc.go index 04def4c..345c3de 100644 --- a/OverrideFunc.go +++ b/OverrideFunc.go @@ -127,33 +127,34 @@ func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, e } case "wrapQuerySQL": //查询的SQL - newFunc, ok := funcObject.(func(ctx context.Context, dialect string, finder *Finder, page *Page) (string, error)) + newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, finder *Finder, page *Page) (string, error)) if ok { oldFunc = wrapQuerySQL wrapQuerySQL = newFunc } case "selectCount": //查询总条数 - newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (int, error)) + newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, finder *Finder) (int, error)) if ok { oldFunc = selectCount selectCount = newFunc } case "wrapPageSQL": //分页SQL - newFunc, ok := funcObject.(func(ctx context.Context, dialect string, sqlstr *string, page *Page) error) + newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, sqlstr *string, page *Page) error) if ok { oldFunc = wrapPageSQL wrapPageSQL = newFunc } case "wrapInsertSQL": //Insert IEntityStruct SQL - newFunc, ok := funcObject.(func(ctx context.Context, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error)) + newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error)) if ok { oldFunc = wrapInsertSQL wrapInsertSQL = newFunc } - case "wrapAutoIncrementInsertSQL": //IEntityStruct 主键自增值的SQL - newFunc, ok := funcObject.(func(ctx context.Context, pkColumnName string, sqlstr *string, dialect string, values *[]interface{}) (*int64, *int64)) + + case "wrapAutoIncrementInsertSQL": //Insert IEntityStruct 主键自增值的SQL + newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, pkColumnName string, sqlstr *string, values *[]interface{}) (*int64, *int64)) if ok { oldFunc = wrapAutoIncrementInsertSQL wrapAutoIncrementInsertSQL = newFunc @@ -166,7 +167,7 @@ func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, e wrapInsertSliceSQL = newFunc } case "wrapInsertEntityMapSQL": //插入 IEntityMap 的SQL - newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (string, *[]interface{}, bool, error)) + newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, entity IEntityMap) (string, *[]interface{}, bool, error)) if ok { oldFunc = wrapInsertEntityMapSQL wrapInsertEntityMapSQL = newFunc diff --git a/dialect.go b/dialect.go index 4fb8b84..a78dde3 100644 --- a/dialect.go +++ b/dialect.go @@ -34,14 +34,14 @@ import ( // wrapPageSQL 包装分页的SQL语句 // wrapPageSQL SQL statement for wrapping paging -var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page *Page) error { +var wrapPageSQL = func(ctx context.Context, config *DataSourceConfig, sqlstr *string, page *Page) error { if page.PageNo < 1 { // 默认第一页 page.PageNo = 1 } var sqlbuilder strings.Builder sqlbuilder.Grow(stringBuilderGrowLen) sqlbuilder.WriteString(*sqlstr) - switch dialect { + switch config.Dialect { case "mysql", "sqlite", "dm", "gbase", "clickhouse", "tdengine", "db2": // MySQL,sqlite3,dm,南通,clickhouse,TDengine,db2 7.2+ sqlbuilder.WriteString(" LIMIT ") sqlbuilder.WriteString(strconv.Itoa(page.PageSize * (page.PageNo - 1))) @@ -74,7 +74,7 @@ var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page sqlbuilder.WriteString(strconv.Itoa(page.PageSize)) sqlbuilder.WriteString(" ROWS ONLY ") default: - return errors.New("->wrapPageSQL-->不支持的数据库类型:" + dialect) + return errors.New("->wrapPageSQL-->不支持的数据库类型:" + config.Dialect) } *sqlstr = sqlbuilder.String() @@ -86,7 +86,7 @@ var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page // 数组传递,如果外部方法有调用append的逻辑,append会破坏指针引用,所以传递指针 // wrapInsertSQL Pack and save 'Struct' statement. Return SQL statement, whether it is incremented, error message // Array transfer, if the external method has logic to call append, append will destroy the pointer reference, so the pointer is passed -var wrapInsertSQL = func(ctx context.Context, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error) { +var wrapInsertSQL = func(ctx context.Context, config *DataSourceConfig, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error) { sqlstr := "" inserColumnName, valuesql, autoIncrement, pktype, err := wrapInsertValueSQL(ctx, typeOf, entity, columns, values) if err != nil { @@ -506,7 +506,7 @@ var wrapDeleteSQL = func(ctx context.Context, entity IEntityStruct) (string, err // wrapInsertEntityMapSQL 包装保存Map语句,Map因为没有字段属性,无法完成Id的类型判断和赋值,需要确保Map的值是完整的 // wrapInsertEntityMapSQL Pack and save the Map statement. Because Map does not have field attributes, // it cannot complete the type judgment and assignment of Id. It is necessary to ensure that the value of Map is complete -var wrapInsertEntityMapSQL = func(ctx context.Context, entity IEntityMap) (string, *[]interface{}, bool, error) { +var wrapInsertEntityMapSQL = func(ctx context.Context, config *DataSourceConfig, entity IEntityMap) (string, *[]interface{}, bool, error) { sqlstr := "" inserColumnName, valuesql, values, autoIncrement, err := wrapInsertValueEntityMapSQL(entity) if err != nil { @@ -649,7 +649,7 @@ var wrapUpdateEntityMapSQL = func(ctx context.Context, entity IEntityMap) (*stri // wrapQuerySQL 封装查询语句 // wrapQuerySQL Encapsulated query statement -var wrapQuerySQL = func(ctx context.Context, dialect string, finder *Finder, page *Page) (string, error) { +var wrapQuerySQL = func(ctx context.Context, config *DataSourceConfig, finder *Finder, page *Page) (string, error) { // 获取到没有page的sql的语句 // Get the SQL statement without page. sqlstr, err := finder.GetSQL() @@ -657,7 +657,7 @@ var wrapQuerySQL = func(ctx context.Context, dialect string, finder *Finder, pag return "", err } if page != nil { - err = wrapPageSQL(ctx, dialect, &sqlstr, page) + err = wrapPageSQL(ctx, config, &sqlstr, page) } if err != nil { return "", err @@ -1024,7 +1024,7 @@ func reUpdateSQL(dialect string, sqlstr *string) error { } // wrapAutoIncrementInsertSQL 包装自增的自增主键的插入sql -var wrapAutoIncrementInsertSQL = func(ctx context.Context, pkColumnName string, sqlstr *string, dialect string, values *[]interface{}) (*int64, *int64) { +var wrapAutoIncrementInsertSQL = func(ctx context.Context, config *DataSourceConfig, pkColumnName string, sqlstr *string, values *[]interface{}) (*int64, *int64) { // oracle 12c+ 支持IDENTITY属性的自增列,因为分页也要求12c+的语法,所以数据库就IDENTITY创建自增吧 // 处理序列产生的自增主键,例如oracle,postgresql等 var lastInsertID, zormSQLOutReturningID *int64 @@ -1032,7 +1032,7 @@ var wrapAutoIncrementInsertSQL = func(ctx context.Context, pkColumnName string, // sqlBuilder.Grow(len(*sqlstr) + len(pkColumnName) + 40) sqlBuilder.Grow(stringBuilderGrowLen) sqlBuilder.WriteString(*sqlstr) - switch dialect { + switch config.Dialect { case "postgresql", "kingbase": var p int64 = 0 lastInsertID = &p diff --git a/structFieldInfo.go b/structFieldInfo.go index 8465e55..f4fcbb5 100644 --- a/structFieldInfo.go +++ b/structFieldInfo.go @@ -491,7 +491,7 @@ func checkEntityKind(entity interface{}) (*reflect.Type, error) { // 当读取数据库的值为NULL时,由于基本类型不支持为NULL,通过反射将未知driver.Value改为interface{},不再映射到struct实体类 // 感谢@fastabler提交的pr // oneColumnScanner 只有一个字段,而且可以直接Scan,例如string或者[]string,不需要反射StructType进行处理 -func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error { +func sqlRowsValues(ctx context.Context, config *DataSourceConfig, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error { if entity == nil && valueOf == nil { return errors.New("->sqlRowsValues-->valueOfElem为nil") } @@ -518,7 +518,7 @@ func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value, if iscdvm { databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName()) // 根据接收的类型,获取到类型转换的接口实现,优先匹配指定的数据库类型 - customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName] + customDriverValueConver, converOK = customDriverValueMap[config.Dialect+"."+databaseTypeName] if !converOK { customDriverValueConver, converOK = customDriverValueMap[databaseTypeName] } diff --git a/typeConvert.go b/typeConvert.go index 9b605e1..e6452d9 100644 --- a/typeConvert.go +++ b/typeConvert.go @@ -26,7 +26,7 @@ import ( ) // FuncDecimalValue 设置decimal类型接收值,复写函数自定义decimal实现,例如github.com/shopspring/decimal,返回的是指针 -var FuncDecimalValue = func(ctx context.Context, dialect string) interface{} { +var FuncDecimalValue = func(ctx context.Context, config *DataSourceConfig) interface{} { return &decimal.Decimal{} }