mirror of
https://gitee.com/samwaf/SamWaf.git
synced 2025-12-06 14:59:18 +08:00
@@ -120,11 +120,7 @@ func (w *WafSysInfoApi) CheckVersionApi(c *gin.Context) {
|
||||
CmdName: "samwaf_update", // The app name which is appended to the ApiURL to look for an update
|
||||
//ForceCheck: true, // For this example, always check for an update unless the version is "dev"
|
||||
}
|
||||
available, newVer, desc, err := updater.UpdateAvailable()
|
||||
if err != nil {
|
||||
response.FailWithMessage("未发现新文件", c)
|
||||
return
|
||||
}
|
||||
available, newVer, desc, _ := updater.UpdateAvailable()
|
||||
if available {
|
||||
global.GWAF_RUNTIME_NEW_VERSION = newVer
|
||||
global.GWAF_RUNTIME_NEW_VERSION_DESC = desc
|
||||
@@ -137,15 +133,30 @@ func (w *WafSysInfoApi) CheckVersionApi(c *gin.Context) {
|
||||
VersionDesc: desc,
|
||||
}, "有新版本", c)
|
||||
} else {
|
||||
response.FailWithMessage("没有最新版本", c)
|
||||
return
|
||||
available, newVer, desc, _ = updater.UpdateAvailableWithChannel("github")
|
||||
if available {
|
||||
global.GWAF_RUNTIME_NEW_VERSION = newVer
|
||||
global.GWAF_RUNTIME_NEW_VERSION_DESC = desc
|
||||
response.OkWithDetailed(model.VersionInfo{
|
||||
Version: global.GWAF_RELEASE_VERSION,
|
||||
VersionName: global.GWAF_RELEASE_VERSION_NAME,
|
||||
VersionRelease: global.GWAF_RELEASE,
|
||||
NeedUpdate: true,
|
||||
VersionNew: newVer,
|
||||
VersionDesc: desc,
|
||||
}, "有新版本(测试版)", c)
|
||||
} else {
|
||||
response.FailWithMessage("没有最新版本", c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 去升级
|
||||
func (w *WafSysInfoApi) UpdateApi(c *gin.Context) {
|
||||
|
||||
// 获取请求中的 channel 参数
|
||||
channel := c.Query("channel")
|
||||
if global.GWAF_RUNTIME_IS_UPDATETING == true {
|
||||
response.OkWithMessage("正在升级中...请在消息等待结果", c)
|
||||
return
|
||||
@@ -180,18 +191,34 @@ func (w *WafSysInfoApi) UpdateApi(c *gin.Context) {
|
||||
}
|
||||
go func() {
|
||||
// try to update
|
||||
err := updater.BackgroundRun()
|
||||
if err != nil {
|
||||
if channel != "" {
|
||||
err := updater.BackgroundRunWithChannel(channel)
|
||||
if err != nil {
|
||||
|
||||
global.GWAF_RUNTIME_IS_UPDATETING = false
|
||||
//发送websocket 推送消息
|
||||
global.GQEQUE_MESSAGE_DB.Enqueue(innerbean.UpdateResultMessageInfo{
|
||||
BaseMessageInfo: innerbean.BaseMessageInfo{OperaType: "升级结果", Server: global.GWAF_CUSTOM_SERVER_NAME},
|
||||
Msg: "升级错误",
|
||||
Success: "False",
|
||||
})
|
||||
zlog.Info("Failed to update app:", err)
|
||||
global.GWAF_RUNTIME_IS_UPDATETING = false
|
||||
//发送websocket 推送消息
|
||||
global.GQEQUE_MESSAGE_DB.Enqueue(innerbean.UpdateResultMessageInfo{
|
||||
BaseMessageInfo: innerbean.BaseMessageInfo{OperaType: "升级结果", Server: global.GWAF_CUSTOM_SERVER_NAME},
|
||||
Msg: "升级错误",
|
||||
Success: "False",
|
||||
})
|
||||
zlog.Info("Failed to update app:", err)
|
||||
}
|
||||
} else {
|
||||
err := updater.BackgroundRun()
|
||||
if err != nil {
|
||||
|
||||
global.GWAF_RUNTIME_IS_UPDATETING = false
|
||||
//发送websocket 推送消息
|
||||
global.GQEQUE_MESSAGE_DB.Enqueue(innerbean.UpdateResultMessageInfo{
|
||||
BaseMessageInfo: innerbean.BaseMessageInfo{OperaType: "升级结果", Server: global.GWAF_CUSTOM_SERVER_NAME},
|
||||
Msg: "升级错误",
|
||||
Success: "False",
|
||||
})
|
||||
zlog.Info("Failed to update app:", err)
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
response.OkWithMessage("已发起升级,等待通知结果", c)
|
||||
}
|
||||
|
||||
@@ -132,9 +132,9 @@ var (
|
||||
GWebSocket *gwebsocket.WebSocketOnline
|
||||
|
||||
//升级相关
|
||||
GUPDATE_VERSION_URL string = "https://update.samwaf.com/" //
|
||||
|
||||
GWAF_SNOWFLAKE_GEN *wafsnowflake.Snowflake //雪花算法
|
||||
GUPDATE_VERSION_URL string = "https://update.samwaf.com/" // 官方下载
|
||||
GUPDATE_GITHUB_VERSION_URL string = "https://api.github.com/repos/samwafgo/samwaf/releases/latest" //gitHub
|
||||
GWAF_SNOWFLAKE_GEN *wafsnowflake.Snowflake //雪花算法
|
||||
|
||||
//任务开关信息
|
||||
GWAF_SWITCH_TASK_COUNTER bool
|
||||
|
||||
149
utils/archiveutil.go
Normal file
149
utils/archiveutil.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Unzip 解压缩zip文件到指定目录
|
||||
// zipFile: zip文件路径
|
||||
// destDir: 解压目标目录
|
||||
func Unzip(zipFile, destDir string) error {
|
||||
// 打开zip文件
|
||||
reader, err := zip.OpenReader(zipFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开zip文件失败: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// 确保目标目录存在
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建目标目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 遍历zip文件中的所有文件和目录
|
||||
for _, file := range reader.File {
|
||||
// 构建目标路径
|
||||
path := filepath.Join(destDir, file.Name)
|
||||
|
||||
// 检查路径是否在目标目录内(防止zip slip漏洞)
|
||||
if !strings.HasPrefix(path, filepath.Clean(destDir)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("非法的文件路径: %s", file.Name)
|
||||
}
|
||||
|
||||
// 如果是目录,创建它
|
||||
if file.FileInfo().IsDir() {
|
||||
if err := os.MkdirAll(path, file.Mode()); err != nil {
|
||||
return fmt.Errorf("创建目录失败: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 确保父目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("创建父目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建目标文件
|
||||
destFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode())
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建目标文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 打开源文件
|
||||
srcFile, err := file.Open()
|
||||
if err != nil {
|
||||
destFile.Close()
|
||||
return fmt.Errorf("打开源文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 复制内容
|
||||
_, err = io.Copy(destFile, srcFile)
|
||||
srcFile.Close()
|
||||
destFile.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("复制文件内容失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractTarGz 解压缩tar.gz文件到指定目录
|
||||
// tarGzFile: tar.gz文件路径
|
||||
// destDir: 解压目标目录
|
||||
func ExtractTarGz(tarGzFile, destDir string) error {
|
||||
// 打开tar.gz文件
|
||||
file, err := os.Open(tarGzFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开tar.gz文件失败: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 创建gzip读取器
|
||||
gzipReader, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建gzip读取器失败: %v", err)
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
// 创建tar读取器
|
||||
tarReader := tar.NewReader(gzipReader)
|
||||
|
||||
// 确保目标目录存在
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建目标目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 遍历tar文件中的所有文件和目录
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break // 文件结束
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取tar文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 构建目标路径
|
||||
path := filepath.Join(destDir, header.Name)
|
||||
|
||||
// 检查路径是否在目标目录内(防止zip slip漏洞)
|
||||
if !strings.HasPrefix(path, filepath.Clean(destDir)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("非法的文件路径: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir: // 目录
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
return fmt.Errorf("创建目录失败: %v", err)
|
||||
}
|
||||
case tar.TypeReg: // 普通文件
|
||||
// 确保父目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("创建父目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建目标文件
|
||||
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建目标文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 复制内容
|
||||
if _, err := io.Copy(file, tarReader); err != nil {
|
||||
file.Close()
|
||||
return fmt.Errorf("复制文件内容失败: %v", err)
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package wafupdate
|
||||
import (
|
||||
"SamWaf/binarydist"
|
||||
"SamWaf/global"
|
||||
"SamWaf/utils"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -56,6 +58,7 @@ type Updater struct {
|
||||
CmdName string // Command name is appended to the ApiURL like http://apiurl/CmdName/. This represents one binary.
|
||||
BinURL string // Base URL for full binary downloads.
|
||||
DiffURL string // Base URL for diff downloads.
|
||||
BinGithubURL string // Base URL for full binary downloads.
|
||||
Dir string // Directory to store selfupdate state.
|
||||
ForceCheck bool // Check for update regardless of cktime timestamp
|
||||
CheckTime int // Time in hours before next check
|
||||
@@ -99,13 +102,34 @@ func canUpdate() (err error) {
|
||||
|
||||
// BackgroundRun starts the update check and apply cycle.
|
||||
func (u *Updater) BackgroundRun() error {
|
||||
if err := os.MkdirAll(u.getExecRelativeDir(u.Dir), 0755); err != nil {
|
||||
// fail
|
||||
return err
|
||||
}
|
||||
// check to see if we want to check for updates based on version
|
||||
// and last update time
|
||||
if u.WantUpdate() {
|
||||
|
||||
return u.BackgroundRunWithChannel("official")
|
||||
}
|
||||
func (u *Updater) BackgroundRunWithChannel(channel string) error {
|
||||
if channel == "" || channel == "official" {
|
||||
if err := os.MkdirAll(u.getExecRelativeDir(u.Dir), 0755); err != nil {
|
||||
// fail
|
||||
return err
|
||||
}
|
||||
// check to see if we want to check for updates based on version
|
||||
// and last update time
|
||||
if u.WantUpdate() {
|
||||
if err := canUpdate(); err != nil {
|
||||
// fail
|
||||
return err
|
||||
}
|
||||
|
||||
u.SetUpdateTime()
|
||||
|
||||
if err := u.Update(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if channel == "github" {
|
||||
if err := os.MkdirAll(u.getExecRelativeDir(u.Dir), 0755); err != nil {
|
||||
// fail
|
||||
return err
|
||||
}
|
||||
if err := canUpdate(); err != nil {
|
||||
// fail
|
||||
return err
|
||||
@@ -113,10 +137,11 @@ func (u *Updater) BackgroundRun() error {
|
||||
|
||||
u.SetUpdateTime()
|
||||
|
||||
if err := u.Update(); err != nil {
|
||||
if err := u.UpdateWithChannel(channel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -154,9 +179,8 @@ func (u *Updater) ClearUpdateState() {
|
||||
path := u.getExecRelativeDir(u.Dir + upcktimePath)
|
||||
os.Remove(path)
|
||||
}
|
||||
func (u *Updater) UpdateAvailableWithChannel(channel string) (bool, string, string, error) {
|
||||
|
||||
// UpdateAvailable checks if update is available and returns version
|
||||
func (u *Updater) UpdateAvailable() (bool, string, string, error) {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return false, "", "", err
|
||||
@@ -167,21 +191,46 @@ func (u *Updater) UpdateAvailable() (bool, string, string, error) {
|
||||
}
|
||||
defer old.Close()
|
||||
|
||||
err = u.fetchInfo()
|
||||
if err != nil {
|
||||
return false, "", "", err
|
||||
}
|
||||
// 比较版本号
|
||||
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
|
||||
// 如果更新的版本大于当前版本,返回 true,表示有可用的更新
|
||||
if cmp > 0 {
|
||||
return true, u.Info.Version, u.Info.Desc, nil
|
||||
//渠道选择
|
||||
if channel == "" || channel == "official" {
|
||||
err = u.fetchInfo()
|
||||
if err != nil {
|
||||
return false, "", "", err
|
||||
}
|
||||
// 比较版本号
|
||||
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
|
||||
// 如果更新的版本大于当前版本,返回 true,表示有可用的更新
|
||||
if cmp > 0 {
|
||||
return true, u.Info.Version, u.Info.Desc, nil
|
||||
} else {
|
||||
// 否则,返回 false,表示没有可用的更新
|
||||
return false, "", "", nil
|
||||
}
|
||||
} else if channel == "github" {
|
||||
err = u.fetchInfoGithub()
|
||||
if err != nil {
|
||||
return false, "", "", err
|
||||
}
|
||||
// 比较版本号
|
||||
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
|
||||
// 如果更新的版本大于当前版本,返回 true,表示有可用的更新
|
||||
if cmp > 0 {
|
||||
return true, u.Info.Version, u.Info.Desc, nil
|
||||
} else {
|
||||
// 否则,返回 false,表示没有可用的更新
|
||||
return false, "", "", nil
|
||||
}
|
||||
} else {
|
||||
// 否则,返回 false,表示没有可用的更新
|
||||
return false, "", "", nil
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateAvailable 默认从官方下载
|
||||
func (u *Updater) UpdateAvailable() (bool, string, string, error) {
|
||||
return u.UpdateAvailableWithChannel("official")
|
||||
}
|
||||
|
||||
// Update initiates the self update process
|
||||
func (u *Updater) Update() error {
|
||||
path, err := os.Executable()
|
||||
@@ -242,6 +291,149 @@ func (u *Updater) Update() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateWithChannel 通过渠道来检测
|
||||
func (u *Updater) UpdateWithChannel(channel string) error {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resolvedPath, err := filepath.EvalSymlinks(path); err == nil {
|
||||
path = resolvedPath
|
||||
}
|
||||
|
||||
if channel == "" || channel == "official" {
|
||||
// go fetch latest updates manifest
|
||||
err = u.fetchInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检测是新版本才更新,否则不更新
|
||||
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
|
||||
if cmp <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
old, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer old.Close()
|
||||
|
||||
// if patch failed grab the full new bin
|
||||
bin, err := u.fetchAndVerifyFullBin()
|
||||
if err != nil {
|
||||
if err == ErrHashMismatch {
|
||||
log.Println("update: hash mismatch from full binary")
|
||||
} else {
|
||||
log.Println("update: fetching full binary,", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// close the old binary before installing because on windows
|
||||
// it can't be renamed if a handle to the file is still open
|
||||
old.Close()
|
||||
|
||||
err, errRecover := fromStream(bytes.NewBuffer(bin))
|
||||
if errRecover != nil {
|
||||
return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update was successful, run func if set
|
||||
if u.OnSuccessfulUpdate != nil {
|
||||
u.OnSuccessfulUpdate()
|
||||
}
|
||||
} else if channel == "github" {
|
||||
err = u.fetchInfoGithub()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 从 GitHub 下载资源
|
||||
r, err := u.fetch(u.BinGithubURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// 创建临时目录用于解压文件
|
||||
tempDir, err := ioutil.TempDir("", "samwaf_beta_update"+u.Info.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// 保存下载的文件到临时文件
|
||||
tempFile := filepath.Join(tempDir, "download")
|
||||
out, err := os.Create(tempFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(out, r)
|
||||
out.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 根据文件类型和平台解压并获取正确的可执行文件
|
||||
var binPath string
|
||||
|
||||
if strings.HasSuffix(u.BinGithubURL, ".exe") {
|
||||
//使用win7内核
|
||||
binPath = tempFile
|
||||
} else if strings.HasSuffix(u.BinGithubURL, ".zip") {
|
||||
// 情况 1: 从 ZIP 中提取 SamWaf64.exe
|
||||
err = utils.Unzip(tempFile, tempDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
binPath = filepath.Join(tempDir, "SamWaf64.exe")
|
||||
} else if strings.HasSuffix(u.BinGithubURL, ".tar.gz") {
|
||||
// 处理 Linux 平台的 tar.gz 文件
|
||||
err = utils.ExtractTarGz(tempFile, tempDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.Contains(u.BinGithubURL, "Linux_x86_64") {
|
||||
// 情况 2: 从 tar.gz 中提取 SamWafLinux64
|
||||
binPath = filepath.Join(tempDir, "SamWafLinux64")
|
||||
} else if strings.Contains(u.BinGithubURL, "Linux_arm64") {
|
||||
// 情况 3: 从 tar.gz 中提取 SamWafLinuxArm64
|
||||
binPath = filepath.Join(tempDir, "SamWafLinuxArm64")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否找到了可执行文件
|
||||
if binPath == "" {
|
||||
return errors.New("无法找到适合当前平台的可执行文件")
|
||||
}
|
||||
fileBytes, err := ioutil.ReadFile(binPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err, errRecover := fromStream(bytes.NewBuffer(fileBytes))
|
||||
if errRecover != nil {
|
||||
return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update was successful, run func if set
|
||||
if u.OnSuccessfulUpdate != nil {
|
||||
u.OnSuccessfulUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func fromStream(updateWith io.Reader) (err error, errRecover error) {
|
||||
updatePath, err := os.Executable()
|
||||
if err != nil {
|
||||
@@ -322,6 +514,72 @@ func (u *Updater) fetchInfo() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchInfoGithub 从GitHub获取最新beta版本信息
|
||||
func (u *Updater) fetchInfoGithub() error {
|
||||
r, err := u.fetch(global.GUPDATE_GITHUB_VERSION_URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// 解析GitHub API返回的JSON数据
|
||||
var githubRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
Assets []struct {
|
||||
Name string `json:"name"`
|
||||
BrowserDownloadURL string `json:"browser_download_url"`
|
||||
Size int64 `json:"size"`
|
||||
ContentType string `json:"content_type"`
|
||||
} `json:"assets"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(r).Decode(&githubRelease)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//判断tagname 是否包含beta
|
||||
if !strings.Contains(githubRelease.TagName, "beta") {
|
||||
//return errors.New("not beta version")
|
||||
}
|
||||
// 查找适合当前平台的资源
|
||||
var downloadURL string
|
||||
|
||||
// 根据平台选择合适的下载文件
|
||||
platformSuffix := ""
|
||||
utils.IsSupportedWindows7Version()
|
||||
switch plat {
|
||||
case "windows-amd64":
|
||||
if utils.IsSupportedWindows7Version() {
|
||||
platformSuffix = "SamWaf64ForWin7Win8Win2008"
|
||||
} else {
|
||||
platformSuffix = "Windows_x86_64"
|
||||
}
|
||||
case "linux-amd64":
|
||||
platformSuffix = "Linux_x86_64"
|
||||
case "linux-arm64":
|
||||
platformSuffix = "Linux_arm64"
|
||||
}
|
||||
|
||||
// 查找匹配当前平台的资源
|
||||
for _, asset := range githubRelease.Assets {
|
||||
if strings.Contains(asset.Name, platformSuffix) {
|
||||
downloadURL = asset.BrowserDownloadURL
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到适合的资源,返回错误
|
||||
if downloadURL == "" {
|
||||
return errors.New("no suitable release asset found for this platform")
|
||||
}
|
||||
u.BinGithubURL = downloadURL
|
||||
u.Info.Version = githubRelease.TagName
|
||||
u.Info.Desc = githubRelease.Body
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
|
||||
bin, err := u.fetchAndApplyPatch(old)
|
||||
if err != nil {
|
||||
@@ -370,10 +628,8 @@ func (u *Updater) fetchBin() ([]byte, error) {
|
||||
if _, err = io.Copy(buf, gz); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (u *Updater) fetch(url string) (io.ReadCloser, error) {
|
||||
if u.Requester == nil {
|
||||
return defaultHTTPRequester.Fetch(url)
|
||||
|
||||
@@ -181,3 +181,35 @@ func TestUpdater_GetHttps(t *testing.T) {
|
||||
println(strbody)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试 fetchInfoGithub 在实际环境中的行为
|
||||
func TestFetchInfoGithubIntegration(t *testing.T) {
|
||||
|
||||
// 创建更新器
|
||||
updater := &Updater{
|
||||
CurrentVersion: "v1.0.0",
|
||||
}
|
||||
|
||||
// 调用被测试的函数
|
||||
err := updater.fetchInfoGithub()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 GitHub 信息失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
if updater.BinGithubURL == "" {
|
||||
t.Error("下载 URL 为空")
|
||||
}
|
||||
|
||||
if updater.Info.Version == "" {
|
||||
t.Error("版本为空")
|
||||
}
|
||||
|
||||
if updater.Info.Desc == "" {
|
||||
t.Error("描述为空")
|
||||
}
|
||||
|
||||
t.Logf("获取到的版本: %s", updater.Info.Version)
|
||||
t.Logf("下载 URL: %s", updater.BinGithubURL)
|
||||
t.Logf("版本说明 Desc: %s", updater.Info.Desc)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user