feat:beta update

#261
This commit is contained in:
samwaf
2025-05-01 21:45:01 +08:00
parent f2cbd5c183
commit 28004c18d2
5 changed files with 506 additions and 42 deletions

View File

@@ -120,11 +120,7 @@ func (w *WafSysInfoApi) CheckVersionApi(c *gin.Context) {
CmdName: "samwaf_update", // The app name which is appended to the ApiURL to look for an update
//ForceCheck: true, // For this example, always check for an update unless the version is "dev"
}
available, newVer, desc, err := updater.UpdateAvailable()
if err != nil {
response.FailWithMessage("未发现新文件", c)
return
}
available, newVer, desc, _ := updater.UpdateAvailable()
if available {
global.GWAF_RUNTIME_NEW_VERSION = newVer
global.GWAF_RUNTIME_NEW_VERSION_DESC = desc
@@ -137,15 +133,30 @@ func (w *WafSysInfoApi) CheckVersionApi(c *gin.Context) {
VersionDesc: desc,
}, "有新版本", c)
} else {
response.FailWithMessage("没有最新版本", c)
return
available, newVer, desc, _ = updater.UpdateAvailableWithChannel("github")
if available {
global.GWAF_RUNTIME_NEW_VERSION = newVer
global.GWAF_RUNTIME_NEW_VERSION_DESC = desc
response.OkWithDetailed(model.VersionInfo{
Version: global.GWAF_RELEASE_VERSION,
VersionName: global.GWAF_RELEASE_VERSION_NAME,
VersionRelease: global.GWAF_RELEASE,
NeedUpdate: true,
VersionNew: newVer,
VersionDesc: desc,
}, "有新版本(测试版)", c)
} else {
response.FailWithMessage("没有最新版本", c)
return
}
}
}
// 去升级
func (w *WafSysInfoApi) UpdateApi(c *gin.Context) {
// 获取请求中的 channel 参数
channel := c.Query("channel")
if global.GWAF_RUNTIME_IS_UPDATETING == true {
response.OkWithMessage("正在升级中...请在消息等待结果", c)
return
@@ -180,18 +191,34 @@ func (w *WafSysInfoApi) UpdateApi(c *gin.Context) {
}
go func() {
// try to update
err := updater.BackgroundRun()
if err != nil {
if channel != "" {
err := updater.BackgroundRunWithChannel(channel)
if err != nil {
global.GWAF_RUNTIME_IS_UPDATETING = false
//发送websocket 推送消息
global.GQEQUE_MESSAGE_DB.Enqueue(innerbean.UpdateResultMessageInfo{
BaseMessageInfo: innerbean.BaseMessageInfo{OperaType: "升级结果", Server: global.GWAF_CUSTOM_SERVER_NAME},
Msg: "升级错误",
Success: "False",
})
zlog.Info("Failed to update app:", err)
global.GWAF_RUNTIME_IS_UPDATETING = false
//发送websocket 推送消息
global.GQEQUE_MESSAGE_DB.Enqueue(innerbean.UpdateResultMessageInfo{
BaseMessageInfo: innerbean.BaseMessageInfo{OperaType: "升级结果", Server: global.GWAF_CUSTOM_SERVER_NAME},
Msg: "升级错误",
Success: "False",
})
zlog.Info("Failed to update app:", err)
}
} else {
err := updater.BackgroundRun()
if err != nil {
global.GWAF_RUNTIME_IS_UPDATETING = false
//发送websocket 推送消息
global.GQEQUE_MESSAGE_DB.Enqueue(innerbean.UpdateResultMessageInfo{
BaseMessageInfo: innerbean.BaseMessageInfo{OperaType: "升级结果", Server: global.GWAF_CUSTOM_SERVER_NAME},
Msg: "升级错误",
Success: "False",
})
zlog.Info("Failed to update app:", err)
}
}
}()
response.OkWithMessage("已发起升级,等待通知结果", c)
}

View File

@@ -132,9 +132,9 @@ var (
GWebSocket *gwebsocket.WebSocketOnline
//升级相关
GUPDATE_VERSION_URL string = "https://update.samwaf.com/" //
GWAF_SNOWFLAKE_GEN *wafsnowflake.Snowflake //雪花算法
GUPDATE_VERSION_URL string = "https://update.samwaf.com/" // 官方下载
GUPDATE_GITHUB_VERSION_URL string = "https://api.github.com/repos/samwafgo/samwaf/releases/latest" //gitHub
GWAF_SNOWFLAKE_GEN *wafsnowflake.Snowflake //雪花算法
//任务开关信息
GWAF_SWITCH_TASK_COUNTER bool

149
utils/archiveutil.go Normal file
View File

@@ -0,0 +1,149 @@
package utils
import (
"archive/tar"
"archive/zip"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// Unzip 解压缩zip文件到指定目录
// zipFile: zip文件路径
// destDir: 解压目标目录
func Unzip(zipFile, destDir string) error {
// 打开zip文件
reader, err := zip.OpenReader(zipFile)
if err != nil {
return fmt.Errorf("打开zip文件失败: %v", err)
}
defer reader.Close()
// 确保目标目录存在
if err := os.MkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("创建目标目录失败: %v", err)
}
// 遍历zip文件中的所有文件和目录
for _, file := range reader.File {
// 构建目标路径
path := filepath.Join(destDir, file.Name)
// 检查路径是否在目标目录内防止zip slip漏洞
if !strings.HasPrefix(path, filepath.Clean(destDir)+string(os.PathSeparator)) {
return fmt.Errorf("非法的文件路径: %s", file.Name)
}
// 如果是目录,创建它
if file.FileInfo().IsDir() {
if err := os.MkdirAll(path, file.Mode()); err != nil {
return fmt.Errorf("创建目录失败: %v", err)
}
continue
}
// 确保父目录存在
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("创建父目录失败: %v", err)
}
// 创建目标文件
destFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode())
if err != nil {
return fmt.Errorf("创建目标文件失败: %v", err)
}
// 打开源文件
srcFile, err := file.Open()
if err != nil {
destFile.Close()
return fmt.Errorf("打开源文件失败: %v", err)
}
// 复制内容
_, err = io.Copy(destFile, srcFile)
srcFile.Close()
destFile.Close()
if err != nil {
return fmt.Errorf("复制文件内容失败: %v", err)
}
}
return nil
}
// ExtractTarGz 解压缩tar.gz文件到指定目录
// tarGzFile: tar.gz文件路径
// destDir: 解压目标目录
func ExtractTarGz(tarGzFile, destDir string) error {
// 打开tar.gz文件
file, err := os.Open(tarGzFile)
if err != nil {
return fmt.Errorf("打开tar.gz文件失败: %v", err)
}
defer file.Close()
// 创建gzip读取器
gzipReader, err := gzip.NewReader(file)
if err != nil {
return fmt.Errorf("创建gzip读取器失败: %v", err)
}
defer gzipReader.Close()
// 创建tar读取器
tarReader := tar.NewReader(gzipReader)
// 确保目标目录存在
if err := os.MkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("创建目标目录失败: %v", err)
}
// 遍历tar文件中的所有文件和目录
for {
header, err := tarReader.Next()
if err == io.EOF {
break // 文件结束
}
if err != nil {
return fmt.Errorf("读取tar文件失败: %v", err)
}
// 构建目标路径
path := filepath.Join(destDir, header.Name)
// 检查路径是否在目标目录内防止zip slip漏洞
if !strings.HasPrefix(path, filepath.Clean(destDir)+string(os.PathSeparator)) {
return fmt.Errorf("非法的文件路径: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir: // 目录
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Errorf("创建目录失败: %v", err)
}
case tar.TypeReg: // 普通文件
// 确保父目录存在
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("创建父目录失败: %v", err)
}
// 创建目标文件
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return fmt.Errorf("创建目标文件失败: %v", err)
}
// 复制内容
if _, err := io.Copy(file, tarReader); err != nil {
file.Close()
return fmt.Errorf("复制文件内容失败: %v", err)
}
file.Close()
}
}
return nil
}

