mirror of
https://gitee.com/chunanyong/zorm.git
synced 2025-12-06 07:08:49 +08:00
统一为config *DataSourceConfig参数
This commit is contained in:
41
DBDao.go
41
DBDao.go
@@ -501,11 +501,9 @@ var queryRow = func(ctx context.Context, finder *Finder, entity interface{}) (ha
|
||||
FuncLogError(ctx, errConfig)
|
||||
return has, errConfig
|
||||
}
|
||||
dialect := config.Dialect
|
||||
|
||||
// 获取到sql语句
|
||||
// Get the sql statement
|
||||
sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, nil)
|
||||
sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, nil)
|
||||
if errSQL != nil {
|
||||
errSQL = fmt.Errorf("->QueryRow-->wrapQuerySQL获取查询SQL语句错误:%w", errSQL)
|
||||
FuncLogError(ctx, errSQL)
|
||||
@@ -601,10 +599,10 @@ var queryRow = func(ctx context.Context, finder *Finder, entity interface{}) (ha
|
||||
return has, errQueryRow
|
||||
}
|
||||
if oneColumnScanner {
|
||||
err = sqlRowsValues(ctx, dialect, nil, typeOf, rows, &driverValue, columnTypes, entity, dbColumnFieldMap, exportFieldMap)
|
||||
err = sqlRowsValues(ctx, config, nil, typeOf, rows, &driverValue, columnTypes, entity, dbColumnFieldMap, exportFieldMap)
|
||||
} else {
|
||||
pv := reflect.ValueOf(entity)
|
||||
err = sqlRowsValues(ctx, dialect, &pv, typeOf, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
|
||||
err = sqlRowsValues(ctx, config, &pv, typeOf, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
|
||||
}
|
||||
|
||||
// pv = pv.Elem()
|
||||
@@ -689,9 +687,7 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{},
|
||||
FuncLogError(ctx, errConfig)
|
||||
return errConfig
|
||||
}
|
||||
dialect := config.Dialect
|
||||
|
||||
sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, page)
|
||||
sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, page)
|
||||
if errSQL != nil {
|
||||
errSQL = fmt.Errorf("->Query-->wrapQuerySQL获取查询SQL语句错误:%w", errSQL)
|
||||
FuncLogError(ctx, errSQL)
|
||||
@@ -781,9 +777,9 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{},
|
||||
for rows.Next() {
|
||||
pv := reflect.New(sliceElementType)
|
||||
if oneColumnScanner {
|
||||
err = sqlRowsValues(ctx, dialect, nil, &sliceElementType, rows, &driverValue, columnTypes, pv.Interface(), dbColumnFieldMap, exportFieldMap)
|
||||
err = sqlRowsValues(ctx, config, nil, &sliceElementType, rows, &driverValue, columnTypes, pv.Interface(), dbColumnFieldMap, exportFieldMap)
|
||||
} else {
|
||||
err = sqlRowsValues(ctx, dialect, &pv, &sliceElementType, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
|
||||
err = sqlRowsValues(ctx, config, &pv, &sliceElementType, rows, &driverValue, columnTypes, nil, dbColumnFieldMap, exportFieldMap)
|
||||
}
|
||||
|
||||
// err = sqlRowsValues(ctx, dialect, &pv, rows, &driverValue, columnTypes, oneColumnScanner, structType, &dbColumnFieldMap, &exportFieldMap)
|
||||
@@ -811,7 +807,7 @@ var query = func(ctx context.Context, finder *Finder, rowsSlicePtr interface{},
|
||||
// 查询总条数
|
||||
// Query total number
|
||||
if finder.SelectTotalCount && page != nil {
|
||||
count, errCount := selectCount(ctx, finder)
|
||||
count, errCount := selectCount(ctx, config, finder)
|
||||
if errCount != nil {
|
||||
errCount = fmt.Errorf("->Query-->selectCount查询总条数错误:%w", errCount)
|
||||
FuncLogError(ctx, errCount)
|
||||
@@ -895,8 +891,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
|
||||
FuncLogError(ctx, errConfig)
|
||||
return nil, errConfig
|
||||
}
|
||||
dialect := config.Dialect
|
||||
sqlstr, errSQL := wrapQuerySQL(ctx, dialect, finder, page)
|
||||
sqlstr, errSQL := wrapQuerySQL(ctx, config, finder, page)
|
||||
if errSQL != nil {
|
||||
errSQL = fmt.Errorf("->QueryMap -->wrapQuerySQL查询SQL语句错误:%w", errSQL)
|
||||
FuncLogError(ctx, errSQL)
|
||||
@@ -984,7 +979,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
|
||||
databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName())
|
||||
// 判断是否有自定义扩展,避免无意义的反射
|
||||
if iscdvm {
|
||||
customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName]
|
||||
customDriverValueConver, converOK = customDriverValueMap[config.Dialect+"."+databaseTypeName]
|
||||
if !converOK {
|
||||
customDriverValueConver, converOK = customDriverValueMap[databaseTypeName]
|
||||
}
|
||||
@@ -1031,7 +1026,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
|
||||
case "NUMBER":
|
||||
precision, scale, isDecimal := columnType.DecimalSize()
|
||||
if isDecimal || precision > 18 || precision-scale > 18 { // 如果是Decimal类型
|
||||
values[i] = FuncDecimalValue(ctx, dialect)
|
||||
values[i] = FuncDecimalValue(ctx, config)
|
||||
} else if scale > 0 { // 有小数位,默认使用float64接收
|
||||
values[i] = new(float64)
|
||||
} else if precision-scale > 9 { // 超过9位,使用int64
|
||||
@@ -1041,7 +1036,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
|
||||
}
|
||||
|
||||
case "DECIMAL", "NUMERIC", "DEC":
|
||||
values[i] = FuncDecimalValue(ctx, dialect)
|
||||
values[i] = FuncDecimalValue(ctx, config)
|
||||
case "BOOLEAN", "BOOL", "BIT":
|
||||
values[i] = new(bool)
|
||||
default:
|
||||
@@ -1098,7 +1093,7 @@ var queryMap = func(ctx context.Context, finder *Finder, page *Page) (resultMapL
|
||||
// 查询总条数
|
||||
// Query total number
|
||||
if finder.SelectTotalCount && page != nil {
|
||||
count, errCount := selectCount(ctx, finder)
|
||||
count, errCount := selectCount(ctx, config, finder)
|
||||
if errCount != nil {
|
||||
errCount = fmt.Errorf("->QueryMap-->selectCount查询总条数错误:%w", errCount)
|
||||
FuncLogError(ctx, errCount)
|
||||
@@ -1180,7 +1175,7 @@ var insert = func(ctx context.Context, entity IEntityStruct) (int, error) {
|
||||
|
||||
// SQL语句
|
||||
// SQL statement
|
||||
sqlstr, autoIncrement, pktype, err := wrapInsertSQL(ctx, typeOf, entity, columns, values)
|
||||
sqlstr, autoIncrement, pktype, err := wrapInsertSQL(ctx, dbConnection.config, typeOf, entity, columns, values)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->Insert-->wrapInsertSQL获取保存语句错误:%w", err)
|
||||
FuncLogError(ctx, err)
|
||||
@@ -1197,8 +1192,7 @@ var insert = func(ctx context.Context, entity IEntityStruct) (int, error) {
|
||||
if errConfig != nil {
|
||||
return affected, errConfig
|
||||
}
|
||||
dialect := config.Dialect
|
||||
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, entity.GetPKColumnName(), sqlstr, dialect, values)
|
||||
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, config, entity.GetPKColumnName(), sqlstr, values)
|
||||
|
||||
}
|
||||
|
||||
@@ -1420,7 +1414,7 @@ var insertEntityMap = func(ctx context.Context, entity IEntityMap) (int, error)
|
||||
}
|
||||
|
||||
// SQL语句
|
||||
sqlstr, values, autoIncrement, err := wrapInsertEntityMapSQL(ctx, entity)
|
||||
sqlstr, values, autoIncrement, err := wrapInsertEntityMapSQL(ctx, dbConnection.config, entity)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->InsertEntityMap-->wrapInsertEntityMapSQL获取SQL语句错误:%w", err)
|
||||
FuncLogError(ctx, err)
|
||||
@@ -1435,8 +1429,7 @@ var insertEntityMap = func(ctx context.Context, entity IEntityMap) (int, error)
|
||||
if errConfig != nil {
|
||||
return affected, errConfig
|
||||
}
|
||||
dialect := config.Dialect
|
||||
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, entity.GetPKColumnName(), &sqlstr, dialect, values)
|
||||
lastInsertID, zormSQLOutReturningID = wrapAutoIncrementInsertSQL(ctx, config, entity.GetPKColumnName(), &sqlstr, values)
|
||||
}
|
||||
|
||||
// 包装update执行,赋值给影响的函数指针变量,返回*sql.Result
|
||||
@@ -1625,7 +1618,7 @@ func WrapUpdateStructFinder(ctx context.Context, entity IEntityStruct, onlyUpdat
|
||||
// context必须传入,不能为空
|
||||
// selectCount Query the total number of items according to finder
|
||||
// context must be passed in and cannot be empty
|
||||
var selectCount = func(ctx context.Context, finder *Finder) (int, error) {
|
||||
var selectCount = func(ctx context.Context, config *DataSourceConfig, finder *Finder) (int, error) {
|
||||
if finder == nil {
|
||||
return -1, errors.New("->selectCount-->finder参数为nil")
|
||||
}
|
||||
|
||||
@@ -127,33 +127,34 @@ func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, e
|
||||
}
|
||||
|
||||
case "wrapQuerySQL": //查询的SQL
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, dialect string, finder *Finder, page *Page) (string, error))
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, finder *Finder, page *Page) (string, error))
|
||||
if ok {
|
||||
oldFunc = wrapQuerySQL
|
||||
wrapQuerySQL = newFunc
|
||||
}
|
||||
|
||||
case "selectCount": //查询总条数
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (int, error))
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, finder *Finder) (int, error))
|
||||
if ok {
|
||||
oldFunc = selectCount
|
||||
selectCount = newFunc
|
||||
}
|
||||
case "wrapPageSQL": //分页SQL
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, dialect string, sqlstr *string, page *Page) error)
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, sqlstr *string, page *Page) error)
|
||||
if ok {
|
||||
oldFunc = wrapPageSQL
|
||||
wrapPageSQL = newFunc
|
||||
}
|
||||
|
||||
case "wrapInsertSQL": //Insert IEntityStruct SQL
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error))
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error))
|
||||
if ok {
|
||||
oldFunc = wrapInsertSQL
|
||||
wrapInsertSQL = newFunc
|
||||
}
|
||||
case "wrapAutoIncrementInsertSQL": //IEntityStruct 主键自增值的SQL
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, pkColumnName string, sqlstr *string, dialect string, values *[]interface{}) (*int64, *int64))
|
||||
|
||||
case "wrapAutoIncrementInsertSQL": //Insert IEntityStruct 主键自增值的SQL
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, pkColumnName string, sqlstr *string, values *[]interface{}) (*int64, *int64))
|
||||
if ok {
|
||||
oldFunc = wrapAutoIncrementInsertSQL
|
||||
wrapAutoIncrementInsertSQL = newFunc
|
||||
@@ -166,7 +167,7 @@ func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, e
|
||||
wrapInsertSliceSQL = newFunc
|
||||
}
|
||||
case "wrapInsertEntityMapSQL": //插入 IEntityMap 的SQL
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (string, *[]interface{}, bool, error))
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, config *DataSourceConfig, entity IEntityMap) (string, *[]interface{}, bool, error))
|
||||
if ok {
|
||||
oldFunc = wrapInsertEntityMapSQL
|
||||
wrapInsertEntityMapSQL = newFunc
|
||||
|
||||
18
dialect.go
18
dialect.go
@@ -34,14 +34,14 @@ import (
|
||||
|
||||
// wrapPageSQL 包装分页的SQL语句
|
||||
// wrapPageSQL SQL statement for wrapping paging
|
||||
var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page *Page) error {
|
||||
var wrapPageSQL = func(ctx context.Context, config *DataSourceConfig, sqlstr *string, page *Page) error {
|
||||
if page.PageNo < 1 { // 默认第一页
|
||||
page.PageNo = 1
|
||||
}
|
||||
var sqlbuilder strings.Builder
|
||||
sqlbuilder.Grow(stringBuilderGrowLen)
|
||||
sqlbuilder.WriteString(*sqlstr)
|
||||
switch dialect {
|
||||
switch config.Dialect {
|
||||
case "mysql", "sqlite", "dm", "gbase", "clickhouse", "tdengine", "db2": // MySQL,sqlite3,dm,南通,clickhouse,TDengine,db2 7.2+
|
||||
sqlbuilder.WriteString(" LIMIT ")
|
||||
sqlbuilder.WriteString(strconv.Itoa(page.PageSize * (page.PageNo - 1)))
|
||||
@@ -74,7 +74,7 @@ var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page
|
||||
sqlbuilder.WriteString(strconv.Itoa(page.PageSize))
|
||||
sqlbuilder.WriteString(" ROWS ONLY ")
|
||||
default:
|
||||
return errors.New("->wrapPageSQL-->不支持的数据库类型:" + dialect)
|
||||
return errors.New("->wrapPageSQL-->不支持的数据库类型:" + config.Dialect)
|
||||
|
||||
}
|
||||
*sqlstr = sqlbuilder.String()
|
||||
@@ -86,7 +86,7 @@ var wrapPageSQL = func(ctx context.Context, dialect string, sqlstr *string, page
|
||||
// 数组传递,如果外部方法有调用append的逻辑,append会破坏指针引用,所以传递指针
|
||||
// wrapInsertSQL Pack and save 'Struct' statement. Return SQL statement, whether it is incremented, error message
|
||||
// Array transfer, if the external method has logic to call append, append will destroy the pointer reference, so the pointer is passed
|
||||
var wrapInsertSQL = func(ctx context.Context, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error) {
|
||||
var wrapInsertSQL = func(ctx context.Context, config *DataSourceConfig, typeOf *reflect.Type, entity IEntityStruct, columns *[]reflect.StructField, values *[]interface{}) (*string, int, string, error) {
|
||||
sqlstr := ""
|
||||
inserColumnName, valuesql, autoIncrement, pktype, err := wrapInsertValueSQL(ctx, typeOf, entity, columns, values)
|
||||
if err != nil {
|
||||
@@ -506,7 +506,7 @@ var wrapDeleteSQL = func(ctx context.Context, entity IEntityStruct) (string, err
|
||||
// wrapInsertEntityMapSQL 包装保存Map语句,Map因为没有字段属性,无法完成Id的类型判断和赋值,需要确保Map的值是完整的
|
||||
// wrapInsertEntityMapSQL Pack and save the Map statement. Because Map does not have field attributes,
|
||||
// it cannot complete the type judgment and assignment of Id. It is necessary to ensure that the value of Map is complete
|
||||
var wrapInsertEntityMapSQL = func(ctx context.Context, entity IEntityMap) (string, *[]interface{}, bool, error) {
|
||||
var wrapInsertEntityMapSQL = func(ctx context.Context, config *DataSourceConfig, entity IEntityMap) (string, *[]interface{}, bool, error) {
|
||||
sqlstr := ""
|
||||
inserColumnName, valuesql, values, autoIncrement, err := wrapInsertValueEntityMapSQL(entity)
|
||||
if err != nil {
|
||||
@@ -649,7 +649,7 @@ var wrapUpdateEntityMapSQL = func(ctx context.Context, entity IEntityMap) (*stri
|
||||
|
||||
// wrapQuerySQL 封装查询语句
|
||||
// wrapQuerySQL Encapsulated query statement
|
||||
var wrapQuerySQL = func(ctx context.Context, dialect string, finder *Finder, page *Page) (string, error) {
|
||||
var wrapQuerySQL = func(ctx context.Context, config *DataSourceConfig, finder *Finder, page *Page) (string, error) {
|
||||
// 获取到没有page的sql的语句
|
||||
// Get the SQL statement without page.
|
||||
sqlstr, err := finder.GetSQL()
|
||||
@@ -657,7 +657,7 @@ var wrapQuerySQL = func(ctx context.Context, dialect string, finder *Finder, pag
|
||||
return "", err
|
||||
}
|
||||
if page != nil {
|
||||
err = wrapPageSQL(ctx, dialect, &sqlstr, page)
|
||||
err = wrapPageSQL(ctx, config, &sqlstr, page)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1024,7 +1024,7 @@ func reUpdateSQL(dialect string, sqlstr *string) error {
|
||||
}
|
||||
|
||||
// wrapAutoIncrementInsertSQL 包装自增的自增主键的插入sql
|
||||
var wrapAutoIncrementInsertSQL = func(ctx context.Context, pkColumnName string, sqlstr *string, dialect string, values *[]interface{}) (*int64, *int64) {
|
||||
var wrapAutoIncrementInsertSQL = func(ctx context.Context, config *DataSourceConfig, pkColumnName string, sqlstr *string, values *[]interface{}) (*int64, *int64) {
|
||||
// oracle 12c+ 支持IDENTITY属性的自增列,因为分页也要求12c+的语法,所以数据库就IDENTITY创建自增吧
|
||||
// 处理序列产生的自增主键,例如oracle,postgresql等
|
||||
var lastInsertID, zormSQLOutReturningID *int64
|
||||
@@ -1032,7 +1032,7 @@ var wrapAutoIncrementInsertSQL = func(ctx context.Context, pkColumnName string,
|
||||
// sqlBuilder.Grow(len(*sqlstr) + len(pkColumnName) + 40)
|
||||
sqlBuilder.Grow(stringBuilderGrowLen)
|
||||
sqlBuilder.WriteString(*sqlstr)
|
||||
switch dialect {
|
||||
switch config.Dialect {
|
||||
case "postgresql", "kingbase":
|
||||
var p int64 = 0
|
||||
lastInsertID = &p
|
||||
|
||||
@@ -491,7 +491,7 @@ func checkEntityKind(entity interface{}) (*reflect.Type, error) {
|
||||
// 当读取数据库的值为NULL时,由于基本类型不支持为NULL,通过反射将未知driver.Value改为interface{},不再映射到struct实体类
|
||||
// 感谢@fastabler提交的pr
|
||||
// oneColumnScanner 只有一个字段,而且可以直接Scan,例如string或者[]string,不需要反射StructType进行处理
|
||||
func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error {
|
||||
func sqlRowsValues(ctx context.Context, config *DataSourceConfig, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error {
|
||||
if entity == nil && valueOf == nil {
|
||||
return errors.New("->sqlRowsValues-->valueOfElem为nil")
|
||||
}
|
||||
@@ -518,7 +518,7 @@ func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value,
|
||||
if iscdvm {
|
||||
databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName())
|
||||
// 根据接收的类型,获取到类型转换的接口实现,优先匹配指定的数据库类型
|
||||
customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName]
|
||||
customDriverValueConver, converOK = customDriverValueMap[config.Dialect+"."+databaseTypeName]
|
||||
if !converOK {
|
||||
customDriverValueConver, converOK = customDriverValueMap[databaseTypeName]
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
// FuncDecimalValue 设置decimal类型接收值,复写函数自定义decimal实现,例如github.com/shopspring/decimal,返回的是指针
|
||||
var FuncDecimalValue = func(ctx context.Context, dialect string) interface{} {
|
||||
var FuncDecimalValue = func(ctx context.Context, config *DataSourceConfig) interface{} {
|
||||
return &decimal.Decimal{}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user