diff --git a/examples/community/lm_head_attack/README.md b/examples/community/lm_head_attack/README.md new file mode 100644 index 0000000..c2a739d --- /dev/null +++ b/examples/community/lm_head_attack/README.md @@ -0,0 +1,109 @@ +# LM-Head参数窃取算法实现 + +## 描述 + +本项目是基于MindSpore和Mindformers对于当今最流行的大型语言模型的 API 进行有针对性的查询,进行提取其嵌入维度或其最终权重矩阵。 +目标复现[2403.06634] Stealing Part of a Production Language Model (arxiv.org)结果 + +## 模型结构 + +采用华为Mindformers仓库中开源的"llama_7b"配置和"gpt_2"配置进行复现模拟参数窃取算法 + +## 环境要求 + +Mindspore=2.9, Mindformers=1.2.0, 使用Ascend硬件平台 + +## 脚本说明 + +```markdown +│── readme.md +│── src +│ │── adapters +│ │ │── local_mindformers_gpt2.py +│ │ │── local_mindformers_llama.py +│ │── dim_svd.py +│ │—— svd_plots.py +│ │── metrics.py +│ │── numerics.py +│── ourputs //攻击结果演示 +│ │── different_test_of_llama.png +│ │── sigular_values_of_gpt2.png +│ │── log_diff_of_gpt2.png +│ └── log_diff_of_llama.png +│── train.py //单次攻击脚本 +└── multi_train.py //多次攻击脚本 +``` + +## 攻击实验: + +前言: +对于论文中所使用的开源模型中,由于缺乏Llama_65b配置,Mindformers则未提供所需配置,因此llama仅复现7b结果,gpt_2并未提供权重矩阵参数,因此只比较隐藏维度的复现结果。 + + +- **单查询次数实验**(以Llama_7b(fp16)为例): +对Llama_7b(fp16)模型进行攻击实验,意图复原其权重矩阵和隐藏维度。 + ```python + from src.adapters.local_mindformers_llama import MindFormersLlamaAdapter + # 对于gpt模型:from src.adapters.local_mindformers_gpt2 import MindFormersGptAdapter + from train import train + adapter = MindFormersLlamaAdapter("llama_7b") + train(adapter) # 使用默认参数:num_queries=5000, prompt_len=16, vocab_subset=None, seed=0, batch_size=8 + ``` +攻击结果和简要报告保存在`outputs`文件夹中 + +- **多次查询次数实验**(以Llama_7b(fp16)为例): + 命令行运行 + ```bash + python3 multi_train.py --model_name "llama_7b" --prompt_len 128 --vocab_subset None --seed 42 --batch_size 16 --num_queries_list 1024 2048 4000 5000 6000 + ``` +其中不同攻击得到的sigular values保存在`temp_results`文件夹中, 攻击得到的最终结果和简要报告保存在`outputs`文件夹中 + + +## 实验结果 + +#### 单次查询实验 + +**对于llama_7b配置的攻击结果**: + +```text +prompt_len:16 +vocab_used:32000 +h_est:4093 +rms_aligned:0.0002823424886589567 +``` +![difference between sigular values for llama in log](outputs/log_diff_of_llama.png) +得到的隐藏层大小为4093,与文献中的标准结果4096$\pm$2,误差大小接近在0.07%,满足攻击结果要求 +而权重矩阵的实验值与真实配置的误差大小在2$\times 10^{-4}$,接近文档平均误差大小水平:$10^{-4}$至$10^{-5}$水平,攻击实验合理完成 +**对于gpt2配置的攻击结果**: + +```text +h_est:761 +``` +![difference between sigular values for gpt2 in log](outputs/log_diff_of_gpt2.png) +本次实验得到的h大小为761,文献得到结果为757$\pm$1,真实结果为768,因此可知本次复现结果合理有效 + +#### 多次查询实验 + +```text +....... +----- Summary ----- +=== 攻击实验评估报告 === +model:llama_7b +num_queries:5000 +prompt_len:16 +vocab_used:32000 +h_est:4093 +rms_aligned:0.0003705834606818188 +----- Summary ----- +=== 攻击实验评估报告 === +model:llama_7b +num_queries:6000 +prompt_len:16 +vocab_used:32000 +h_est:4093 +rms_aligned:0.0002823424886589567 +``` +![different results for multi queries](outputs/different_test_of_llama.png) + +后言: +可通过增加num_queries、改变seed的方式多次对模型进行攻击和实验,提高结果精度。 \ No newline at end of file diff --git a/examples/community/lm_head_attack/multi_train.py b/examples/community/lm_head_attack/multi_train.py new file mode 100644 index 0000000..2a3927c --- /dev/null +++ b/examples/community/lm_head_attack/multi_train.py @@ -0,0 +1,101 @@ +# 构建完整代码,复现不同查询次数和不同种子下的h值复现结果和W_aligned对应的rms复现结果 +import os + +os.environ["GLOG_v"] = "3" +from src.dim_svd import recover_hidden_matrix +from src.metrics import rms_error, summary_report +from src.svd_plots import plot_singular_values +from src.numerics import center_matrix +from src.adapters.local_mindformers_llama import MindFormersLlamaAdapter +import numpy as np +from train import collect_full_logits_matrix, align_and_eval +import matplotlib.pyplot as plt +import pandas as pd +import argparse + + +def main(model_name, num_queries, prompt_len, vocab_subset, seed, batch_size): + adapter = MindFormersLlamaAdapter(model_name) + X, subset_idx = collect_full_logits_matrix( + adapter, + num_queries=num_queries, + prompt_len=prompt_len, + vocab_subset=vocab_subset, + seed=seed, + batch_size=batch_size, + ) + # 行中心化以提升数值稳定性 + Xc = center_matrix(X) + + # 提取隐藏维度 h + h_est, S, W_hat = recover_hidden_matrix(Xc, gap_threshold=2.0) + print(f"[RESULT] Estimated hidden dimension h = {h_est}") + + # 读取真实 W(若本地模型可得) + W_true_full = adapter.get_W_true().astype(np.float64) + if W_true_full is None: + print("[WARN] 无法读取真实 W,跳过对齐评估。") + plot_singular_values(S, h_est) + return h_est, S, None + + # 如使用了子词表,仅对相同行做评估 + if subset_idx is not None: + W_true = W_true_full[subset_idx, :] + else: + W_true = W_true_full + + rms, W_aligned, G = align_and_eval(W_hat, W_true) + print(f"[RESULT] RMS after alignment: {rms:.6e}") + + # 简要报告 + report = { + "model": model_name, + "num_queries": num_queries, + "prompt_len": prompt_len, + "vocab_used": X.shape[1], + "h_est": h_est, + "rms_aligned": rms, + } + print() + print("----- Summary -----") + print(summary_report(report)) + os.makedirs("./outputs", exist_ok=True) + with open("./outputs/summary_report.txt", "a") as f: + f.write("----- Summary -----\n") + f.write(summary_report(report) + "\n") + return h_est, S, rms + + +if __name__ == "__main__": + h_list, rms_list = [], [] + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default="llama_7b") + parser.add_argument("--prompt_len", type=int, default=16) + parser.add_argument("--vocab_subset", type=str, default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument( + "--num_queries_list", + type=int, + nargs="+", + default=[1024, 2048, 4000, 5000, 6000], + required=True, + ) + args = parser.parse_args() + + h_list, rms_list = [], [] + + for num in args.num_queries_list: + h_est, S, rms = main( + args.model_name, num, args.prompt_len, args.vocab_subset, args.seed, args.batch_size + ) + h_list.append(h_est) + rms_list.append(rms) + pd.DataFrame(S).to_csv(f"./temp_results/{num}.csv") + + for num in args.num_queries_list: + S = pd.read_csv(f"./temp_results/{num}.csv", index_col=0) + plt.plot(S[:-1], label=f"{num}") + plt.yscale("log") + plt.legend() + plt.savefig("./outputs/singular_values_comparison.png") diff --git a/examples/community/lm_head_attack/outputs/different_test_of_llama.png b/examples/community/lm_head_attack/outputs/different_test_of_llama.png new file mode 100644 index 0000000..a5c6cd4 Binary files /dev/null and b/examples/community/lm_head_attack/outputs/different_test_of_llama.png differ diff --git a/examples/community/lm_head_attack/outputs/log_diff_of_gpt2.png b/examples/community/lm_head_attack/outputs/log_diff_of_gpt2.png new file mode 100644 index 0000000..84ffac1 Binary files /dev/null and b/examples/community/lm_head_attack/outputs/log_diff_of_gpt2.png differ diff --git a/examples/community/lm_head_attack/outputs/log_diff_of_llama.png b/examples/community/lm_head_attack/outputs/log_diff_of_llama.png new file mode 100644 index 0000000..a550110 Binary files /dev/null and b/examples/community/lm_head_attack/outputs/log_diff_of_llama.png differ diff --git a/examples/community/lm_head_attack/outputs/sigular_values_of_gpt2.png b/examples/community/lm_head_attack/outputs/sigular_values_of_gpt2.png new file mode 100644 index 0000000..e4fd36e Binary files /dev/null and b/examples/community/lm_head_attack/outputs/sigular_values_of_gpt2.png differ diff --git a/examples/community/lm_head_attack/src/adapters/local_mindformers_gpt2.py b/examples/community/lm_head_attack/src/adapters/local_mindformers_gpt2.py new file mode 100644 index 0000000..04d8ba8 --- /dev/null +++ b/examples/community/lm_head_attack/src/adapters/local_mindformers_gpt2.py @@ -0,0 +1,82 @@ +from typing import Dict, List, Optional +import numpy as np + +try: + from mindformers import AutoTokenizer, AutoModel + import mindspore as ms + + MINDFORMERS_AVAILABLE = True +except Exception: + MINDFORMERS_AVAILABLE = False + + +class MindFormersGptAdapter: + """ + 本地 GPT (MindFormers) 适配器 + - 提供 full logits for next token + 使用前请确保模型与分词器可用(名称需与你的权重一致) + """ + + def __init__(self, model_name: str = "gpt2_small", device_target: str = "Ascend"): + if not MINDFORMERS_AVAILABLE: + raise ImportError("mindformers / mindspore 未安装或不可用。") + ms.set_context(mode=ms.GRAPH_MODE, device_target=device_target) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name, dtype=ms.float32) + self._hidden_size = getattr(getattr(self.model, "config", None), "hidden_size", None) + + @property + def vocab_size(self) -> int: + return self.tokenizer.vocab_size + + @property + def hidden_size(self) -> Optional[int]: + return self._hidden_size + + def encode(self, text: str) -> List[int]: + return self.tokenizer(text)["input_ids"] + + def batch_encode(self, texts: List[str]) -> List[List[int]]: + return [self.encode(t) for t in texts] + + def decode(self, token_ids: List[int]) -> str: + return self.tokenizer.decode(token_ids) + + def batch_next_token_logits( + self, batch_input_ids: List[List[int]], pad_token_id: int = 0 + ) -> np.ndarray: + """ + 批量获取下一 token 的 logits + batch_input_ids: List[List[int]],批量的 token id 序列 + 返回: np.ndarray,形状 [batch_size, vocab_size] + """ + max_len = max(len(ids) for ids in batch_input_ids) + padded_ids = [ids + [pad_token_id] * (max_len - len(ids)) for ids in batch_input_ids] + input_ids_ms = ms.Tensor(padded_ids, dtype=ms.int64) + outputs = self.model.generate( + input_ids=input_ids_ms, + max_length=max_len + 1, + return_dict_in_generate=True, + output_scores=True, + ) + step_logits = outputs.scores[0] # shape: [batch_size, vocab_size] + return step_logits.astype(np.float64) + + def next_token_logits(self, input_ids: List[int]) -> np.ndarray: + """ + 获取单条输入的下一 token logits + """ + logits_batch = self.batch_next_token_logits( + [input_ids], pad_token_id=self.tokenizer.pad_token_id + ) + return logits_batch[0] + + def get_W_true(self) -> Optional[np.ndarray]: + """ + 获取输出头权重(通常与嵌入权重 tied) + """ + lm_head = getattr(self.model, "lm_head", None) + if lm_head is None or not hasattr(lm_head, "weight"): + return None + W = lm_head.weight.asnumpy() + return W diff --git a/examples/community/lm_head_attack/src/adapters/local_mindformers_llama.py b/examples/community/lm_head_attack/src/adapters/local_mindformers_llama.py new file mode 100644 index 0000000..d4d99ad --- /dev/null +++ b/examples/community/lm_head_attack/src/adapters/local_mindformers_llama.py @@ -0,0 +1,83 @@ +from typing import Dict, List, Optional +import numpy as np + +try: + from mindformers import AutoTokenizer, AutoModel + import mindspore as ms + + MINDFORMERS_AVAILABLE = True +except Exception: + MINDFORMERS_AVAILABLE = False + + +class MindFormersLlamaAdapter: + """ + 本地 LLaMA (MindFormers) 适配器 + - 提供 full logits for next token + - 提供基于 logits + bias 的模拟 top-k logprobs + 使用前请确保模型与分词器可用(名称需与你的权重一致) + """ + + def __init__(self, model_name: str = "llama_7b", device_target: str = "Ascend"): + if not MINDFORMERS_AVAILABLE: + raise ImportError("mindformers / mindspore 未安装或不可用。") + ms.set_context(mode=ms.GRAPH_MODE, device_target=device_target) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name, dtype=ms.float16) + self._hidden_size = getattr(getattr(self.model, "config", None), "hidden_size", None) + + @property + def vocab_size(self) -> int: + return self.tokenizer.vocab_size + + @property + def hidden_size(self) -> Optional[int]: + return self._hidden_size + + def encode(self, text: str) -> List[int]: + return self.tokenizer(text)["input_ids"] + + def batch_encode(self, texts: List[str]) -> List[List[int]]: + return [self.encode(t) for t in texts] + + def decode(self, token_ids: List[int]) -> str: + return self.tokenizer.decode(token_ids) + + def batch_next_token_logits( + self, batch_input_ids: List[List[int]], pad_token_id: int = 0 + ) -> np.ndarray: + """ + 批量获取下一 token 的 logits + batch_input_ids: List[List[int]],批量的 token id 序列 + 返回: np.ndarray,形状 [batch_size, vocab_size] + """ + max_len = max(len(ids) for ids in batch_input_ids) + padded_ids = [ids + [pad_token_id] * (max_len - len(ids)) for ids in batch_input_ids] + input_ids_ms = ms.Tensor(padded_ids, dtype=ms.int64) + outputs = self.model.generate( + input_ids=input_ids_ms, + max_length=max_len + 1, + return_dict_in_generate=True, + output_scores=True, + ) + step_logits = outputs.scores[0] # shape: [batch_size, vocab_size] + return step_logits.astype(np.float64) + + def next_token_logits(self, input_ids: List[int]) -> np.ndarray: + """ + 获取单条输入的下一 token logits + """ + logits_batch = self.batch_next_token_logits( + [input_ids], pad_token_id=self.tokenizer.pad_token_id + ) + return logits_batch[0] + + def get_W_true(self) -> Optional[np.ndarray]: + """ + 获取输出头权重(通常与嵌入权重 tied) + """ + lm_head = getattr(self.model, "lm_head", None) + if lm_head is None or not hasattr(lm_head, "weight"): + return None + W = lm_head.weight.asnumpy() + return W diff --git a/examples/community/lm_head_attack/src/dim_svd.py b/examples/community/lm_head_attack/src/dim_svd.py new file mode 100644 index 0000000..23801cb --- /dev/null +++ b/examples/community/lm_head_attack/src/dim_svd.py @@ -0,0 +1,26 @@ +import numpy as np +from typing import Tuple + + +def recover_hidden_matrix( + logits_matrix: np.ndarray, gap_threshold: float = 2.0 +) -> Tuple[int, np.ndarray, np.ndarray]: + """ + 从logits矩阵中估计隐藏层维度 h并复原W_hat + Args: + logits_matrix: shape [num_queries, vocab_size] + gap_threshold: 判断奇异值落差的阈值(log域差值) + Returns: + h: 估计的隐藏层维度 + singular_values: 奇异值数组 + W: 复原的矩阵 + """ + Q = logits_matrix.T.astype(np.float64) # 转置为[l,n] + U, S, _ = np.linalg.svd(Q, full_matrices=False) + log_s = np.log(np.abs(S)) + gaps = - np.diff(log_s)[1:-1] + h = int(np.argmax(gaps) + 1) + if gaps[h - 1] < np.log(gap_threshold): + print("[WARN] 最大gap低于阈值,结果可能不稳定") + W_hat = U[:,:h] @ np.diag(S[:h]) + return h, S, W_hat diff --git a/examples/community/lm_head_attack/src/metrics.py b/examples/community/lm_head_attack/src/metrics.py new file mode 100644 index 0000000..c6e40fd --- /dev/null +++ b/examples/community/lm_head_attack/src/metrics.py @@ -0,0 +1,27 @@ +import numpy as np +from typing import Dict + + +def rms_error(W_true: np.ndarray, W_est: np.ndarray) -> float: + """ + 计算均方根差 + """ + return np.sqrt(np.mean((W_true - W_est) ** 2)) + + +def bit_agreement(logits_true: np.ndarray, logits_est: np.ndarray) -> float: + """ + 估算logits的位精度(假设已归一化) + """ + diff = np.abs(logits_true - logits_est) + return -np.log2(np.mean(diff) + 1e-12) + + +def summary_report(results: Dict) -> str: + """ + 生成简要评估报告 + """ + lines = ["=== 攻击实验评估报告 ==="] + for k, v in results.items(): + lines.append(f"{k}:{v}") + return "\n".join(lines) diff --git a/examples/community/lm_head_attack/src/numerics.py b/examples/community/lm_head_attack/src/numerics.py new file mode 100644 index 0000000..3b8cf9d --- /dev/null +++ b/examples/community/lm_head_attack/src/numerics.py @@ -0,0 +1,16 @@ +import numpy as np + + +def center_matrix(X: np.ndarray) -> np.ndarray: + """ + 按行中心化矩阵 + """ + return X - X.mean(axis=0, keepdims=True) + + +def normalize_rows(X: np.ndarray) -> np.ndarray: + """ + 按行归一化 + """ + norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-12 + return X / norms diff --git a/examples/community/lm_head_attack/src/svd_plots.py b/examples/community/lm_head_attack/src/svd_plots.py new file mode 100644 index 0000000..89668da --- /dev/null +++ b/examples/community/lm_head_attack/src/svd_plots.py @@ -0,0 +1,13 @@ +import numpy as np +import matplotlib.pyplot as plt + + +def plot_singular_values(S: np.ndarray, h_est: int): + plt.figure() + plt.semilogy(S, marker="o") + plt.axvline(h_est - 1, color="r", linestyle="--", label=f"estimated h={h_est}") + plt.xlabel("Sorted Singular Values") + plt.ylabel("Magnitude(log)") + plt.legend() + plt.title("Singular Values Distribution") + plt.show() diff --git a/examples/community/lm_head_attack/train.py b/examples/community/lm_head_attack/train.py new file mode 100644 index 0000000..becedd6 --- /dev/null +++ b/examples/community/lm_head_attack/train.py @@ -0,0 +1,107 @@ +import os + +os.environ["GLOG_v"] = "3" +from src.dim_svd import recover_hidden_matrix +from src.metrics import rms_error, summary_report +from src.svd_plots import plot_singular_values +from src.numerics import center_matrix +import numpy as np + + +def collect_full_logits_matrix( + adapter, + num_queries: int, + prompt_len: int, + vocab_subset: int = None, + seed: int = 0, + batch_size: int = 1, +): + """ + 批量采集多次查询的“最后一个位置”的全量 logits,形成矩阵: + 返回: + X: [num_queries, vocab_or_subset] + subset_indices: Optional[np.ndarray] + """ + rng = np.random.default_rng(seed) + vsize = adapter.tokenizer.vocab_size + + # 如果指定了子词表 + subset_idx = None + if vocab_subset is not None: + subset_idx = rng.choice(vsize, size=vocab_subset, replace=False) + + all_logits = [] + + for start in range(0, num_queries, batch_size): + current_bs = min(batch_size, num_queries - start) + + # 随机生成 token id 序列 + batch_prompts = rng.integers( + low=0, high=vsize, size=(current_bs, prompt_len), endpoint=False + ).tolist() + + # 批量获取 logits + logits_batch = adapter.batch_next_token_logits( + batch_prompts, pad_token_id=adapter.tokenizer.pad_token_id + ) # shape: [batch, vocab] + + if subset_idx is not None: + logits_batch = logits_batch[:, subset_idx] + + all_logits.append(logits_batch) + + X = np.vstack(all_logits) # shape: [num_queries, vocab or subset] + return X, subset_idx + + +def align_and_eval(W_hat: np.ndarray, W_true: np.ndarray): + """ + 通过最小二乘在右侧对齐(求解 G:W_hat G ≈ W_true),并计算 RMS + """ + # 解 G 的最小二乘:对每列一起解,lstsq(W_hat, W_true) + G, *_ = np.linalg.lstsq(W_hat.astype(np.float64), W_true.astype(np.float64), rcond=None) + W_aligned = (W_hat.astype(np.float64) @ G).astype(np.float64) + err = rms_error(W_true, W_aligned) + return err, W_aligned, G + + +def train(adapter, num_queries=5000, prompt_len=16, vocab_subset=None, seed=0, batch_size=8): + import matplotlib.pyplot as plt + + X, subset_idx = collect_full_logits_matrix( + adapter, + num_queries=num_queries, + prompt_len=prompt_len, + vocab_subset=vocab_subset, + seed=seed, + batch_size=batch_size, + ) + Xc = center_matrix(X) + Q = Xc.T.astype(np.float64) + U, S, _ = np.linalg.svd(Q, full_matrices=False) + log_s = np.log(np.abs(S)) + gaps = log_s[:-1] - log_s[1:] + + plt.plot(gaps[:-1]) + plt.yscale("log") + plt.savefig("./outputs/gap_plot.png") + h_exp = np.argmax(gaps[1:-1]) + 1 + W_hat = U[:, :h_exp] @ np.diag(S[:h_exp]) + W_true_full = adapter.get_W_true().astype(np.float64) + err, W_aligned, G = align_and_eval(W_hat, W_true_full) + # 简要报告 + report = { + "num_queries": num_queries, + "prompt_len": prompt_len, + "vocab_used": X.shape[1], + "h_exp": h_exp, + "rms_aligned": err, + } + print() + print("----- Summary -----") + print(summary_report(report)) + os.makedirs("./outputs", exist_ok=True) + with open("./outputs/summary_report.txt", "w") as f: + f.write("----- Summary -----\n") + f.write(summary_report(report)) + print("-------------------")