mirror of
https://gitee.com/mindspore/mindarmour.git
synced 2025-12-06 11:59:05 +08:00
新建 lm_head_attack
Update the attack experiment and report Signed-off-by: Yichi <yichi@isrc.iscas.ac.cn> Update the newest network code Signed-off-by: Yichi <yichi@isrc.iscas.ac.cn> final version correct the READMD.md
This commit is contained in:
109
examples/community/lm_head_attack/README.md
Normal file
109
examples/community/lm_head_attack/README.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# LM-Head参数窃取算法实现
|
||||
|
||||
## 描述
|
||||
|
||||
本项目是基于MindSpore和Mindformers对于当今最流行的大型语言模型的 API 进行有针对性的查询,进行提取其嵌入维度或其最终权重矩阵。
|
||||
目标复现[2403.06634] Stealing Part of a Production Language Model (arxiv.org)结果
|
||||
|
||||
## 模型结构
|
||||
|
||||
采用华为Mindformers仓库中开源的"llama_7b"配置和"gpt_2"配置进行复现模拟参数窃取算法
|
||||
|
||||
## 环境要求
|
||||
|
||||
Mindspore=2.9, Mindformers=1.2.0, 使用Ascend硬件平台
|
||||
|
||||
## 脚本说明
|
||||
|
||||
```markdown
|
||||
│── readme.md
|
||||
│── src
|
||||
│ │── adapters
|
||||
│ │ │── local_mindformers_gpt2.py
|
||||
│ │ │── local_mindformers_llama.py
|
||||
│ │── dim_svd.py
|
||||
│ │—— svd_plots.py
|
||||
│ │── metrics.py
|
||||
│ │── numerics.py
|
||||
│── ourputs //攻击结果演示
|
||||
│ │── different_test_of_llama.png
|
||||
│ │── sigular_values_of_gpt2.png
|
||||
│ │── log_diff_of_gpt2.png
|
||||
│ └── log_diff_of_llama.png
|
||||
│── train.py //单次攻击脚本
|
||||
└── multi_train.py //多次攻击脚本
|
||||
```
|
||||
|
||||
## 攻击实验:
|
||||
|
||||
前言:
|
||||
对于论文中所使用的开源模型中,由于缺乏Llama_65b配置,Mindformers则未提供所需配置,因此llama仅复现7b结果,gpt_2并未提供权重矩阵参数,因此只比较隐藏维度的复现结果。
|
||||
|
||||
|
||||
- **单查询次数实验**(以Llama_7b(fp16)为例):
|
||||
对Llama_7b(fp16)模型进行攻击实验,意图复原其权重矩阵和隐藏维度。
|
||||
```python
|
||||
from src.adapters.local_mindformers_llama import MindFormersLlamaAdapter
|
||||
# 对于gpt模型:from src.adapters.local_mindformers_gpt2 import MindFormersGptAdapter
|
||||
from train import train
|
||||
adapter = MindFormersLlamaAdapter("llama_7b")
|
||||
train(adapter) # 使用默认参数:num_queries=5000, prompt_len=16, vocab_subset=None, seed=0, batch_size=8
|
||||
```
|
||||
攻击结果和简要报告保存在`outputs`文件夹中
|
||||
|
||||
- **多次查询次数实验**(以Llama_7b(fp16)为例):
|
||||
命令行运行
|
||||
```bash
|
||||
python3 multi_train.py --model_name "llama_7b" --prompt_len 128 --vocab_subset None --seed 42 --batch_size 16 --num_queries_list 1024 2048 4000 5000 6000
|
||||
```
|
||||
其中不同攻击得到的sigular values保存在`temp_results`文件夹中, 攻击得到的最终结果和简要报告保存在`outputs`文件夹中
|
||||
|
||||
|
||||
## 实验结果
|
||||
|
||||
#### 单次查询实验
|
||||
|
||||
**对于llama_7b配置的攻击结果**:
|
||||
|
||||
```text
|
||||
prompt_len:16
|
||||
vocab_used:32000
|
||||
h_est:4093
|
||||
rms_aligned:0.0002823424886589567
|
||||
```
|
||||