View File

@@ -3,6 +3,7 @@ package wafupdate
import (
"SamWaf/binarydist"
"SamWaf/global"
"SamWaf/utils"
"bytes"
"compress/gzip"
"crypto/sha256"
@@ -18,6 +19,7 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"time"
)
@@ -56,6 +58,7 @@ type Updater struct {
CmdName string // Command name is appended to the ApiURL like http://apiurl/CmdName/. This represents one binary.
BinURL string // Base URL for full binary downloads.
DiffURL string // Base URL for diff downloads.
BinGithubURL string // Base URL for full binary downloads.
Dir string // Directory to store selfupdate state.
ForceCheck bool // Check for update regardless of cktime timestamp
CheckTime int // Time in hours before next check
@@ -99,13 +102,34 @@ func canUpdate() (err error) {
// BackgroundRun starts the update check and apply cycle.
func (u *Updater) BackgroundRun() error {
if err := os.MkdirAll(u.getExecRelativeDir(u.Dir), 0755); err != nil {
// fail
return err
}
// check to see if we want to check for updates based on version
// and last update time
if u.WantUpdate() {
return u.BackgroundRunWithChannel("official")
}
func (u *Updater) BackgroundRunWithChannel(channel string) error {
if channel == "" || channel == "official" {
if err := os.MkdirAll(u.getExecRelativeDir(u.Dir), 0755); err != nil {
// fail
return err
}
// check to see if we want to check for updates based on version
// and last update time
if u.WantUpdate() {
if err := canUpdate(); err != nil {
// fail
return err
}
u.SetUpdateTime()
if err := u.Update(); err != nil {
return err
}
}
} else if channel == "github" {
if err := os.MkdirAll(u.getExecRelativeDir(u.Dir), 0755); err != nil {
// fail
return err
}
if err := canUpdate(); err != nil {
// fail
return err
@@ -113,10 +137,11 @@ func (u *Updater) BackgroundRun() error {
u.SetUpdateTime()
if err := u.Update(); err != nil {
if err := u.UpdateWithChannel(channel); err != nil {
return err
}
}
return nil
}
@@ -154,9 +179,8 @@ func (u *Updater) ClearUpdateState() {
path := u.getExecRelativeDir(u.Dir + upcktimePath)
os.Remove(path)
}
func (u *Updater) UpdateAvailableWithChannel(channel string) (bool, string, string, error) {
// UpdateAvailable checks if update is available and returns version
func (u *Updater) UpdateAvailable() (bool, string, string, error) {
path, err := os.Executable()
if err != nil {
return false, "", "", err
@@ -167,21 +191,46 @@ func (u *Updater) UpdateAvailable() (bool, string, string, error) {
}
defer old.Close()
err = u.fetchInfo()
if err != nil {
return false, "", "", err
}
// 比较版本号
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
// 如果更新的版本大于当前版本,返回 true表示有可用的更新
if cmp > 0 {
return true, u.Info.Version, u.Info.Desc, nil
//渠道选择
if channel == "" || channel == "official" {
err = u.fetchInfo()
if err != nil {
return false, "", "", err
}
// 比较版本号
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
// 如果更新的版本大于当前版本,返回 true表示有可用的更新
if cmp > 0 {
return true, u.Info.Version, u.Info.Desc, nil
} else {
// 否则,返回 false表示没有可用的更新
return false, "", "", nil
}
} else if channel == "github" {
err = u.fetchInfoGithub()
if err != nil {
return false, "", "", err
}
// 比较版本号
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
// 如果更新的版本大于当前版本,返回 true表示有可用的更新
if cmp > 0 {
return true, u.Info.Version, u.Info.Desc, nil
} else {
// 否则,返回 false表示没有可用的更新
return false, "", "", nil
}
} else {
// 否则,返回 false表示没有可用的更新
return false, "", "", nil
}
}
// UpdateAvailable 默认从官方下载
func (u *Updater) UpdateAvailable() (bool, string, string, error) {
return u.UpdateAvailableWithChannel("official")
}
// Update initiates the self update process
func (u *Updater) Update() error {
path, err := os.Executable()
@@ -242,6 +291,149 @@ func (u *Updater) Update() error {
return nil
}
// UpdateWithChannel 通过渠道来检测
func (u *Updater) UpdateWithChannel(channel string) error {
path, err := os.Executable()
if err != nil {
return err
}
if resolvedPath, err := filepath.EvalSymlinks(path); err == nil {
path = resolvedPath
}
if channel == "" || channel == "official" {
// go fetch latest updates manifest
err = u.fetchInfo()
if err != nil {
return err
}
// 检测是新版本才更新,否则不更新
cmp := semver.Compare(u.Info.Version, u.CurrentVersion)
if cmp <= 0 {
return nil
}
old, err := os.Open(path)
if err != nil {
return err
}
defer old.Close()
// if patch failed grab the full new bin
bin, err := u.fetchAndVerifyFullBin()
if err != nil {
if err == ErrHashMismatch {
log.Println("update: hash mismatch from full binary")
} else {
log.Println("update: fetching full binary,", err)
}
return err
}
// close the old binary before installing because on windows
// it can't be renamed if a handle to the file is still open
old.Close()
err, errRecover := fromStream(bytes.NewBuffer(bin))
if errRecover != nil {
return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
}
if err != nil {
return err
}
// update was successful, run func if set
if u.OnSuccessfulUpdate != nil {
u.OnSuccessfulUpdate()
}
} else if channel == "github" {
err = u.fetchInfoGithub()
if err != nil {
return err
}
// 从 GitHub 下载资源
r, err := u.fetch(u.BinGithubURL)
if err != nil {
return err
}
defer r.Close()
// 创建临时目录用于解压文件
tempDir, err := ioutil.TempDir("", "samwaf_beta_update"+u.Info.Version)
if err != nil {
return err
}
defer os.RemoveAll(tempDir)
// 保存下载的文件到临时文件
tempFile := filepath.Join(tempDir, "download")
out, err := os.Create(tempFile)
if err != nil {
return err
}
_, err = io.Copy(out, r)
out.Close()
if err != nil {
return err
}
// 根据文件类型和平台解压并获取正确的可执行文件
var binPath string
if strings.HasSuffix(u.BinGithubURL, ".exe") {
//使用win7内核
binPath = tempFile
} else if strings.HasSuffix(u.BinGithubURL, ".zip") {
// 情况 1: 从 ZIP 中提取 SamWaf64.exe
err = utils.Unzip(tempFile, tempDir)
if err != nil {
return err
}
binPath = filepath.Join(tempDir, "SamWaf64.exe")
} else if strings.HasSuffix(u.BinGithubURL, ".tar.gz") {
// 处理 Linux 平台的 tar.gz 文件
err = utils.ExtractTarGz(tempFile, tempDir)
if err != nil {
return err
}
if strings.Contains(u.BinGithubURL, "Linux_x86_64") {
// 情况 2: 从 tar.gz 中提取 SamWafLinux64
binPath = filepath.Join(tempDir, "SamWafLinux64")
} else if strings.Contains(u.BinGithubURL, "Linux_arm64") {
// 情况 3: 从 tar.gz 中提取 SamWafLinuxArm64
binPath = filepath.Join(tempDir, "SamWafLinuxArm64")
}
}
// 检查是否找到了可执行文件
if binPath == "" {
return errors.New("无法找到适合当前平台的可执行文件")
}
fileBytes, err := ioutil.ReadFile(binPath)
if err != nil {
return err
}
err, errRecover := fromStream(bytes.NewBuffer(fileBytes))
if errRecover != nil {
return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
}
if err != nil {
return err
}
// update was successful, run func if set
if u.OnSuccessfulUpdate != nil {
u.OnSuccessfulUpdate()
}
}
return nil
}
func fromStream(updateWith io.Reader) (err error, errRecover error) {
updatePath, err := os.Executable()
if err != nil {
@@ -322,6 +514,72 @@ func (u *Updater) fetchInfo() error {
return nil
}
// fetchInfoGithub 从GitHub获取最新beta版本信息
func (u *Updater) fetchInfoGithub() error {
r, err := u.fetch(global.GUPDATE_GITHUB_VERSION_URL)
if err != nil {
return err
}
defer r.Close()
// 解析GitHub API返回的JSON数据
var githubRelease struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
Body string `json:"body"`
Assets []struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"`
ContentType string `json:"content_type"`
} `json:"assets"`
}
err = json.NewDecoder(r).Decode(&githubRelease)
if err != nil {
return err
}
//判断tagname 是否包含beta
if !strings.Contains(githubRelease.TagName, "beta") {
//return errors.New("not beta version")
}
// 查找适合当前平台的资源
var downloadURL string
// 根据平台选择合适的下载文件
platformSuffix := ""
utils.IsSupportedWindows7Version()
switch plat {
case "windows-amd64":
if utils.IsSupportedWindows7Version() {
platformSuffix = "SamWaf64ForWin7Win8Win2008"
} else {
platformSuffix = "Windows_x86_64"
}
case "linux-amd64":
platformSuffix = "Linux_x86_64"
case "linux-arm64":
platformSuffix = "Linux_arm64"
}
// 查找匹配当前平台的资源
for _, asset := range githubRelease.Assets {
if strings.Contains(asset.Name, platformSuffix) {
downloadURL = asset.BrowserDownloadURL
break
}
}
// 如果没有找到适合的资源,返回错误
if downloadURL == "" {
return errors.New("no suitable release asset found for this platform")
}
u.BinGithubURL = downloadURL
u.Info.Version = githubRelease.TagName
u.Info.Desc = githubRelease.Body
return nil
}
func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
bin, err := u.fetchAndApplyPatch(old)
if err != nil {
@@ -370,10 +628,8 @@ func (u *Updater) fetchBin() ([]byte, error) {
if _, err = io.Copy(buf, gz); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (u *Updater) fetch(url string) (io.ReadCloser, error) {
if u.Requester == nil {
return defaultHTTPRequester.Fetch(url)

View File

@@ -181,3 +181,35 @@ func TestUpdater_GetHttps(t *testing.T) {
println(strbody)
}
}
// 测试 fetchInfoGithub 在实际环境中的行为
func TestFetchInfoGithubIntegration(t *testing.T) {
// 创建更新器
updater := &Updater{
CurrentVersion: "v1.0.0",
}
// 调用被测试的函数
err := updater.fetchInfoGithub()
if err != nil {
t.Fatalf("获取 GitHub 信息失败: %v", err)
}
// 验证结果
if updater.BinGithubURL == "" {
t.Error("下载 URL 为空")
}
if updater.Info.Version == "" {
t.Error("版本为空")
}
if updater.Info.Desc == "" {
t.Error("描述为空")
}
t.Logf("获取到的版本: %s", updater.Info.Version)
t.Logf("下载 URL: %s", updater.BinGithubURL)
t.Logf("版本说明 Desc: %s", updater.Info.Desc)
}