fix:host import export

#92 #114
This commit is contained in:
samwaf
2025-01-09 08:34:06 +08:00
parent 70ab1d175f
commit cd65e37f26
2 changed files with 193 additions and 140 deletions

View File

@@ -12,7 +12,6 @@ import (
"github.com/360EntSecGroup-Skylar/excelize"
"github.com/gin-gonic/gin"
uuid "github.com/satori/go.uuid"
"gorm.io/gorm"
"log"
"net/http"
"os"
@@ -50,10 +49,11 @@ func (w *WafCommonApi) ExportExcelApi(c *gin.Context) {
for i := 0; i < dataType.NumField(); i++ {
field := dataType.Field(i)
if field.Name == "BaseOrm" {
f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+i, 1), " - ")
f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(i), 1), " - ")
} else {
colName := field.Tag.Get("json") // 获取 excel 标签的值,即表头名称
f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+i, 1), colName)
//fmt.Println(fmt.Sprintf("%v %v %v %v", i, colName, dataType.NumField(), fmt.Sprintf("%s%d", w.GetColumnName(i), 1)))
f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(i), 1), colName)
}
}
@@ -64,12 +64,11 @@ func (w *WafCommonApi) ExportExcelApi(c *gin.Context) {
for j := 0; j < dataType.NumField(); j++ {
field := dataType.Field(j)
if field.Name == "BaseOrm" {
f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+j, rowNum), "")
f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(j), rowNum), "")
} else {
colValue := rowValue.Field(j).Interface()
f.SetCellValue(sheetName, fmt.Sprintf("%c%d", 'A'+j, rowNum), colValue)
f.SetCellValue(sheetName, fmt.Sprintf("%s%d", w.GetColumnName(j), rowNum), colValue)
}
}
}
@@ -106,7 +105,14 @@ func (w *WafCommonApi) ImportExcelApi(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
importTable, succResult := c.GetPostForm("import_table")
if !succResult {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
importCodeStrategy, succResult := c.GetPostForm("import_code_strategy")
if !succResult {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
// 创建 Excel 文件
f, err := excelize.OpenFile(fileName)
if err != nil {
@@ -120,7 +126,7 @@ func (w *WafCommonApi) ImportExcelApi(c *gin.Context) {
fmt.Println(err)
return
}
ret := saveDataToDatabase("hosts", rows)
ret := saveDataToDatabase(importTable, rows, importCodeStrategy)
// 删除临时文件
if err := os.Remove(fileName); err != nil {
log.Println("无法删除临时文件:", err)
@@ -139,143 +145,23 @@ func getStructTypeValueByName(name string) (reflect.Type, reflect.Value) {
return nil, reflect.ValueOf(nil)
}
}
func (w *WafCommonApi) GetColumnName(colIdx int) string {
// 英文字母有26个列号超过26时使用多字母
colName := ""
for colIdx >= 0 {
colName = fmt.Sprintf("%c", 'A'+colIdx%26) + colName
colIdx = colIdx/26 - 1
}
return colName
}
// 保存Excel数据
func saveDataToDatabase(name string, rows [][]string) ReturnImportData {
// 通用数据插入函数
func saveDataToDatabase(tableName string, rows [][]string, importCodeStrategy string) ReturnImportData {
successInt := 0
failInt := 0
msg := ""
switch name {
case "hosts":
needJumpFristCol := false
var header []string
// 获取header
for _, row := range rows {
for _, colCell := range row {
if colCell == " - " {
needJumpFristCol = true
continue
}
//fmt.Print(colCell, "\t")
header = append(header, colCell)
}
break
//fmt.Println()
}
//fmt.Println("一下是数据")
// 获取数据
rowNumber := 0
for _, row := range rows {
if rowNumber == 0 && needJumpFristCol == true {
rowNumber++
continue
}
colNumber := 0
data := make(map[string]string)
//循环列
for _, colCell := range row {
if colNumber == 0 && needJumpFristCol == true {
colNumber++
continue
}
headerNumber := colNumber
if needJumpFristCol == true {
headerNumber = headerNumber - 1
}
data[header[headerNumber]] = colCell
//fmt.Println(header[headerNumber], ":", colCell, "\t")
colNumber++
}
//准备插入数据
if wafHostService.GetDetailByCodeApi(data["code"]).Code != "" {
//fmt.Println(data["code"], " 数据已存在不进行插入\t")
msg += "行" + strconv.Itoa(rowNumber) + " code:" + data["code"] + " 数据已存在不进行插入 "
failInt++
rowNumber++
continue
}
if data["host"] == "全局网站" {
//fmt.Println(data["code"], " 数据已存在不进行插入\t")
msg += "行" + strconv.Itoa(rowNumber) + " 全局网站的跳过"
failInt++
rowNumber++
continue
}
err := wafHostService.CheckIsExist(data["host"], data["port"])
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
var wafHost = &model.Hosts{
BaseOrm: baseorm.BaseOrm{
Id: uuid.NewV4().String(),
USER_CODE: global.GWAF_USER_CODE,
Tenant_ID: global.GWAF_TENANT_ID,
CREATE_TIME: customtype.JsonTime(time.Now()),
UPDATE_TIME: customtype.JsonTime(time.Now()),
},
Code: data["code"],
Host: data["host"],
/*Port: data["port"],
Ssl: data["ssl"],
GUARD_STATUS: data["guard_status"],*/
REMOTE_SYSTEM: data["remote_system"],
REMOTE_APP: data["remote_app"],
Remote_host: data["remote_host"],
/*Remote_port: data["remote_port"],*/
Remote_ip: data["remote_ip"],
Certfile: data["certfile"],
Keyfile: data["keyfile"],
REMARKS: data["remarks"],
}
port, err := strconv.Atoi(data["port"])
if err != nil {
//fmt.Println("转换出错:", err)
msg += "行" + strconv.Itoa(rowNumber) + " port:" + data["port"] + " 转换出错 "
failInt++
continue
}
wafHost.Port = port
ssl, err := strconv.Atoi(data["ssl"])
if err != nil {
//fmt.Println("转换出错:", err)
msg += "行" + strconv.Itoa(rowNumber) + " ssl:" + data["ssl"] + " 转换出错 "
failInt++
continue
}
wafHost.Ssl = ssl
guard_status, err := strconv.Atoi(data["guard_status"])
if err != nil {
//fmt.Println("转换出错:", err)
msg += "行" + strconv.Itoa(rowNumber) + " guard_status:" + data["guard_status"] + " 转换出错 "
failInt++
continue
}
wafHost.GUARD_STATUS = guard_status
remote_port, err := strconv.Atoi(data["remote_port"])
if err != nil {
//fmt.Println("转换出错:", err)
msg += "行" + strconv.Itoa(rowNumber) + " remote_port:" + data["remote_port"] + " 转换出错 "
failInt++
continue
}
wafHost.Remote_port = remote_port
global.GWAF_LOCAL_DB.Create(wafHost)
successInt++
} else {
failInt++
msg += "行" + strconv.Itoa(rowNumber) + " host:" + data["host"] + " port:" + data["port"] + " 数据已存在不进行插入\t"
//fmt.Println(data["host"], data["port"], " 数据已存在不进行插入\t")
}
rowNumber++
//fmt.Println()
}
default:
}
processImportData(&model.Hosts{}, tableName, rows, &successInt, &failInt, &msg, importCodeStrategy)
return ReturnImportData{
SuccessInt: successInt,
@@ -283,3 +169,159 @@ func saveDataToDatabase(name string, rows [][]string) ReturnImportData {
Msg: msg,
}
}
// processImportData 是通用的导入数据函数
func processImportData(dataType interface{}, tableName string, rows [][]string, successInt, failInt *int, msg *string, importCodeStrategy string) {
var header []string
needJumpFristCol := false
rowNumber := 0
var dataMap map[string]string
// 获取结构体类型和字段信息
dataValue := reflect.ValueOf(dataType).Elem()
dataTypeFields := dataValue.Type()
// 获取表头
for _, row := range rows {
for _, colCell := range row {
if colCell == " - " {
needJumpFristCol = true
continue
}
header = append(header, colCell)
}
break
}
// 处理数据 获取数据并插入数据库
for _, row := range rows {
if rowNumber == 0 && needJumpFristCol {
rowNumber++
continue
}
// 建立一个 map 来存储数据
dataMap = make(map[string]string)
colNumber := 0
for _, colCell := range row {
if colNumber == 0 && needJumpFristCol {
colNumber++
continue
}
headerNumber := colNumber
if needJumpFristCol {
headerNumber = headerNumber - 1
}
dataMap[header[headerNumber]] = colCell
colNumber++
}
// 动态创建结构体实例,并映射数据
newInstance := reflect.New(dataValue.Type()).Elem()
//TODO 插入之前校验一下 数据是否已经存在
for fieldIdx := 0; fieldIdx < dataTypeFields.NumField(); fieldIdx++ {
field := dataTypeFields.Field(fieldIdx)
fieldName := field.Name
jsonTag := field.Tag.Get("json")
// 如果 dataMap 中有匹配的字段
if val, exists := dataMap[jsonTag]; exists {
//排除一些特定数据
if tableName == "hosts" && fieldName == "Host" && val == "全局网站" {
continue
}
//检查数据是否已经存在
if tableName == "hosts" && fieldName == "Host" {
if importCodeStrategy == "1" {
errMsg, err := checkHostCodeData(dataMap["code"])
if err != nil {
*msg += fmt.Sprintf("行 %d, 检测数据合法性时候 出错: %v |", rowNumber, errMsg)
*failInt++
continue
}
}
errMsg, err := checkHostPortData(dataMap["host"], dataMap["port"])
if err != nil {
*msg += fmt.Sprintf("行 %d, 检测数据合法性时候 出错: %v |", rowNumber, errMsg)
*failInt++
continue
}
}
// 将字段值设置到结构体中
fieldVal := newInstance.Field(fieldIdx)
// 转换并设置字段的值
switch fieldVal.Kind() {
case reflect.String:
if tableName == "hosts" {
if importCodeStrategy == "0" && fieldName == "Code" {
fieldVal.SetString(uuid.NewV4().String())
} else {
fieldVal.SetString(val)
}
} else {
fieldVal.SetString(val)
}
case reflect.Int:
intVal, err := strconv.Atoi(val)
if err != nil {
*msg += fmt.Sprintf("行 %d, 字段 %s 转换为 int 错误: %v |", rowNumber, fieldName, err)
*failInt++
continue
}
fieldVal.SetInt(int64(intVal))
default:
*msg += fmt.Sprintf("不支持的字段类型: %s |", fieldVal.Kind())
*failInt++
}
} else if fieldName == "BaseOrm" {
fieldVal := newInstance.Field(fieldIdx)
// 给 BaseOrm 赋值
baseOrm := baseorm.BaseOrm{
Id: uuid.NewV4().String(), // 新生成的 ID
USER_CODE: global.GWAF_USER_CODE,
Tenant_ID: global.GWAF_TENANT_ID,
CREATE_TIME: customtype.JsonTime(time.Now()),
UPDATE_TIME: customtype.JsonTime(time.Now()),
}
// 将 BaseOrm 设置到结构体字段
fieldVal.Set(reflect.ValueOf(baseOrm))
} else {
*msg += fmt.Sprintf("行 %d, 缺少字段 %s 数据 |", rowNumber, fieldName)
*failInt++
}
}
if err := global.GWAF_LOCAL_DB.Create(newInstance.Interface()); err != nil {
errGorm := err.Error
if errGorm != nil {
*msg += fmt.Sprintf("行 %d 插入失败: %v |", rowNumber, err.Error)
*failInt++
} else {
*successInt++
}
}
rowNumber++
}
}
// checkHostData 检查host信息是否合法
func checkHostCodeData(code string) (string, error) {
// 唯一性校验:检查 `Code` 是否已存在
if wafHostService.GetDetailByCodeApi(code).Code != "" {
errorMsg := "Code 数据已存在不进行插入"
// 数据已存在,不插入
return errorMsg, errors.New(errorMsg)
}
return "数据正常", nil
}
func checkHostPortData(host string, port string) (string, error) {
//唯一性校验:检查 `Host` 和 `Port` 的组合是否已存在
if err := wafHostService.CheckIsExist(host, port); err == nil {
errorMsg := "Host+Port 数据已存在不进行插入"
// 数据已存在,不插入
return errorMsg, errors.New(errorMsg)
}
return "数据正常", nil
}

View File

@@ -2,6 +2,7 @@ package api
import (
"SamWaf/global"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"net/http/httptest"
@@ -37,3 +38,13 @@ func TestExportExcelApi(t *testing.T) {
t.Errorf("期望的状态码:%d实际状态码%d", http.StatusOK, rec.Code)
}
}
func TestGetColumnName(t *testing.T) {
colIdx := 500
colName := ""
for colIdx >= 0 {
colName = fmt.Sprintf("%c", 'A'+colIdx%26) + colName
colIdx = colIdx/26 - 1
}
fmt.Println(colName)
}