|
||||
得到的隐藏层大小为4093,与文献中的标准结果4096$\pm$2,误差大小接近在0.07%,满足攻击结果要求
|
||||
而权重矩阵的实验值与真实配置的误差大小在2$\times 10^{-4}$,接近文档平均误差大小水平:$10^{-4}$至$10^{-5}$水平,攻击实验合理完成
|
||||
**对于gpt2配置的攻击结果**:
|
||||
|
||||
```text
|
||||
h_est:761
|
||||
```
|
||||

|
||||
本次实验得到的h大小为761,文献得到结果为757$\pm$1,真实结果为768,因此可知本次复现结果合理有效
|
||||
|
||||
#### 多次查询实验
|
||||
|
||||
```text
|
||||
.......
|
||||
----- Summary -----
|
||||
=== 攻击实验评估报告 ===
|
||||
model:llama_7b
|
||||
num_queries:5000
|
||||
prompt_len:16
|
||||
vocab_used:32000
|
||||
h_est:4093
|
||||
rms_aligned:0.0003705834606818188
|
||||
----- Summary -----
|
||||
=== 攻击实验评估报告 ===
|
||||
model:llama_7b
|
||||
num_queries:6000
|
||||
prompt_len:16
|
||||
vocab_used:32000
|
||||
h_est:4093
|
||||
rms_aligned:0.0002823424886589567
|
||||
```
|
||||

