diff --git a/api/waf_common.go b/api/waf_common.go index 5e63c87..7b73225 100644 --- a/api/waf_common.go +++ b/api/waf_common.go @@ -12,7 +12,6 @@ import ( "github.com/360EntSecGroup-Skylar/excelize" "github.com/gin-gonic/gin" uuid "github.com/satori/go.uuid" - "gorm.io/gorm" "log" "net/http" "os" @@ -50,10 +49,11 @@ func (w *WafCommonApi) ExportExcelApi(c *gin.Context) { for i := 0; i < dataType.NumField(); i++ { field := dataType.Field(i) if field.Name == "BaseOrm" { - f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+i, 1), " - ") + f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(i), 1), " - ") } else { colName := field.Tag.Get("json") // 获取 excel 标签的值,即表头名称 - f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+i, 1), colName) + //fmt.Println(fmt.Sprintf("%v , %v %v %v", i, colName, dataType.NumField(), fmt.Sprintf("%s%d", w.GetColumnName(i), 1))) + f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(i), 1), colName) } } @@ -64,12 +64,11 @@ func (w *WafCommonApi) ExportExcelApi(c *gin.Context) { for j := 0; j < dataType.NumField(); j++ { field := dataType.Field(j) if field.Name == "BaseOrm" { - f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+j, rowNum), "") + f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(j), rowNum), "") } else { colValue := rowValue.Field(j).Interface() - f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+j, rowNum), colValue) + f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(j), rowNum), colValue) } - } } @@ -106,7 +105,14 @@ func (w *WafCommonApi) ImportExcelApi(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - + importTable, succResult := c.GetPostForm("import_table") + if !succResult { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } + importCodeStrategy, succResult := c.GetPostForm("import_code_strategy") + if !succResult { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } // 创建 Excel 文件 f, err := excelize.OpenFile(fileName) if err != nil { @@ -120,7 +126,7 @@ func (w *WafCommonApi) ImportExcelApi(c *gin.Context) { fmt.Println(err) return } - ret := saveDataToDatabase("hosts", rows) + ret := saveDataToDatabase(importTable, rows, importCodeStrategy) // 删除临时文件 if err := os.Remove(fileName); err != nil { log.Println("无法删除临时文件:", err) @@ -139,143 +145,23 @@ func getStructTypeValueByName(name string) (reflect.Type, reflect.Value) { return nil, reflect.ValueOf(nil) } } +func (w *WafCommonApi) GetColumnName(colIdx int) string { + // 英文字母有26个,列号超过26时使用多字母 + colName := "" + for colIdx >= 0 { + colName = fmt.Sprintf("%c", 'A'+colIdx%26) + colName + colIdx = colIdx/26 - 1 + } + return colName +} -// 保存Excel数据 -func saveDataToDatabase(name string, rows [][]string) ReturnImportData { +// 通用数据插入函数 +func saveDataToDatabase(tableName string, rows [][]string, importCodeStrategy string) ReturnImportData { successInt := 0 failInt := 0 msg := "" - switch name { - case "hosts": - needJumpFristCol := false - var header []string - // 获取header - for _, row := range rows { - for _, colCell := range row { - if colCell == " - " { - needJumpFristCol = true - continue - } - //fmt.Print(colCell, "\t") - header = append(header, colCell) - } - break - //fmt.Println() - } - //fmt.Println("一下是数据") - // 获取数据 - rowNumber := 0 - for _, row := range rows { - if rowNumber == 0 && needJumpFristCol == true { - rowNumber++ - continue - } - colNumber := 0 - data := make(map[string]string) - //循环列 - for _, colCell := range row { - if colNumber == 0 && needJumpFristCol == true { - colNumber++ - continue - } - headerNumber := colNumber - if needJumpFristCol == true { - headerNumber = headerNumber - 1 - } - data[header[headerNumber]] = colCell - //fmt.Println(header[headerNumber], ":", colCell, "\t") - colNumber++ - } - - //准备插入数据 - if wafHostService.GetDetailByCodeApi(data["code"]).Code != "" { - //fmt.Println(data["code"], " 数据已存在不进行插入\t") - msg += "行" + strconv.Itoa(rowNumber) + " code:" + data["code"] + " 数据已存在不进行插入 " - failInt++ - rowNumber++ - continue - } - if data["host"] == "全局网站" { - //fmt.Println(data["code"], " 数据已存在不进行插入\t") - msg += "行" + strconv.Itoa(rowNumber) + " 全局网站的跳过" - failInt++ - rowNumber++ - continue - } - - err := wafHostService.CheckIsExist(data["host"], data["port"]) - if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { - var wafHost = &model.Hosts{ - BaseOrm: baseorm.BaseOrm{ - Id: uuid.NewV4().String(), - USER_CODE: global.GWAF_USER_CODE, - Tenant_ID: global.GWAF_TENANT_ID, - CREATE_TIME: customtype.JsonTime(time.Now()), - UPDATE_TIME: customtype.JsonTime(time.Now()), - }, - Code: data["code"], - Host: data["host"], - /*Port: data["port"], - Ssl: data["ssl"], - GUARD_STATUS: data["guard_status"],*/ - REMOTE_SYSTEM: data["remote_system"], - REMOTE_APP: data["remote_app"], - Remote_host: data["remote_host"], - /*Remote_port: data["remote_port"],*/ - Remote_ip: data["remote_ip"], - Certfile: data["certfile"], - Keyfile: data["keyfile"], - REMARKS: data["remarks"], - } - port, err := strconv.Atoi(data["port"]) - if err != nil { - //fmt.Println("转换出错:", err) - msg += "行" + strconv.Itoa(rowNumber) + " port:" + data["port"] + " 转换出错 " - failInt++ - continue - } - wafHost.Port = port - ssl, err := strconv.Atoi(data["ssl"]) - if err != nil { - //fmt.Println("转换出错:", err) - msg += "行" + strconv.Itoa(rowNumber) + " ssl:" + data["ssl"] + " 转换出错 " - failInt++ - continue - } - wafHost.Ssl = ssl - - guard_status, err := strconv.Atoi(data["guard_status"]) - if err != nil { - //fmt.Println("转换出错:", err) - msg += "行" + strconv.Itoa(rowNumber) + " guard_status:" + data["guard_status"] + " 转换出错 " - failInt++ - continue - } - wafHost.GUARD_STATUS = guard_status - - remote_port, err := strconv.Atoi(data["remote_port"]) - if err != nil { - //fmt.Println("转换出错:", err) - msg += "行" + strconv.Itoa(rowNumber) + " remote_port:" + data["remote_port"] + " 转换出错 " - failInt++ - continue - } - wafHost.Remote_port = remote_port - - global.GWAF_LOCAL_DB.Create(wafHost) - successInt++ - } else { - failInt++ - msg += "行" + strconv.Itoa(rowNumber) + " host:" + data["host"] + " port:" + data["port"] + " 数据已存在不进行插入\t" - //fmt.Println(data["host"], data["port"], " 数据已存在不进行插入\t") - } - rowNumber++ - //fmt.Println() - } - - default: - } + processImportData(&model.Hosts{}, tableName, rows, &successInt, &failInt, &msg, importCodeStrategy) return ReturnImportData{ SuccessInt: successInt, @@ -283,3 +169,159 @@ func saveDataToDatabase(name string, rows [][]string) ReturnImportData { Msg: msg, } } + +// processImportData 是通用的导入数据函数 +func processImportData(dataType interface{}, tableName string, rows [][]string, successInt, failInt *int, msg *string, importCodeStrategy string) { + var header []string + needJumpFristCol := false + rowNumber := 0 + var dataMap map[string]string + + // 获取结构体类型和字段信息 + dataValue := reflect.ValueOf(dataType).Elem() + dataTypeFields := dataValue.Type() + + // 获取表头 + for _, row := range rows { + for _, colCell := range row { + if colCell == " - " { + needJumpFristCol = true + continue + } + header = append(header, colCell) + } + break + } + + // 处理数据 获取数据并插入数据库 + for _, row := range rows { + if rowNumber == 0 && needJumpFristCol { + rowNumber++ + continue + } + + // 建立一个 map 来存储数据 + dataMap = make(map[string]string) + colNumber := 0 + for _, colCell := range row { + if colNumber == 0 && needJumpFristCol { + colNumber++ + continue + } + headerNumber := colNumber + if needJumpFristCol { + headerNumber = headerNumber - 1 + } + dataMap[header[headerNumber]] = colCell + colNumber++ + } + + // 动态创建结构体实例,并映射数据 + newInstance := reflect.New(dataValue.Type()).Elem() + + //TODO 插入之前校验一下 数据是否已经存在 + + for fieldIdx := 0; fieldIdx < dataTypeFields.NumField(); fieldIdx++ { + field := dataTypeFields.Field(fieldIdx) + fieldName := field.Name + jsonTag := field.Tag.Get("json") + + // 如果 dataMap 中有匹配的字段 + if val, exists := dataMap[jsonTag]; exists { + //排除一些特定数据 + if tableName == "hosts" && fieldName == "Host" && val == "全局网站" { + continue + } + //检查数据是否已经存在 + if tableName == "hosts" && fieldName == "Host" { + if importCodeStrategy == "1" { + errMsg, err := checkHostCodeData(dataMap["code"]) + if err != nil { + *msg += fmt.Sprintf("行 %d, 检测数据合法性时候 出错: %v |", rowNumber, errMsg) + *failInt++ + continue + } + } + errMsg, err := checkHostPortData(dataMap["host"], dataMap["port"]) + if err != nil { + *msg += fmt.Sprintf("行 %d, 检测数据合法性时候 出错: %v |", rowNumber, errMsg) + *failInt++ + continue + } + } + // 将字段值设置到结构体中 + fieldVal := newInstance.Field(fieldIdx) + // 转换并设置字段的值 + switch fieldVal.Kind() { + case reflect.String: + if tableName == "hosts" { + if importCodeStrategy == "0" && fieldName == "Code" { + fieldVal.SetString(uuid.NewV4().String()) + } else { + fieldVal.SetString(val) + } + } else { + fieldVal.SetString(val) + } + case reflect.Int: + intVal, err := strconv.Atoi(val) + if err != nil { + *msg += fmt.Sprintf("行 %d, 字段 %s 转换为 int 错误: %v |", rowNumber, fieldName, err) + *failInt++ + continue + } + fieldVal.SetInt(int64(intVal)) + default: + *msg += fmt.Sprintf("不支持的字段类型: %s |", fieldVal.Kind()) + *failInt++ + } + } else if fieldName == "BaseOrm" { + fieldVal := newInstance.Field(fieldIdx) + // 给 BaseOrm 赋值 + baseOrm := baseorm.BaseOrm{ + Id: uuid.NewV4().String(), // 新生成的 ID + USER_CODE: global.GWAF_USER_CODE, + Tenant_ID: global.GWAF_TENANT_ID, + CREATE_TIME: customtype.JsonTime(time.Now()), + UPDATE_TIME: customtype.JsonTime(time.Now()), + } + // 将 BaseOrm 设置到结构体字段 + fieldVal.Set(reflect.ValueOf(baseOrm)) + } else { + *msg += fmt.Sprintf("行 %d, 缺少字段 %s 数据 |", rowNumber, fieldName) + *failInt++ + } + } + if err := global.GWAF_LOCAL_DB.Create(newInstance.Interface()); err != nil { + errGorm := err.Error + if errGorm != nil { + *msg += fmt.Sprintf("行 %d 插入失败: %v |", rowNumber, err.Error) + *failInt++ + } else { + *successInt++ + } + } + + rowNumber++ + } +} + +// checkHostData 检查host信息是否合法 +func checkHostCodeData(code string) (string, error) { + // 唯一性校验:检查 `Code` 是否已存在 + if wafHostService.GetDetailByCodeApi(code).Code != "" { + errorMsg := "Code 数据已存在不进行插入" + // 数据已存在,不插入 + return errorMsg, errors.New(errorMsg) + } + return "数据正常", nil +} +func checkHostPortData(host string, port string) (string, error) { + //唯一性校验:检查 `Host` 和 `Port` 的组合是否已存在 + if err := wafHostService.CheckIsExist(host, port); err == nil { + errorMsg := "Host+Port 数据已存在不进行插入" + // 数据已存在,不插入 + return errorMsg, errors.New(errorMsg) + } + return "数据正常", nil +} diff --git a/api/waf_common_test.go b/api/waf_common_test.go index aa4a930..4ae52fa 100644 --- a/api/waf_common_test.go +++ b/api/waf_common_test.go @@ -2,6 +2,7 @@ package api import ( "SamWaf/global" + "fmt" "github.com/gin-gonic/gin" "net/http" "net/http/httptest" @@ -37,3 +38,13 @@ func TestExportExcelApi(t *testing.T) { t.Errorf("期望的状态码:%d,实际状态码:%d", http.StatusOK, rec.Code) } } + +func TestGetColumnName(t *testing.T) { + colIdx := 500 + colName := "" + for colIdx >= 0 { + colName = fmt.Sprintf("%c", 'A'+colIdx%26) + colName + colIdx = colIdx/26 - 1 + } + fmt.Println(colName) +}