|
||||
|
||||
后言:
|
||||
可通过增加num_queries、改变seed的方式多次对模型进行攻击和实验,提高结果精度。
|
||||
101
examples/community/lm_head_attack/multi_train.py
Normal file
101
examples/community/lm_head_attack/multi_train.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# 构建完整代码,复现不同查询次数和不同种子下的h值复现结果和W_aligned对应的rms复现结果
|
||||
import os
|
||||
|
||||
os.environ["GLOG_v"] = "3"
|
||||
from src.dim_svd import recover_hidden_matrix
|
||||
from src.metrics import rms_error, summary_report
|
||||
from src.svd_plots import plot_singular_values
|
||||
from src.numerics import center_matrix
|
||||
from src.adapters.local_mindformers_llama import MindFormersLlamaAdapter
|
||||
import numpy as np
|
||||
from train import collect_full_logits_matrix, align_and_eval
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import argparse
|
||||
|
||||
|
||||
def main(model_name, num_queries, prompt_len, vocab_subset, seed, batch_size):
|
||||
adapter = MindFormersLlamaAdapter(model_name)
|
||||
X, subset_idx = collect_full_logits_matrix(
|
||||
adapter,
|
||||
num_queries=num_queries,
|
||||
prompt_len=prompt_len,
|
||||
vocab_subset=vocab_subset,
|
||||
seed=seed,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
# 行中心化以提升数值稳定性
|
||||
Xc = center_matrix(X)
|
||||
|
||||
# 提取隐藏维度 h
|
||||
h_est, S, W_hat = recover_hidden_matrix(Xc, gap_threshold=2.0)
|
||||
print(f"[RESULT] Estimated hidden dimension h = {h_est}")
|
||||
|
||||
# 读取真实 W(若本地模型可得)
|
||||
W_true_full = adapter.get_W_true().astype(np.float64)
|
||||
if W_true_full is None:
|
||||
print("[WARN] 无法读取真实 W,跳过对齐评估。")
|
||||
plot_singular_values(S, h_est)
|
||||
return h_est, S, None
|
||||
|
||||
# 如使用了子词表,仅对相同行做评估
|
||||
if subset_idx is not None:
|
||||
W_true = W_true_full[subset_idx, :]
|
||||
else:
|
||||
W_true = W_true_full
|
||||
|
||||
rms, W_aligned, G = align_and_eval(W_hat, W_true)
|
||||
print(f"[RESULT] RMS after alignment: {rms:.6e}")
|
||||
|
||||
# 简要报告
|
||||
report = {
|
||||
"model": model_name,
|
||||
"num_queries": num_queries,
|
||||
"prompt_len": prompt_len,
|
||||
"vocab_used": X.shape[1],
|
||||
"h_est": h_est,
|
||||
"rms_aligned": rms,
|
||||
}
|
||||
print()
|
||||
print("----- Summary -----")
|
||||
print(summary_report(report))
|
||||
os.makedirs("./outputs", exist_ok=True)
|
||||
with open("./outputs/summary_report.txt", "a") as f:
|
||||
f.write("----- Summary -----\n")
|
||||
f.write(summary_report(report) + "\n")
|
||||
return h_est, S, rms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
h_list, rms_list = [], []
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, default="llama_7b")
|
||||
parser.add_argument("--prompt_len", type=int, default=16)
|
||||
parser.add_argument("--vocab_subset", type=str, default=None)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument(
|
||||
"--num_queries_list",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1024, 2048, 4000, 5000, 6000],
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
h_list, rms_list = [], []
|
||||
|
||||
for num in args.num_queries_list:
|
||||
h_est, S, rms = main(
|
||||
args.model_name, num, args.prompt_len, args.vocab_subset, args.seed, args.batch_size
|
||||
)
|
||||
h_list.append(h_est)
|
||||
rms_list.append(rms)
|
||||
pd.DataFrame(S).to_csv(f"./temp_results/{num}.csv")
|
||||
|
||||
for num in args.num_queries_list:
|
||||
S = pd.read_csv(f"./temp_results/{num}.csv", index_col=0)
|
||||
plt.plot(S[:-1], label=f"{num}")
|
||||
plt.yscale("log")
|
||||
plt.legend()
|
||||
plt.savefig("./outputs/singular_values_comparison.png")
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 31 KiB |
BIN
examples/community/lm_head_attack/outputs/log_diff_of_gpt2.png
Normal file
BIN
examples/community/lm_head_attack/outputs/log_diff_of_gpt2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
BIN
examples/community/lm_head_attack/outputs/log_diff_of_llama.png
Normal file
BIN
examples/community/lm_head_attack/outputs/log_diff_of_llama.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
@@ -0,0 +1,82 @@
|
||||
from typing import Dict, List, Optional
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from mindformers import AutoTokenizer, AutoModel
|
||||
import mindspore as ms
|
||||
|
||||
MINDFORMERS_AVAILABLE = True
|
||||
except Exception:
|
||||
MINDFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class MindFormersGptAdapter:
|
||||
"""
|
||||
本地 GPT (MindFormers) 适配器
|
||||
- 提供 full logits for next token
|
||||
使用前请确保模型与分词器可用(名称需与你的权重一致)
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "gpt2_small", device_target: str = "Ascend"):
|
||||
if not MINDFORMERS_AVAILABLE:
|
||||
raise ImportError("mindformers / mindspore 未安装或不可用。")
|
||||
ms.set_context(mode=ms.GRAPH_MODE, device_target=device_target)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name, dtype=ms.float32)
|
||||
self._hidden_size = getattr(getattr(self.model, "config", None), "hidden_size", None)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.tokenizer.vocab_size
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> Optional[int]:
|
||||
return self._hidden_size
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
return self.tokenizer(text)["input_ids"]
|
||||
|
||||
def batch_encode(self, texts: List[str]) -> List[List[int]]:
|
||||
return [self.encode(t) for t in texts]
|
||||
|
||||
def decode(self, token_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
def batch_next_token_logits(
|
||||
self, batch_input_ids: List[List[int]], pad_token_id: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
批量获取下一 token 的 logits
|
||||
batch_input_ids: List[List[int]],批量的 token id 序列
|
||||
返回: np.ndarray,形状 [batch_size, vocab_size]
|
||||
"""
|
||||
max_len = max(len(ids) for ids in batch_input_ids)
|
||||
padded_ids = [ids + [pad_token_id] * (max_len - len(ids)) for ids in batch_input_ids]
|
||||
input_ids_ms = ms.Tensor(padded_ids, dtype=ms.int64)
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids_ms,
|
||||
max_length=max_len + 1,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
step_logits = outputs.scores[0] # shape: [batch_size, vocab_size]
|
||||
return step_logits.astype(np.float64)
|
||||
|
||||
def next_token_logits(self, input_ids: List[int]) -> np.ndarray:
|
||||
"""
|
||||
获取单条输入的下一 token logits
|
||||
"""
|
||||
logits_batch = self.batch_next_token_logits(
|
||||
[input_ids], pad_token_id=self.tokenizer.pad_token_id
|
||||
)
|
||||
return logits_batch[0]
|
||||
|
||||
def get_W_true(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
获取输出头权重(通常与嵌入权重 tied)
|
||||
"""
|
||||
lm_head = getattr(self.model, "lm_head", None)
|
||||
if lm_head is None or not hasattr(lm_head, "weight"):
|
||||
return None
|
||||
W = lm_head.weight.asnumpy()
|
||||
return W
|
||||
@@ -0,0 +1,83 @@
|
||||
from typing import Dict, List, Optional
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from mindformers import AutoTokenizer, AutoModel
|
||||
import mindspore as ms
|
||||
|
||||
MINDFORMERS_AVAILABLE = True
|
||||
except Exception:
|
||||
MINDFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class MindFormersLlamaAdapter:
|
||||
"""
|
||||
本地 LLaMA (MindFormers) 适配器
|
||||
- 提供 full logits for next token
|
||||
- 提供基于 logits + bias 的模拟 top-k logprobs
|
||||
使用前请确保模型与分词器可用(名称需与你的权重一致)
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "llama_7b", device_target: str = "Ascend"):
|
||||
if not MINDFORMERS_AVAILABLE:
|
||||
raise ImportError("mindformers / mindspore 未安装或不可用。")
|
||||
ms.set_context(mode=ms.GRAPH_MODE, device_target=device_target)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name, dtype=ms.float16)
|
||||
self._hidden_size = getattr(getattr(self.model, "config", None), "hidden_size", None)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.tokenizer.vocab_size
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> Optional[int]:
|
||||
return self._hidden_size
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
return self.tokenizer(text)["input_ids"]
|
||||
|
||||
def batch_encode(self, texts: List[str]) -> List[List[int]]:
|
||||
return [self.encode(t) for t in texts]
|
||||
|
||||
def decode(self, token_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
def batch_next_token_logits(
|
||||
self, batch_input_ids: List[List[int]], pad_token_id: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
批量获取下一 token 的 logits
|
||||
batch_input_ids: List[List[int]],批量的 token id 序列
|
||||
返回: np.ndarray,形状 [batch_size, vocab_size]
|
||||
"""
|
||||
max_len = max(len(ids) for ids in batch_input_ids)
|
||||
padded_ids = [ids + [pad_token_id] * (max_len - len(ids)) for ids in batch_input_ids]
|
||||
input_ids_ms = ms.Tensor(padded_ids, dtype=ms.int64)
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids_ms,
|
||||
max_length=max_len + 1,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
step_logits = outputs.scores[0] # shape: [batch_size, vocab_size]
|
||||
return step_logits.astype(np.float64)
|
||||
|
||||
def next_token_logits(self, input_ids: List[int]) -> np.ndarray:
|
||||
"""
|
||||
获取单条输入的下一 token logits
|
||||
"""
|
||||
logits_batch = self.batch_next_token_logits(
|
||||
[input_ids], pad_token_id=self.tokenizer.pad_token_id
|
||||
)
|
||||
return logits_batch[0]
|
||||
|
||||
def get_W_true(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
获取输出头权重(通常与嵌入权重 tied)
|
||||
"""
|
||||
lm_head = getattr(self.model, "lm_head", None)
|
||||
if lm_head is None or not hasattr(lm_head, "weight"):
|
||||
return None
|
||||
W = lm_head.weight.asnumpy()
|
||||
return W
|
||||
26
examples/community/lm_head_attack/src/dim_svd.py
Normal file
26
examples/community/lm_head_attack/src/dim_svd.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def recover_hidden_matrix(
|
||||
logits_matrix: np.ndarray, gap_threshold: float = 2.0
|
||||
) -> Tuple[int, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
从logits矩阵中估计隐藏层维度 h并复原W_hat
|
||||
Args:
|
||||
logits_matrix: shape [num_queries, vocab_size]
|
||||
gap_threshold: 判断奇异值落差的阈值(log域差值)
|
||||
Returns:
|
||||
h: 估计的隐藏层维度
|
||||
singular_values: 奇异值数组
|
||||
W: 复原的矩阵
|
||||
"""
|
||||
Q = logits_matrix.T.astype(np.float64) # 转置为[l,n]
|
||||
U, S, _ = np.linalg.svd(Q, full_matrices=False)
|
||||
log_s = np.log(np.abs(S))
|
||||
gaps = - np.diff(log_s)[1:-1]
|
||||
h = int(np.argmax(gaps) + 1)
|
||||
if gaps[h - 1] < np.log(gap_threshold):
|
||||
print("[WARN] 最大gap低于阈值,结果可能不稳定")
|
||||
W_hat = U[:,:h] @ np.diag(S[:h])
|
||||
return h, S, W_hat
|
||||
27
examples/community/lm_head_attack/src/metrics.py
Normal file
27
examples/community/lm_head_attack/src/metrics.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def rms_error(W_true: np.ndarray, W_est: np.ndarray) -> float:
|
||||
"""
|
||||
计算均方根差
|
||||
"""
|
||||
return np.sqrt(np.mean((W_true - W_est) ** 2))
|
||||
|
||||
|
||||
def bit_agreement(logits_true: np.ndarray, logits_est: np.ndarray) -> float:
|
||||
"""
|
||||
估算logits的位精度(假设已归一化)
|
||||
"""
|
||||
diff = np.abs(logits_true - logits_est)
|
||||
return -np.log2(np.mean(diff) + 1e-12)
|
||||
|
||||
|
||||
def summary_report(results: Dict) -> str:
|
||||
"""
|
||||
生成简要评估报告
|
||||
"""
|
||||
lines = ["=== 攻击实验评估报告 ==="]
|
||||
for k, v in results.items():
|
||||
lines.append(f"{k}:{v}")
|
||||
return "\n".join(lines)
|
||||
16
examples/community/lm_head_attack/src/numerics.py
Normal file
16
examples/community/lm_head_attack/src/numerics.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def center_matrix(X: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
按行中心化矩阵
|
||||
"""
|
||||
return X - X.mean(axis=0, keepdims=True)
|
||||
|
||||
|
||||
def normalize_rows(X: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
按行归一化
|
||||
"""
|
||||
norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-12
|
||||
return X / norms
|
||||
13
examples/community/lm_head_attack/src/svd_plots.py
Normal file
13
examples/community/lm_head_attack/src/svd_plots.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_singular_values(S: np.ndarray, h_est: int):
|
||||
plt.figure()
|
||||
plt.semilogy(S, marker="o")
|
||||
plt.axvline(h_est - 1, color="r", linestyle="--", label=f"estimated h={h_est}")
|
||||
plt.xlabel("Sorted Singular Values")
|
||||
plt.ylabel("Magnitude(log)")
|
||||
plt.legend()
|
||||
plt.title("Singular Values Distribution")
|
||||
plt.show()
|
||||
107
examples/community/lm_head_attack/train.py
Normal file
107
examples/community/lm_head_attack/train.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
|
||||
os.environ["GLOG_v"] = "3"
|
||||
from src.dim_svd import recover_hidden_matrix
|
||||
from src.metrics import rms_error, summary_report
|
||||
from src.svd_plots import plot_singular_values
|
||||
from src.numerics import center_matrix
|
||||
import numpy as np
|
||||
|
||||
|
||||
def collect_full_logits_matrix(
|
||||
adapter,
|
||||
num_queries: int,
|
||||
prompt_len: int,
|
||||
vocab_subset: int = None,
|
||||
seed: int = 0,
|
||||
batch_size: int = 1,
|
||||
):
|
||||
"""
|
||||
批量采集多次查询的“最后一个位置”的全量 logits,形成矩阵:
|
||||
返回:
|
||||
X: [num_queries, vocab_or_subset]
|
||||
subset_indices: Optional[np.ndarray]
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
vsize = adapter.tokenizer.vocab_size
|
||||
|
||||
# 如果指定了子词表
|
||||
subset_idx = None
|
||||
if vocab_subset is not None:
|
||||
subset_idx = rng.choice(vsize, size=vocab_subset, replace=False)
|
||||
|
||||
all_logits = []
|
||||
|
||||
for start in range(0, num_queries, batch_size):
|
||||
current_bs = min(batch_size, num_queries - start)
|
||||
|
||||
# 随机生成 token id 序列
|
||||
batch_prompts = rng.integers(
|
||||
low=0, high=vsize, size=(current_bs, prompt_len), endpoint=False
|
||||
).tolist()
|
||||
|
||||
# 批量获取 logits
|
||||
logits_batch = adapter.batch_next_token_logits(
|
||||
batch_prompts, pad_token_id=adapter.tokenizer.pad_token_id
|
||||
) # shape: [batch, vocab]
|
||||
|
||||
if subset_idx is not None:
|
||||
logits_batch = logits_batch[:, subset_idx]
|
||||
|
||||
all_logits.append(logits_batch)
|
||||
|
||||
X = np.vstack(all_logits) # shape: [num_queries, vocab or subset]
|
||||
return X, subset_idx
|
||||
|
||||
|
||||
def align_and_eval(W_hat: np.ndarray, W_true: np.ndarray):
|
||||
"""
|
||||
通过最小二乘在右侧对齐(求解 G:W_hat G ≈ W_true),并计算 RMS
|
||||
"""
|
||||
# 解 G 的最小二乘:对每列一起解,lstsq(W_hat, W_true)
|
||||
G, *_ = np.linalg.lstsq(W_hat.astype(np.float64), W_true.astype(np.float64), rcond=None)
|
||||
W_aligned = (W_hat.astype(np.float64) @ G).astype(np.float64)
|
||||
err = rms_error(W_true, W_aligned)
|
||||
return err, W_aligned, G
|
||||
|
||||
|
||||
def train(adapter, num_queries=5000, prompt_len=16, vocab_subset=None, seed=0, batch_size=8):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
X, subset_idx = collect_full_logits_matrix(
|
||||
adapter,
|
||||
num_queries=num_queries,
|
||||
prompt_len=prompt_len,
|
||||
vocab_subset=vocab_subset,
|
||||
seed=seed,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
Xc = center_matrix(X)
|
||||
Q = Xc.T.astype(np.float64)
|
||||
U, S, _ = np.linalg.svd(Q, full_matrices=False)
|
||||
log_s = np.log(np.abs(S))
|
||||
gaps = log_s[:-1] - log_s[1:]
|
||||
|
||||
plt.plot(gaps[:-1])
|
||||
plt.yscale("log")
|
||||
plt.savefig("./outputs/gap_plot.png")
|
||||
h_exp = np.argmax(gaps[1:-1]) + 1
|
||||
W_hat = U[:, :h_exp] @ np.diag(S[:h_exp])
|
||||
W_true_full = adapter.get_W_true().astype(np.float64)
|
||||
err, W_aligned, G = align_and_eval(W_hat, W_true_full)
|
||||
# 简要报告
|
||||
report = {
|
||||
"num_queries": num_queries,
|
||||
"prompt_len": prompt_len,
|
||||
"vocab_used": X.shape[1],
|
||||
"h_exp": h_exp,
|
||||
"rms_aligned": err,
|
||||
}
|
||||
print()
|
||||
print("----- Summary -----")
|
||||
print(summary_report(report))
|
||||
os.makedirs("./outputs", exist_ok=True)
|
||||
with open("./outputs/summary_report.txt", "w") as f:
|
||||
f.write("----- Summary -----\n")
|
||||
f.write(summary_report(report))
|
||||
print("-------------------")
|
||||
Reference in New Issue
Block a user