mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 19:42:57 +08:00
Compare commits
4 Commits
bc4ed9d124
...
kbk-infer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1eb2ec173 | ||
|
|
8cb628a9ab | ||
|
|
562584dc0e | ||
|
|
75078d33d5 |
@@ -20,6 +20,7 @@ import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Parameter, Tensor, mint, nn, ops
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.auto_generate import PagedAttention
|
||||
|
||||
from mindformers.experimental.parallel_core.pynative.transformer.scale_mask_softmax import ScaleMaskSoftmax
|
||||
from mindformers.experimental.parallel_core.pynative.utils import divide
|
||||
@@ -29,7 +30,7 @@ from mindformers.experimental.infer.core.utils import get_tp_world_size, create_
|
||||
from mindformers.modules.flash_attention import FlashAttention
|
||||
from mindformers.modules.infer_attention import InferRotaryEmbedding
|
||||
from mindformers.modules.layers import FreqsMgr, RotaryEmbedding
|
||||
from mindformers.modules.paged_attention_mgr import PagedAttentionMgr
|
||||
from mindformers.modules.kv_cache_mgr import KVCacheMgr
|
||||
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
|
||||
from mindformers.tools.utils import is_pynative
|
||||
|
||||
@@ -382,24 +383,14 @@ class ParallelAttention(nn.Cell):
|
||||
|
||||
if self.use_past:
|
||||
kv_shape = (self.config.num_blocks, self.config.block_size, self.kv_num_heads_per_partition, self.head_dim)
|
||||
self.paged_attention_mgr = PagedAttentionMgr(self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.kv_num_heads_per_partition,
|
||||
kv_shape,
|
||||
config.seq_length,
|
||||
compute_dtype=self.compute_dtype)
|
||||
self.paged_attention_mgr.key_cache = create_empty_parameter(
|
||||
shape=self.paged_attention_mgr.key_cache.shape,
|
||||
dtype=self.paged_attention_mgr.key_cache.dtype,
|
||||
name=self.paged_attention_mgr.key_cache.name,
|
||||
requires_grad=self.paged_attention_mgr.key_cache.requires_grad,
|
||||
)
|
||||
self.paged_attention_mgr.value_cache = create_empty_parameter(
|
||||
shape=self.paged_attention_mgr.value_cache.shape,
|
||||
dtype=self.paged_attention_mgr.value_cache.dtype,
|
||||
name=self.paged_attention_mgr.value_cache.name,
|
||||
requires_grad=self.paged_attention_mgr.value_cache.requires_grad,
|
||||
)
|
||||
self.kv_cache_mgr = KVCacheMgr(self.kv_num_heads_per_partition,
|
||||
self.head_dim,
|
||||
num_blocks=self.config.num_blocks,
|
||||
block_size=self.config.block_size,
|
||||
compute_dtype=self.compute_dtype)
|
||||
self.paged_attention = PagedAttention(self.num_heads_per_partition,
|
||||
1.0 / self.norm_factor,
|
||||
self.kv_num_heads_per_partition)
|
||||
self.rotary_embedding = InferRotaryEmbedding(rotary_cos_format=2)
|
||||
else:
|
||||
self.apply_rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_dtype)
|
||||
@@ -461,8 +452,7 @@ class ParallelAttention(nn.Cell):
|
||||
if self.is_first_iteration:
|
||||
key, value = self._cat_prefix(key, value, prefix_keys_values)
|
||||
|
||||
key_out = self.paged_attention_mgr(key, value, slot_mapping, batch_valid_length)
|
||||
query = ops.depend(query, key_out)
|
||||
key_cache, value_cache = self.kv_cache_mgr(key, value, slot_mapping, batch_valid_length)
|
||||
|
||||
if self.is_first_iteration:
|
||||
if self.use_flash_attention:
|
||||
@@ -496,9 +486,8 @@ class ParallelAttention(nn.Cell):
|
||||
context_layer = context_layer.transpose(0, 2, 1, 3).reshape(
|
||||
bs, seq_len, self.hidden_size_per_partition)
|
||||
else:
|
||||
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables)
|
||||
|
||||
# [B, S, N, D]
|
||||
context_layer = self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
|
||||
None, None, None, None)
|
||||
else:
|
||||
# [B, S, N, D] --> [B, N, S, D]
|
||||
query = query.transpose(0, 2, 1, 3)
|
||||
|
||||
@@ -133,7 +133,6 @@ class ParallelLlamaForCausalLM(LlamaPreTrainedModel):
|
||||
for layer in self.model.layers:
|
||||
layer.add_flags(is_first_iteration=is_first_iteration)
|
||||
layer.attention.add_flags(is_first_iteration=is_first_iteration)
|
||||
layer.attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
|
||||
|
||||
# pylint: disable=W0613
|
||||
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
|
||||
@@ -164,8 +163,8 @@ class ParallelLlamaForCausalLM(LlamaPreTrainedModel):
|
||||
return logits, input_ids, input_mask
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -504,6 +504,6 @@ class CogVLM2VideoLM(LlamaPreTrainedModel):
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
"""Get kvcache with input layer index."""
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -891,8 +891,8 @@ class LlamaForCausalLMForCogVLM2Image(LlamaPreTrainedModel):
|
||||
"""Get kvcache with input layer index."""
|
||||
key_cache = self.model.layers[
|
||||
layer_idx
|
||||
].self_attn.infer_attention.paged_attention_mgr.key_cache
|
||||
].self_attn.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[
|
||||
layer_idx
|
||||
].self_attn.infer_attention.paged_attention_mgr.value_cache
|
||||
].self_attn.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -478,7 +478,7 @@ class ChatGLM2WithPtuning2(ChatGLM2ForConditionalGeneration):
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = \
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.key_cache
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = \
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.value_cache
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -35,7 +35,7 @@ from mindformers.modules.layers import Linear, FreqsMgr
|
||||
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
|
||||
from mindformers.modules.transformer.op_parallel_config import _check_config
|
||||
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
|
||||
from mindformers.tools.utils import get_predict_run_mode
|
||||
from mindformers.tools.utils import get_predict_run_mode, get_infer_boost
|
||||
|
||||
from .llama_config import LlamaConfig
|
||||
from .llama_layer import LlamaEmbedding, LlamaRMSNorm
|
||||
@@ -95,7 +95,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.rmsnorm_compute_2d = config.rmsnorm_compute_2d
|
||||
|
||||
self.enable_infer_boost = get_infer_boost()
|
||||
if self.use_past and not self.enable_infer_boost:
|
||||
self.range = Tensor(np.arange(config.seq_length).reshape((1, 1, -1)), mstype.int32)
|
||||
if config.moe_config.expert_num > 1:
|
||||
logger.info("MoE config is provided, use MoE FFN")
|
||||
else:
|
||||
@@ -280,6 +282,34 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
else:
|
||||
self.norm_out.shard((dp, cp, 1))
|
||||
|
||||
def gen_infer_freqs_and_mask(self, batch_size, seq_len, tokens, batch_valid_length, prefix_keys_values):
|
||||
"""generate infer rope freqs and attention mask."""
|
||||
# for infer boost off mode or o2 mode
|
||||
if not self.enable_infer_boost:
|
||||
if self.is_first_iteration:
|
||||
freqs_cis = self.freqs_mgr(seq_len)
|
||||
mask = self.casual_mask(tokens) # mask: [bs, seq, seq]
|
||||
else:
|
||||
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
|
||||
mask = self.casual_mask.increment(self.range, batch_valid_length)
|
||||
return freqs_cis, mask
|
||||
|
||||
# for O0 + infer boost on mode
|
||||
mask = None
|
||||
if self.is_first_iteration:
|
||||
freqs_cis = self.freqs_mgr.prefill(batch_size, seq_len)
|
||||
mask = self.casual_mask.prefill()
|
||||
if prefix_keys_values is not None:
|
||||
if mask is None:
|
||||
mask = self.casual_mask(tokens)
|
||||
prefix_length = prefix_keys_values[0].shape[2]
|
||||
prefix_mask = Tensor(np.zeros((batch_size, 1, seq_len, prefix_length)), dtype=mask.dtype)
|
||||
mask = self.concat((prefix_mask, mask))
|
||||
else:
|
||||
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
|
||||
|
||||
return freqs_cis, mask
|
||||
|
||||
# pylint: disable=W0613
|
||||
def construct(self, tokens: Tensor, input_embeds=None, batch_valid_length=None, batch_index=None,
|
||||
zactivate_len=None, block_tables=None, slot_mapping=None, prefix_keys_values=None,
|
||||
@@ -325,17 +355,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
else:
|
||||
mask = None
|
||||
if self.use_past:
|
||||
if self.is_first_iteration:
|
||||
freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
|
||||
mask = self.casual_mask.prefill()
|
||||
if prefix_keys_values is not None:
|
||||
if mask is None:
|
||||
mask = self.casual_mask(tokens)
|
||||
prefix_length = prefix_keys_values[0].shape[2]
|
||||
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
|
||||
mask = self.concat((prefix_mask, mask))
|
||||
else:
|
||||
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
|
||||
freqs_cis, mask = self.gen_infer_freqs_and_mask(bs, seq_len, tokens, batch_valid_length,
|
||||
prefix_keys_values)
|
||||
else:
|
||||
if self.seq_pipe:
|
||||
mask = self.casual_mask(tokens, seq_chunk=self.seq_chunk)
|
||||
@@ -359,7 +380,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
else:
|
||||
h = self.cast(self.tok_embeddings(tokens), self.dtype)
|
||||
if not rmsnorm_compute_2d:
|
||||
h = self.reshape(h, (bs, seq_len, self.hidden_size)) # h: [bs, seq/1, hidden_dim]
|
||||
h = self.reshape(h, (bs, seq_len, self.hidden_size)) # h: [bs, seq/1, hidden_dim]
|
||||
for i in range(self.num_layers):
|
||||
prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None
|
||||
h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables,
|
||||
@@ -580,7 +601,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
for layer in self.model.layers:
|
||||
layer.add_flags(is_first_iteration=is_first_iteration)
|
||||
layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration)
|
||||
layer.attention.infer_attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
|
||||
|
||||
def pre_gather_func(self, pre_gather, output, batch_valid_length, gather_index=None):
|
||||
"""Pre gather operation in infer mode."""
|
||||
@@ -655,8 +675,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
return loss
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -664,7 +664,7 @@ class LlamaMoeInferFeedForward(Cell):
|
||||
self.w3.shard(strategy_matmul=(((1, 1),), ((1, 1, mp),), ((),), ((),), ((),), ((),), ((),), (1,)))
|
||||
self.w2.shard(strategy_matmul=(((1, mp),), ((1, mp, 1),), ((),), ((),), ((),), ((),), ((),), (1,)))
|
||||
else:
|
||||
self.mul.shard((1, mp), (1, mp))
|
||||
self.mul.shard(((1, mp), (1, mp)))
|
||||
self.w1.shard(strategy_matmul=(((1, 1),), ((1, 1, mp),), ((),), ((),), ((),), ((),), ((),), (1,)),
|
||||
strategy_activation=((1, 1, mp, 1),))
|
||||
self.w3.shard(strategy_matmul=(((1, 1),), ((1, 1, mp),), ((),), ((),), ((),), ((),), ((),), (1,)))
|
||||
@@ -771,8 +771,10 @@ class LlamaFeedForwardWithMoE(Cell):
|
||||
self.mul.shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
self.add.shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
self.sigmoid.shard(((dp, 1, 1),))
|
||||
|
||||
self.routed_experts.ffn.shard(parallel_config)
|
||||
if self.use_moe_infer:
|
||||
self.routed_experts.shard(parallel_config)
|
||||
else:
|
||||
self.routed_experts.ffn.shard(parallel_config)
|
||||
self.shared_experts.shard(parallel_config)
|
||||
self.shared_experts.mul.shard(((dp, 1, mp), (dp, 1, mp)))
|
||||
|
||||
|
||||
@@ -41,6 +41,6 @@ from .layers import (
|
||||
RotaryEmbedding
|
||||
)
|
||||
from .local_block_sparse_attention import LocalBlockSparseAttention
|
||||
from .paged_attention_mgr import PagedAttentionMgr
|
||||
from .kv_cache_mgr import KVCacheMgr
|
||||
|
||||
__all__ = []
|
||||
|
||||
@@ -20,9 +20,11 @@ from mindspore import ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.auto_generate import PagedAttention
|
||||
|
||||
from mindformers.modules import PagedAttentionMgr
|
||||
from mindformers.modules import KVCacheMgr
|
||||
from mindformers.modules.flash_attention import FlashAttention
|
||||
from mindformers.tools.utils import get_infer_boost
|
||||
|
||||
|
||||
class InferRotaryEmbedding(Cell):
|
||||
@@ -59,6 +61,22 @@ class InferRotaryEmbedding(Cell):
|
||||
self.rotary_embedding_op.shard(((dp, 1, mp), (dp, 1, mp), (1, 1), (1, 1), (dp,)))
|
||||
|
||||
|
||||
class AttentionInput:
|
||||
"""Infer Attention Input."""
|
||||
def __init__(self, query, key, value, batch_valid_length, block_tables, slot_mapping):
|
||||
self.query = query
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.batch_valid_length = batch_valid_length
|
||||
self.block_tables = block_tables
|
||||
self.slot_mapping = slot_mapping
|
||||
|
||||
def __repr__(self):
|
||||
return f"AttentionInput(query={self.query}, key={self.key}, value={self.value}," \
|
||||
f"batch_valid_length={self.batch_valid_length}, block_tables={self.block_tables}," \
|
||||
f"slot_mapping={self.slot_mapping})"
|
||||
|
||||
|
||||
class InferAttention(Cell):
|
||||
"""Infer Attention Layer.
|
||||
|
||||
@@ -215,6 +233,7 @@ class InferAttention(Cell):
|
||||
sparse_mode=0,
|
||||
block_size=16,
|
||||
num_blocks=1024,
|
||||
batch_size=32,
|
||||
seq_length=-1,
|
||||
is_dynamic=True,
|
||||
use_flash_attention=True,
|
||||
@@ -238,6 +257,8 @@ class InferAttention(Cell):
|
||||
self.sparse_mode = sparse_mode
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.use_flash_attention = use_flash_attention
|
||||
self.use_alibi_mask = use_alibi_mask
|
||||
self.use_rope_rotary_emb = use_rope_rotary_emb
|
||||
@@ -259,16 +280,20 @@ class InferAttention(Cell):
|
||||
self.inv_norm_factor = Tensor(1.0 / math.sqrt(self.head_dim), dtype=compute_dtype)
|
||||
self.not_equal = P.NotEqual()
|
||||
self.n_rep = self.n_head // self.n_kv_head
|
||||
if self.use_alibi_mask:
|
||||
self.add_alibi = P.Add()
|
||||
self.use_attention_mask = True
|
||||
self.enable_infer_boost = get_infer_boost()
|
||||
self.is_dynamic = is_dynamic
|
||||
if self.is_dynamic:
|
||||
self.input_layout = "TH"
|
||||
self.use_attention_mask = True
|
||||
self.parallel_decoding = parallel_decoding
|
||||
if self.enable_infer_boost:
|
||||
if self.is_dynamic:
|
||||
self.input_layout = "TH"
|
||||
self.use_attention_mask = True
|
||||
else:
|
||||
self.input_layout = "BSH"
|
||||
self.use_attention_mask = False
|
||||
else:
|
||||
self.input_layout = "BSH"
|
||||
self.use_attention_mask = False
|
||||
self.input_layout = "BNSD"
|
||||
self.use_attention_mask = True
|
||||
|
||||
if self.use_flash_attention:
|
||||
self.flash_attention = FlashAttention(head_num=self.n_head,
|
||||
@@ -281,16 +306,17 @@ class InferAttention(Cell):
|
||||
use_alibi_mask=self.use_alibi_mask,
|
||||
input_layout=self.input_layout)
|
||||
|
||||
kv_shape = (self.num_blocks, self.block_size, self.n_kv_head, self.head_dim)
|
||||
self.paged_attention_mgr = PagedAttentionMgr(self.pa_n_head_split,
|
||||
self.head_dim,
|
||||
self.pa_n_kv_head_split,
|
||||
kv_shape,
|
||||
seq_length,
|
||||
compute_dtype=self.compute_dtype,
|
||||
parallel_decoding=parallel_decoding,
|
||||
chunk_prefill=chunk_prefill,
|
||||
)
|
||||
self.kv_cache_mgr = KVCacheMgr(self.n_kv_head,
|
||||
self.head_dim,
|
||||
num_blocks=self.num_blocks,
|
||||
block_size=self.block_size,
|
||||
batch_size=self.batch_size,
|
||||
seq_length=self.seq_length,
|
||||
compute_dtype=self.compute_dtype)
|
||||
|
||||
self.paged_attention = PagedAttention(self.pa_n_head_split,
|
||||
self.scale_value,
|
||||
self.pa_n_kv_head_split)
|
||||
if use_rope_rotary_emb:
|
||||
self.rotary_embedding = InferRotaryEmbedding(self.rotary_cos_format)
|
||||
|
||||
@@ -425,12 +451,51 @@ class InferAttention(Cell):
|
||||
|
||||
raise ValueError("FlashAttention input layout:{} is not supported.".format(self.input_layout))
|
||||
|
||||
def _incre_attention(self, query, batch_valid_length, block_tables, alibi_mask=None, attn_mask=None,
|
||||
q_seq_lens=None):
|
||||
if self.use_alibi_mask:
|
||||
return self.paged_attention_mgr.paged_attn_with_alibi(query, batch_valid_length, block_tables, alibi_mask)
|
||||
return self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables, attn_mask=attn_mask,
|
||||
q_seq_lens=q_seq_lens)
|
||||
def _infer_boost_attention(self, query, key, value, batch_valid_length, block_tables, slot_mapping, attn_mask=None, alibi_mask=None, prefix_keys_values=None,
|
||||
q_seq_lens=None):
|
||||
"""The forward compute of infer Attention with boost."""
|
||||
if prefix_keys_values is not None:
|
||||
prefix_len = prefix_keys_values.shape[2]
|
||||
slot_mapping = slot_mapping + self.cast(self.not_equal(slot_mapping, -1), mstype.int32) * prefix_len
|
||||
if self.is_first_iteration:
|
||||
key, value = self._cat_prefix(key, value, prefix_keys_values)
|
||||
|
||||
key_cache, value_cache = self.kv_cache_mgr(key, value, slot_mapping, batch_valid_length)
|
||||
|
||||
if self.chunk_prefill:
|
||||
return self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
|
||||
None, None, attn_mask, q_seq_lens)
|
||||
|
||||
if self.is_first_iteration:
|
||||
return self._prefill_attention(query, key, value, attn_mask, alibi_mask, batch_valid_length,
|
||||
batch_valid_length)
|
||||
else:
|
||||
if self.parallel_decoding:
|
||||
return self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
|
||||
None, None, attn_mask, q_seq_lens)
|
||||
return self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length)
|
||||
|
||||
def _infer_normal_attention(self, query, key, value, batch_valid_length, attn_mask):
|
||||
"""The forward compute of infer Attention without boost."""
|
||||
bs, seq_len, _ = query.shape
|
||||
key_seq_len = key.shape[1]
|
||||
value_seq_len = value.shape[1]
|
||||
# (B,S,H) -> (B,N,S,D)
|
||||
query = self.transpose(self.reshape(query, (bs, seq_len, self.n_head, self.head_dim)), (0, 2, 1, 3))
|
||||
key = self.transpose(self.reshape(key, (bs, key_seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3))
|
||||
value = self.transpose(self.reshape(value, (bs, value_seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3))
|
||||
|
||||
if self.is_first_iteration:
|
||||
batch_valid_length = batch_valid_length * 0
|
||||
key_cache, value_cache = self.kv_cache_mgr(key, value, None, batch_valid_length)
|
||||
|
||||
if self.use_flash_attention:
|
||||
attention = self.flash_attention(query, key_cache, value_cache, attn_mask)
|
||||
return self._merge_heads(attention)
|
||||
|
||||
key_cache = self._repeat_kv(key_cache, self.n_rep)
|
||||
value_cache = self._repeat_kv(value_cache, self.n_rep)
|
||||
return self._core_attention(query, key_cache, value_cache, attn_mask)
|
||||
|
||||
def construct(self, query, key, value, batch_valid_length, block_tables, slot_mapping, freqs_cis=None,
|
||||
attn_mask=None, alibi_mask=None, prefix_keys_values=None, q_seq_lens=None):
|
||||
@@ -438,23 +503,9 @@ class InferAttention(Cell):
|
||||
if self.use_rope_rotary_emb:
|
||||
query, key = self._apply_rotary_pos_emb(query, key, freqs_cis, batch_valid_length)
|
||||
|
||||
if prefix_keys_values is not None:
|
||||
prefix_len = prefix_keys_values.shape[2]
|
||||
slot_mapping = slot_mapping + self.cast(self.not_equal(slot_mapping, -1), mstype.int32) * prefix_len
|
||||
if self.is_first_iteration:
|
||||
key, value = self._cat_prefix(key, value, prefix_keys_values)
|
||||
|
||||
key_out = self.paged_attention_mgr(key, value, slot_mapping, batch_valid_length)
|
||||
query = ops.depend(query, key_out)
|
||||
|
||||
if self.chunk_prefill:
|
||||
return self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables, attn_mask=attn_mask,
|
||||
q_seq_lens=q_seq_lens)
|
||||
|
||||
if self.is_first_iteration:
|
||||
return self._prefill_attention(query, key, value, attn_mask, alibi_mask, batch_valid_length,
|
||||
batch_valid_length)
|
||||
return self._incre_attention(query, batch_valid_length, block_tables, alibi_mask, attn_mask, q_seq_lens)
|
||||
if self.enable_infer_boost:
|
||||
return self._infer_boost_attention(query, key, value, batch_valid_length, block_tables, slot_mapping, attn_mask, alibi_mask, prefix_keys_values, q_seq_lens)
|
||||
return self._infer_normal_attention(query, key, value, batch_valid_length, attn_mask)
|
||||
|
||||
def shard(self, parallel_config):
|
||||
"""Parallel strategy configuratiuon interface."""
|
||||
@@ -465,7 +516,11 @@ class InferAttention(Cell):
|
||||
self.rotary_embedding.shard(parallel_config)
|
||||
if self.use_flash_attention:
|
||||
self.flash_attention.shard(parallel_config)
|
||||
self.paged_attention_mgr.shard(parallel_config)
|
||||
self.kv_cache_mgr.shard(parallel_config)
|
||||
if self.parallel_decoding:
|
||||
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,), (dp, 1), (1,)))
|
||||
else:
|
||||
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,)))
|
||||
|
||||
self.transpose.shard(((dp, 1, mp, 1),))
|
||||
self.merger_head_transpose.shard(((dp, mp, 1, 1),))
|
||||
@@ -473,8 +528,6 @@ class InferAttention(Cell):
|
||||
self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
|
||||
self.mul.shard(((dp, mp, 1, 1), ()))
|
||||
self.add.shard(((dp, 1, 1, 1), (dp, mp, 1, 1)))
|
||||
if self.use_alibi_mask:
|
||||
self.add_alibi.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
|
||||
self.softmax.shard(((dp, mp, 1, 1),))
|
||||
self.tile_kv.shard(((dp, mp, 1, 1),))
|
||||
return self
|
||||
|
||||
80
mindformers/modules/kv_cache_mgr.py
Normal file
80
mindformers/modules/kv_cache_mgr.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Copyright 20244 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""KV Cache Attention Manager for inference."""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn, Parameter, ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import Zero
|
||||
from mindspore.ops.auto_generate import KVCacheScatterUpdate, ReshapeAndCache
|
||||
|
||||
from mindformers.tools.utils import get_infer_boost
|
||||
|
||||
|
||||
class KVCacheMgr(nn.Cell):
|
||||
"""KV Cache Manager."""
|
||||
|
||||
def __init__(self,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
num_blocks=1024,
|
||||
block_size=128,
|
||||
batch_size=32,
|
||||
seq_length=4096,
|
||||
compute_dtype=mstype.float16):
|
||||
super().__init__()
|
||||
self.n_kv_head = n_kv_head
|
||||
self.head_dim = head_dim
|
||||
self.num_blocks = num_blocks
|
||||
self.block_size = block_size
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.enable_infer_boost = get_infer_boost()
|
||||
print("enable_infer_boost-----",self.enable_infer_boost)
|
||||
if self.enable_infer_boost:
|
||||
kv_shape = (self.num_blocks, self.block_size, self.n_kv_head, self.head_dim)
|
||||
self.reshape_and_cache = ReshapeAndCache()
|
||||
else:
|
||||
kv_shape = (self.batch_size, self.n_kv_head, self.seq_length, self.head_dim)
|
||||
self.kv_cache_scatter_update = KVCacheScatterUpdate()
|
||||
|
||||
self.key_cache = Parameter(Tensor(shape=kv_shape, dtype=compute_dtype, init=Zero()), name="key_cache",
|
||||
requires_grad=False)
|
||||
self.value_cache = Parameter(Tensor(shape=kv_shape, dtype=compute_dtype, init=Zero()), name="value_cache",
|
||||
requires_grad=False)
|
||||
print("KVCacheMgr--finish--------------------")
|
||||
|
||||
def construct(self, key_update, value_update, slot_mapping=None, batch_valid_length=None):
|
||||
"""The forward compute of KVCache for Attention."""
|
||||
key_cache = self.key_cache
|
||||
value_cache = self.value_cache
|
||||
if self.enable_infer_boost:
|
||||
self.reshape_and_cache(key_update, value_update, self.key_cache, self.value_cache, slot_mapping)
|
||||
else:
|
||||
# update shape: [real_bs, n_head, max_seqlen, head_dim]
|
||||
self.kv_cache_scatter_update(self.key_cache, batch_valid_length, key_update, -2, 'update')
|
||||
self.kv_cache_scatter_update(self.value_cache, batch_valid_length, value_update, -2, 'update')
|
||||
key_cache = ops.depend(key_cache, key_update)
|
||||
value_cache = ops.depend(value_cache, value_update)
|
||||
return key_cache, value_cache
|
||||
|
||||
def shard(self, parallel_config):
|
||||
"""The shard strategy."""
|
||||
dp = 1 if parallel_config is None else parallel_config.data_parallel
|
||||
mp = 1 if parallel_config is None else parallel_config.model_parallel
|
||||
if self.enable_infer_boost:
|
||||
self.reshape_and_cache.shard(((dp, 1, mp), (dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (1,)))
|
||||
else:
|
||||
self.kv_cache_scatter_update.shard(((dp, mp, 1, 1), (1,), (dp, mp, 1, 1)))
|
||||
@@ -1,254 +0,0 @@
|
||||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""KVCache Manager for inference."""
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import nn, Parameter, ops
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class KVCacheMgr(nn.Cell):
|
||||
"""KVCache Manager."""
|
||||
def __init__(self,
|
||||
n_head,
|
||||
head_dim,
|
||||
max_batch_size=8,
|
||||
max_seq_length=4096,
|
||||
compute_dtype=mstype.float16,
|
||||
is_dynamic=False,
|
||||
use_kvcache_op=True,
|
||||
is_flexible_shape=False):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.head_dim = head_dim
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_length = max_seq_length
|
||||
self.dtype = compute_dtype
|
||||
self.use_kvcache_op = use_kvcache_op
|
||||
self.is_dynamic = is_dynamic
|
||||
self.is_flexible_shape = is_flexible_shape
|
||||
self.is_first_iteration = True
|
||||
|
||||
self.cache_length_tensor = Tensor([max_batch_size * max_seq_length], dtype=mstype.int32)
|
||||
self.cache_pad_tensor = Tensor([3], dtype=mstype.int64)
|
||||
self.seq_length_tensor = Tensor([max_seq_length], dtype=mstype.int32)
|
||||
self.seq_length_tensor_pad = Tensor([max_seq_length, 3], dtype=mstype.int64)
|
||||
self.seqlen_axis_tensor_pad = Tensor([2, 3], dtype=mstype.int64)
|
||||
self.pad_before = Tensor([0, 0, 0, 0, 0], mstype.int32)
|
||||
self.pad_after = Tensor([0, 0], mstype.int32)
|
||||
self.pad_zero = Tensor(0, compute_dtype)
|
||||
|
||||
if self.use_kvcache_op:
|
||||
# pylint: disable=W0212
|
||||
self.prompt_kvcache = P._inner_ops.PromptKVCache()
|
||||
# pylint: disable=W0212
|
||||
self.decoder_kvcache = P._inner_ops.DecoderKVCache()
|
||||
else:
|
||||
self.add = P.Add()
|
||||
self.mul = P.Mul()
|
||||
self.assign = P.Assign()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.sub = P.Sub()
|
||||
self.div = P.Div()
|
||||
self.pad = P.PadV3()
|
||||
self.slice = P.StridedSlice()
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
|
||||
|
||||
kv_shape = (max_batch_size, n_head, max_seq_length, head_dim)
|
||||
self.key_past = Parameter(Tensor(np.zeros(kv_shape), compute_dtype), name="key_past", requires_grad=False)
|
||||
self.value_past = Parameter(Tensor(np.zeros(kv_shape), compute_dtype), name="value_past", requires_grad=False)
|
||||
|
||||
def shard(self, parallel_config):
|
||||
"""shard"""
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
self.pad.shard(((dp, mp, 1, 1), (1,), ()))
|
||||
self.slice.shard(((dp, mp, 1, 1),))
|
||||
if self.use_kvcache_op:
|
||||
self.prompt_kvcache.shard(((dp, mp, 1, 1), (dp, mp, 1, 1), (dp,), (1,), (1,), (1,), (1,)))
|
||||
self.decoder_kvcache.shard(((dp, mp, 1, 1), (dp, mp, 1, 1), (dp,), (1,), (1,), (1,), (1,)))
|
||||
else:
|
||||
self.add.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
|
||||
self.mul.shard(((dp, mp, 1, 1), (dp, 1, 1, 1)))
|
||||
self.assign.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
|
||||
|
||||
def padding(self, key, value, seq_length):
|
||||
"""padding key, value"""
|
||||
pad_length = self.sub(self.seq_length_tensor, seq_length)
|
||||
# calculate padding parameter: (0, 0),(0,0),(0,pad_length),(0,0), append values of 'pad_length' in axis
|
||||
pad_config = self.concat((self.pad_before, pad_length, self.pad_after))
|
||||
key_padding = self.pad(key, pad_config, self.pad_zero)
|
||||
value_padding = self.pad(value, pad_config, self.pad_zero)
|
||||
return key_padding, value_padding
|
||||
|
||||
def trimming(self, key, value, zactivate_len, batch_size):
|
||||
"""tramming key, value"""
|
||||
if self.is_flexible_shape:
|
||||
key = self.reshape(key, (batch_size, self.n_head, -1, self.head_dim))
|
||||
value = self.reshape(value, (batch_size, self.n_head, -1, self.head_dim))
|
||||
if zactivate_len is not None:
|
||||
act_len = self.shape(zactivate_len)[0]
|
||||
key = self.slice(key, (0, 0, 0, 0), (batch_size, self.n_head, act_len, self.head_dim), (1, 1, 1, 1))
|
||||
value = self.slice(value, (0, 0, 0, 0), (batch_size, self.n_head, act_len, self.head_dim), (1, 1, 1, 1))
|
||||
elif not self.is_flexible_shape:
|
||||
key = self.slice(key, (0, 0, 0, 0),
|
||||
(batch_size, self.n_head, self.max_seq_length, self.head_dim), (1, 1, 1, 1))
|
||||
value = self.slice(value, (0, 0, 0, 0),
|
||||
(batch_size, self.n_head, self.max_seq_length, self.head_dim), (1, 1, 1, 1))
|
||||
return key, value
|
||||
|
||||
def auto_caching(self, key_update, value_update, batch_valid_length, seq_length_tensor_pad, batch_index_pad=None):
|
||||
"""use kvcache op to cache key, value"""
|
||||
# key_update shape: [real_bs, n_head, max_seqlen, head_dim]
|
||||
if self.is_first_iteration:
|
||||
batch_valid_length = batch_valid_length * 0
|
||||
self.prompt_kvcache(self.key_past, key_update, batch_valid_length, batch_index_pad,
|
||||
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
|
||||
self.prompt_kvcache(self.value_past, value_update, batch_valid_length, batch_index_pad,
|
||||
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
|
||||
return None
|
||||
|
||||
key_cache = self.key_past
|
||||
value_cache = self.value_past
|
||||
key_update = self.decoder_kvcache(self.key_past, key_update, batch_valid_length, batch_index_pad,
|
||||
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
|
||||
value_update = self.decoder_kvcache(self.value_past, value_update, batch_valid_length, batch_index_pad,
|
||||
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
|
||||
key_cache = ops.depend(key_cache, key_update)
|
||||
value_cache = ops.depend(value_cache, value_update)
|
||||
return key_cache, value_cache
|
||||
|
||||
def manual_caching(self, key_update, value_update, valid_length_vector, batch_size):
|
||||
"""use assign to cache key, value"""
|
||||
# key_update shape: [real_bs, n_head, 1, head_dim]
|
||||
if self.is_first_iteration:
|
||||
if self.is_dynamic:
|
||||
self.assign(self.key_past,
|
||||
self.reshape(key_update, (self.max_batch_size, self.n_head, -1, self.head_dim)))
|
||||
self.assign(self.value_past,
|
||||
self.reshape(value_update, (self.max_batch_size, self.n_head, -1, self.head_dim)))
|
||||
else:
|
||||
self.assign(self.key_past, self.mul(key_update, valid_length_vector))
|
||||
self.assign(self.value_past, self.mul(value_update, valid_length_vector))
|
||||
return None
|
||||
|
||||
if self.is_dynamic:
|
||||
key = self.add(self.reshape(self.key_past, (batch_size, self.n_head, -1, self.head_dim)),
|
||||
self.mul(key_update, valid_length_vector))
|
||||
value = self.add(self.reshape(self.value_past, (batch_size, self.n_head, -1, self.head_dim)),
|
||||
self.mul(value_update, valid_length_vector))
|
||||
self.assign(self.key_past,
|
||||
self.reshape(key, (self.max_batch_size, self.n_head, -1, self.head_dim)))
|
||||
self.assign(self.value_past,
|
||||
self.reshape(value, (self.max_batch_size, self.n_head, -1, self.head_dim)))
|
||||
else:
|
||||
key = self.add(self.key_past, self.mul(key_update, valid_length_vector))
|
||||
value = self.add(self.value_past, self.mul(value_update, valid_length_vector))
|
||||
self.assign(self.key_past, key)
|
||||
self.assign(self.value_past, value)
|
||||
# key shape: [real_bs, n_head, max_cache_len // real_bs, head_dim]
|
||||
return key, value
|
||||
|
||||
def construct(self, key, value, kvcache_inputs=None):
|
||||
"""The forward compute of KVCacheMgr."""
|
||||
# add inputs check
|
||||
batch_valid_length, zactivate_len, batch_index_pad, seq_length_tensor_pad = kvcache_inputs
|
||||
if not self.use_kvcache_op:
|
||||
batch_valid_length = self.cast(batch_valid_length, self.dtype)
|
||||
batch_size, _, seq_length, _ = self.shape(key)
|
||||
if self.is_first_iteration:
|
||||
if self.is_dynamic:
|
||||
key_padding, value_padding = self.padding(key, value, seq_length=seq_length)
|
||||
else:
|
||||
key_padding, value_padding = key, value
|
||||
if self.use_kvcache_op:
|
||||
self.auto_caching(key_padding, value_padding, batch_valid_length,
|
||||
seq_length_tensor_pad, batch_index_pad)
|
||||
else:
|
||||
self.manual_caching(key_padding, value_padding, batch_valid_length, batch_size=batch_size)
|
||||
else:
|
||||
if self.use_kvcache_op:
|
||||
key, value = self.auto_caching(key, value, batch_valid_length,
|
||||
seq_length_tensor_pad, batch_index_pad)
|
||||
else:
|
||||
key, value = self.manual_caching(key, value, batch_valid_length, batch_size=batch_size)
|
||||
key, value = self.trimming(key, value, zactivate_len, batch_size=batch_size)
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
class KVCachePreprocess(nn.Cell):
|
||||
"""KVCache Manager."""
|
||||
def __init__(self,
|
||||
max_batch_size=8,
|
||||
max_seq_length=4096,
|
||||
is_dynamic=False,
|
||||
use_kvcache_op=False,
|
||||
is_flexible_shape=False):
|
||||
super().__init__()
|
||||
self.is_dynamic = is_dynamic
|
||||
self.use_kvcache_op = use_kvcache_op
|
||||
self.is_flexible_shape = is_flexible_shape
|
||||
|
||||
self.max_cache_length = max_batch_size * max_seq_length
|
||||
range_len = self.max_cache_length if self.is_flexible_shape else max_seq_length
|
||||
self.range = Tensor(np.arange(range_len).reshape((1, 1, -1)), mstype.int32)
|
||||
self.cache_length_tensor = Tensor([max_batch_size * max_seq_length], dtype=mstype.int32)
|
||||
self.cache_pad_tensor = Tensor([3], dtype=mstype.int64)
|
||||
self.seq_length_tensor = Tensor([max_seq_length], dtype=mstype.int32)
|
||||
self.seq_length_tensor_pad = Tensor([max_seq_length, 3], dtype=mstype.int64)
|
||||
self.is_first_iteration = True
|
||||
|
||||
self.slice = P.StridedSlice()
|
||||
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
|
||||
self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
|
||||
self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
|
||||
self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
|
||||
self.div = P.Div()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, batch_size, batch_valid_length=None, batch_index=None, zactivate_len=None):
|
||||
"""precompute kvcache inputs"""
|
||||
seq_range = self.range
|
||||
if self.is_dynamic and self.is_flexible_shape and not self.use_kvcache_op:
|
||||
seq_range = self.slice(seq_range, (0, 0, 0), (1, 1, self.max_cache_length // batch_size), (1, 1, 1))
|
||||
|
||||
if self.use_kvcache_op:
|
||||
if batch_index is None:
|
||||
batch_index = ops.arange(0, batch_size, 1)
|
||||
batch_index_pad = self.concat((batch_index, self.cache_pad_tensor))
|
||||
seq_length_tensor_pad = self.get_seq_length_tensor_pad(batch_size=batch_size)
|
||||
batch_valid_length = self.cast(self.reshape(batch_valid_length, (-1,)), mstype.int64)
|
||||
kvcache_inputs = (batch_valid_length, zactivate_len, batch_index_pad, seq_length_tensor_pad)
|
||||
else:
|
||||
if self.is_first_iteration:
|
||||
valid_length_vector = self.less(seq_range, self.reshape(batch_valid_length, (-1, 1, 1)))
|
||||
else:
|
||||
valid_length_vector = self.equal(seq_range, self.reshape(batch_valid_length, (-1, 1, 1)))
|
||||
valid_length_vector = self.expand_dims(valid_length_vector, 3)
|
||||
kvcache_inputs = (valid_length_vector, zactivate_len, None, None)
|
||||
return kvcache_inputs
|
||||
|
||||
def get_seq_length_tensor_pad(self, batch_size):
|
||||
"""get seq_length_tensor_pad"""
|
||||
if self.is_flexible_shape:
|
||||
max_seq_length = self.div(self.cache_length_tensor, batch_size).astype(mstype.int64)
|
||||
return self.concat((max_seq_length, self.cache_pad_tensor))
|
||||
return self.seq_length_tensor_pad
|
||||
@@ -1,87 +0,0 @@
|
||||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Paged Attention Manager for inference."""
|
||||
import math
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn, Parameter
|
||||
from mindspore import ops as P
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
class PagedAttentionMgr(nn.Cell):
|
||||
"""Paged Attention Manager."""
|
||||
|
||||
def __init__(self,
|
||||
n_heads,
|
||||
head_dim,
|
||||
n_kv_heads,
|
||||
kv_shape,
|
||||
seq_length=-1,
|
||||
compute_dtype=mstype.float16,
|
||||
parallel_decoding=False,
|
||||
chunk_prefill=False):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.seq_length = seq_length
|
||||
self.is_first_iteration = True
|
||||
self.scale_value = 1 / math.sqrt(self.head_dim)
|
||||
self.key_cache = Parameter(initializer('zeros', kv_shape, compute_dtype), name="key_cache",
|
||||
requires_grad=False)
|
||||
self.value_cache = Parameter(initializer('zeros', kv_shape, compute_dtype), name="value_cache",
|
||||
requires_grad=False)
|
||||
|
||||
self.reshape_and_cache = P.auto_generate.ReshapeAndCache()
|
||||
self.paged_attention = P.auto_generate.PagedAttention(self.n_heads,
|
||||
self.scale_value,
|
||||
self.n_kv_heads)
|
||||
self.paged_attention_with_alibi = P.auto_generate.PagedAttentionMask(self.n_heads,
|
||||
self.scale_value,
|
||||
self.n_kv_heads)
|
||||
self.parallel_decoding = parallel_decoding
|
||||
self.chunk_prefill = chunk_prefill
|
||||
|
||||
# pylint: disable=W0613
|
||||
def construct(self, key, value, slot_mapping, batch_valid_length=None):
|
||||
"""The forward compute of KVCache for Paged Attention."""
|
||||
return self.reshape_and_cache(key, value, self.key_cache, self.value_cache, slot_mapping)
|
||||
|
||||
def paged_attn(self, query, batch_valid_length, block_tables, attn_mask=None, q_seq_lens=None):
|
||||
"""The forward compute of Paged Attention."""
|
||||
if self.parallel_decoding or self.chunk_prefill:
|
||||
attn_mask = attn_mask.astype(mstype.bool_).astype(query.dtype) * -10000
|
||||
return self.paged_attention(query, self.key_cache, self.value_cache, block_tables, batch_valid_length,
|
||||
None, None, attn_mask, q_seq_lens)
|
||||
return self.paged_attention(query, self.key_cache, self.value_cache, block_tables, batch_valid_length)
|
||||
|
||||
def paged_attn_with_alibi(self, query, batch_valid_length, block_tables, alibi_tensor):
|
||||
"""The forward compute of KVCache for Paged Attention with alibi tensor."""
|
||||
return self.paged_attention_with_alibi(query, self.key_cache, self.value_cache,
|
||||
block_tables, batch_valid_length, None, None, alibi_tensor)
|
||||
|
||||
def shard(self, parallel_config):
|
||||
"""The shard strategy."""
|
||||
dp = 1 if parallel_config is None else parallel_config.data_parallel
|
||||
mp = 1 if parallel_config is None else parallel_config.model_parallel
|
||||
self.reshape_and_cache.shard(((dp, 1, mp), (dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (1,)))
|
||||
if self.parallel_decoding:
|
||||
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,), (dp, 1), (1,)))
|
||||
else:
|
||||
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,)))
|
||||
self.paged_attention_with_alibi.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,),
|
||||
(dp, mp, 1, 1)))
|
||||
@@ -38,15 +38,16 @@ except ImportError:
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.auto_generate import MoeFinalizeRouting
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer import Dense
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
||||
|
||||
from mindformers.modules.transformer.op_parallel_config import default_moeparallel_config, MoEParallelConfig
|
||||
from mindformers.version_control import check_valid_moefinalizerouting_op
|
||||
from mindformers.modules.transformer.moe_utils import ZLoss
|
||||
from mindformers.tools.utils import get_predict_run_mode
|
||||
from mindformers.version_control import check_valid_moe_big_kernel
|
||||
|
||||
__all__ = [
|
||||
"MoEConfig"]
|
||||
@@ -1793,6 +1794,7 @@ class TopkRouterV2(Cell):
|
||||
return comm_load_loss
|
||||
|
||||
|
||||
|
||||
class MoEInfer(Cell):
|
||||
r"""
|
||||
MoEInfer. Routing each tokens to the topk expert and calculating the final output.
|
||||
@@ -1816,10 +1818,11 @@ class MoEInfer(Cell):
|
||||
parallel_config):
|
||||
super(MoEInfer, self).__init__()
|
||||
self.hidden_size = dim
|
||||
self.expert_dim = moe_config.expert_num
|
||||
self.topk_norm_prob = moe_config.norm_topk_prob
|
||||
self.expert_num = moe_config.expert_num
|
||||
self.norm_topk_prob = moe_config.norm_topk_prob
|
||||
self.num_experts_chosen = moe_config.num_experts_chosen
|
||||
self.moe_config = moe_config
|
||||
self.router_dense_type = moe_config.router_dense_type
|
||||
self.routed_scaling_factor = moe_config.routed_scaling_factor
|
||||
|
||||
self.ffn = ffn
|
||||
self.router = Router(d_model=self.hidden_size, moe_config=moe_config, routing_policy=None,
|
||||
@@ -1831,23 +1834,58 @@ class MoEInfer(Cell):
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
self.mod = P.Mod().shard(((1,), ()))
|
||||
self.topk = P.TopK().shard(((1, 1),))
|
||||
self.softmax = P.Softmax().shard(((1, 1),))
|
||||
self.expand_dims = P.ExpandDims().shard(((1,),))
|
||||
self.transpose_2d = P.Transpose().shard(((1, 1),))
|
||||
self.sort = P.Sort().shard(((1,),))
|
||||
self.gather = P.Gather().shard(((1, 1), (1,)))
|
||||
self.onehot = P.OneHot().shard(((1, 1), (), ()))
|
||||
self.cumsum = P.CumSum(exclusive=False).shard(((1,),))
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.moe_finalize_routing = MoeFinalizeRouting()
|
||||
self.use_moe_big_kernel = check_valid_moe_big_kernel()
|
||||
if self.use_moe_big_kernel:
|
||||
# fused op
|
||||
from mindspore.ops.auto_generate import MoeGatingTopKSoftmax, MoeInitRouting, MoeComputeExpertTokens
|
||||
self.moe_init_routing = MoeInitRouting()
|
||||
self.moe_compute_expert_tokens = MoeComputeExpertTokens()
|
||||
self.moe_gating_topk_softmax = MoeGatingTopKSoftmax()
|
||||
else:
|
||||
self.mod = ops.auto_generate.RemainderTensorScalar()
|
||||
self.topk = P.TopK()
|
||||
self.softmax = P.Softmax()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.transpose_2d = P.Transpose()
|
||||
self.sort = P.Sort()
|
||||
self.gather = P.Gather()
|
||||
self.onehot = P.OneHot()
|
||||
self.cumsum = P.CumSum(exclusive=False)
|
||||
|
||||
self.on_value = Tensor(1.0, dtype=mstype.float32)
|
||||
self.off_value = Tensor(0.0, dtype=mstype.float32)
|
||||
if check_valid_moefinalizerouting_op():
|
||||
from mindspore.ops.auto_generate import MoeFinalizeRouting
|
||||
self.moe_finalize_routing = MoeFinalizeRouting().shard(((1, 1), (1, 1), (1, 1), (1, 1), (1,), (1, 1)))
|
||||
self.on_value = Tensor(1.0, dtype=mstype.float32)
|
||||
self.off_value = Tensor(0.0, dtype=mstype.float32)
|
||||
|
||||
def tensor_sort(self, input_tensor, expert_ids):
|
||||
def tensor_sort(self, input_tensor, expert_index, row_index):
|
||||
"""dispatch and get unsort map for routing"""
|
||||
expanded_x, expanded_row_idx, expanded_expert_idx = \
|
||||
self.moe_init_routing(input_tensor, row_index, expert_index, self.shape(input_tensor)[0])
|
||||
|
||||
expert_tokens = self.moe_compute_expert_tokens(expanded_expert_idx, self.expert_num)
|
||||
return expanded_x, expert_tokens, expanded_row_idx
|
||||
|
||||
def tensor_moe_finalize_routing(self, input_tensor, expert_weight, expert_index, unsort_map):
|
||||
"""calculate the final output by multiplying FeedForward's output and experts' weight in MoeFinalizeRouting"""
|
||||
input_shape = input_tensor.shape
|
||||
x1 = self.zeros((input_shape[0] // self.num_experts_chosen, input_shape[-1]), input_tensor.dtype)
|
||||
x2 = None
|
||||
bias = self.zeros((self.expert_num, input_shape[-1]), input_tensor.dtype)
|
||||
expert_weight = self.cast(expert_weight, input_tensor.dtype)
|
||||
output_tensor = self.moe_finalize_routing(input_tensor, x1, x2, bias, expert_weight, unsort_map, expert_index)
|
||||
return output_tensor
|
||||
|
||||
def gating_topk_softmax(self, input_tensor):
|
||||
"""calculate the expert value and expert index in MoeGatingTopKSoftmax"""
|
||||
# (N, E)
|
||||
gating_logits = self.gating(input_tensor.astype(self.router_dense_type))
|
||||
# (N, num_experts_chosen), (N, num_experts_chosen), (N, num_experts_chosen)
|
||||
expert_val, expert_index, row_index = self.moe_gating_topk_softmax(gating_logits, finished=None,
|
||||
k=self.num_experts_chosen)
|
||||
|
||||
return expert_val, expert_index, row_index
|
||||
|
||||
def tensor_sort_self_define(self, input_tensor, expert_ids):
|
||||
"""dispatch and get unsort map for routing"""
|
||||
expert_shape = expert_ids.shape
|
||||
transposed_index = self.transpose_2d(expert_ids, (1, 0)) # (N, k) -> (k, N)
|
||||
@@ -1856,7 +1894,7 @@ class MoEInfer(Cell):
|
||||
|
||||
inter_map = self.mod(sort_map, expert_shape[0])
|
||||
output_tensor = self.gather(input_tensor, inter_map, 0)
|
||||
expert_mask = self.onehot(reshaped_index, self.expert_dim, self.on_value, self.off_value)
|
||||
expert_mask = self.onehot(reshaped_index, self.expert_num, self.on_value, self.off_value)
|
||||
expert_cnt = ops.sum(expert_mask, 0)
|
||||
group_list = self.cast(self.cumsum(expert_cnt, 0), mstype.int64)
|
||||
|
||||
@@ -1864,43 +1902,90 @@ class MoEInfer(Cell):
|
||||
unsort_map = self.cast(unsort_map, mstype.int32)
|
||||
return output_tensor, group_list, unsort_map
|
||||
|
||||
def tensor_moe_finalize_routing(self, input_tensor, expert_weight, expert_index, unsort_map):
|
||||
"""calculate the final output by multiplying FeedForward's output and experts' weight in MoeFinalizeRouting"""
|
||||
input_shape = input_tensor.shape # (2N, h)
|
||||
x1 = self.zeros((input_shape[0] // self.num_experts_chosen, input_shape[-1]), input_tensor.dtype)
|
||||
x2 = None
|
||||
bias = self.zeros((self.expert_dim, input_shape[-1]), input_tensor.dtype)
|
||||
expert_weight = self.cast(expert_weight, input_tensor.dtype)
|
||||
output_tensor = self.moe_finalize_routing(input_tensor, x1, x2, bias, expert_weight, unsort_map, expert_index)
|
||||
return output_tensor
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""forward process"""
|
||||
input_tensor_shape = self.shape(input_tensor) # (B, S, H)
|
||||
def moe_compute_self_define(self, input_tensor):
|
||||
"""moe compute self define"""
|
||||
input_dtype = input_tensor.dtype
|
||||
input_tensor = self.reshape(
|
||||
input_tensor, (-1, self.hidden_size)) # (bs, seq/1, h) -> (bs*seq, h) : use N replace bs*seq
|
||||
|
||||
# gating + topk + softmax
|
||||
gating_logits = self.gating(input_tensor.astype(mstype.float32)) # (N, h) * (h, E) -> (bs*seq, E)
|
||||
routing_weights = self.softmax(gating_logits.astype(mstype.float32)) # (N, E) -> (N, E)
|
||||
expert_val, expert_index = self.topk(routing_weights, self.num_experts_chosen) # (N, E) -> (N, 2), (N, 2)
|
||||
|
||||
if self.moe_config.norm_topk_prob and self.num_experts_chosen > 1:
|
||||
if self.norm_topk_prob and self.num_experts_chosen > 1:
|
||||
expert_val = self.cast(expert_val, mstype.float32)
|
||||
expert_weight = expert_val / (self.expand_dims(ops.sum(expert_val, -1), -1) + 1e-9)
|
||||
else:
|
||||
expert_weight = ops.mul(self.moe_config.routed_scaling_factor, expert_val)
|
||||
expert_weight = ops.mul(self.routed_scaling_factor, expert_val)
|
||||
|
||||
expert_weight = self.cast(expert_weight, input_dtype)
|
||||
|
||||
sorted_input_tensor, group_list, unsort_map = self.tensor_sort(input_tensor, expert_index)
|
||||
sorted_input_tensor, group_list, unsort_map = self.tensor_sort_self_define(input_tensor, expert_index)
|
||||
return sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight
|
||||
|
||||
# moeffn
|
||||
expert_output = self.ffn(sorted_input_tensor, group_list) # (N, h) (N, 2) -> (N, 2, h)
|
||||
def moe_compute(self, input_tensor):
|
||||
"""moe compute with big kernel"""
|
||||
input_dtype = input_tensor.dtype
|
||||
# (N, num_experts_chosen), (N, num_experts_chosen), (N, num_experts_chosen)
|
||||
expert_val, expert_index, row_index = self.gating_topk_softmax(input_tensor)
|
||||
|
||||
if self.norm_topk_prob and self.num_experts_chosen > 1:
|
||||
expert_val = self.cast(expert_val, mstype.float32)
|
||||
expert_weight = expert_val / (self.expand_dims(ops.sum(expert_val, -1), -1) + 1e-9)
|
||||
else:
|
||||
expert_weight = ops.mul(self.routed_scaling_factor, expert_val)
|
||||
|
||||
# (N, num_experts_chosen)
|
||||
expert_weight = self.cast(expert_weight, input_dtype)
|
||||
|
||||
# MOE recompute
|
||||
# (N * num_experts_chosen, h), (N * num_experts_chosen,), (N * num_experts_chosen,)
|
||||
sorted_input_tensor, group_list, unsort_map = self.tensor_sort(input_tensor, expert_index, row_index)
|
||||
|
||||
group_list = group_list.astype(mstype.int64)
|
||||
return sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""forward process"""
|
||||
# (B, S, H)
|
||||
input_tensor_shape = self.shape(input_tensor)
|
||||
# (bs, seq/1, h) -> (bs*seq, h) : use N replace bs*seq
|
||||
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
|
||||
|
||||
if self.use_moe_big_kernel:
|
||||
sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight = self.moe_compute(input_tensor)
|
||||
else:
|
||||
sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight =\
|
||||
self.moe_compute_self_define(input_tensor)
|
||||
|
||||
# MOEFFN
|
||||
# (N * num_experts_chosen, h)
|
||||
expert_output = self.ffn(sorted_input_tensor, group_list)
|
||||
|
||||
# MOE revert recompute
|
||||
# (N, h)
|
||||
moe_output = self.tensor_moe_finalize_routing(
|
||||
expert_output, expert_weight, expert_index, unsort_map) # -> (N, h)
|
||||
expert_output, expert_weight, expert_index, unsort_map)
|
||||
|
||||
output_tensor = self.reshape(moe_output, input_tensor_shape) # (N, h) -> (bs, seq, h)
|
||||
# (N, h) -> (bs, seq, h)
|
||||
output_tensor = self.reshape(moe_output, input_tensor_shape)
|
||||
return output_tensor
|
||||
|
||||
def shard(self, parallel_config):
|
||||
"""sharding for moe infer cell"""
|
||||
self.ffn.shard(parallel_config)
|
||||
self.expand_dims.shard(((1,),))
|
||||
self.moe_finalize_routing.shard(((1, 1), (1, 1), (1, 1), (1, 1), (1,), (1, 1)))
|
||||
|
||||
if self.use_moe_big_kernel:
|
||||
self.moe_init_routing.shard(((1, 1), (1, 1), (1, 1)))
|
||||
self.moe_compute_expert_tokens.shard(((1,),))
|
||||
self.moe_gating_topk_softmax.shard(((1, 1),))
|
||||
else:
|
||||
self.mod.shard(((1,), ()))
|
||||
self.topk.shard(((1, 1),))
|
||||
self.softmax.shard(((1, 1),))
|
||||
self.expand_dims.shard(((1,),))
|
||||
self.transpose_2d.shard(((1, 1),))
|
||||
self.sort.shard(((1,),))
|
||||
self.gather.shard(((1, 1), (1,)))
|
||||
self.onehot.shard(((1, 1), (), ()))
|
||||
self.cumsum.shard(((1,),))
|
||||
|
||||
@@ -62,7 +62,7 @@ from mindformers.version_control import get_dropout, choose_flash_attention_dtyp
|
||||
|
||||
from mindformers.tools.logger import _LogActionOnce
|
||||
from mindformers.tools.logger import logger as log
|
||||
from mindformers.tools.utils import is_pynative
|
||||
from mindformers.tools.utils import is_pynative, get_infer_boost
|
||||
|
||||
__all__ = [
|
||||
"AttentionMask",
|
||||
@@ -1007,7 +1007,7 @@ class LowerTriangularMaskWithDynamic(Cell):
|
||||
if use_past and chunk_prefill:
|
||||
self.lower_triangle_mask = Tensor(np.tril(np.ones(shape=(seq_length, seq_length), dtype=np.int8)), \
|
||||
dtype=compute_type)
|
||||
elif use_past and not self.is_pynative:
|
||||
elif use_past and get_infer_boost() and not self.is_pynative:
|
||||
if self.is_dynamic:
|
||||
mask_coeff = 1.0 if compute_type is mstype.bfloat16 else -10000.0
|
||||
self.lower_triangle_mask = Tensor(
|
||||
|
||||
@@ -479,6 +479,11 @@ def get_predict_run_mode():
|
||||
return run_mode == "predict"
|
||||
|
||||
|
||||
def get_infer_boost():
|
||||
infer_boost = context.get_jit_config().get("infer_boost")
|
||||
return infer_boost == "on"
|
||||
|
||||
|
||||
def is_main_rank(ignore_check_modelarts=False):
|
||||
return not get_real_rank() or \
|
||||
((ignore_check_modelarts or check_in_modelarts()) and get_real_rank() % get_device_num_per_node() == 0)
|
||||
|
||||
@@ -295,12 +295,12 @@ def check_valid_gmm_op():
|
||||
return True
|
||||
|
||||
|
||||
def check_valid_moefinalizerouting_op():
|
||||
"""check mindspore version is valid for groupedmatmul"""
|
||||
version_valid = is_version_ge(ms.__version__, "2.3.0")
|
||||
def check_valid_moe_big_kernel():
|
||||
"""check mindspore version is valid for moe big kernel"""
|
||||
version_valid = is_version_ge(ms.__version__, "2.5.0")
|
||||
if not version_valid:
|
||||
logger.warning(f"Current MindSpore do not support MoeFinalizeRouting, "
|
||||
f"please upgrade to 2.3.0 or later version.")
|
||||
logger.warning(f"Current MindSpore do not support Moe big kernel, "
|
||||
f"please upgrade to 2.5.0 or later version.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -221,8 +221,8 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel):
|
||||
return loss
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
|
||||
|
||||
@@ -453,3 +453,8 @@ class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel):
|
||||
input_mask = self.reshape(input_mask, (-1,))
|
||||
loss = self.loss(logits, labels, input_mask)
|
||||
return loss
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -193,7 +193,7 @@ class DeepSeekV2MoEInfer(Cell):
|
||||
|
||||
self.add.shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
|
||||
self.routed_experts.ffn.shard(parallel_config)
|
||||
self.routed_experts.shard(parallel_config)
|
||||
self.shared_experts.shard(parallel_config)
|
||||
self.shared_experts.mul.shard(((dp, 1, mp), (dp, 1, mp)))
|
||||
|
||||
|
||||
@@ -204,8 +204,8 @@ class QwenForCausalLM(QwenPreTrainedModel):
|
||||
self.lm_head.shard(strategy_matmul=((1, 1), (dp * mp, 1)))
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
|
||||
|
||||
@@ -286,8 +286,8 @@ class QwenForCausalLM(QwenPreTrainedModel):
|
||||
self.lm_head.shard(strategy_matmul=((1, 1), (dp * mp, 1)))
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
|
||||
|
||||
@@ -465,6 +465,6 @@ class TelechatForCausalLM(TelechatPreTrainedModel):
|
||||
return loss
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -142,7 +142,6 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
for layer in self.model.layers:
|
||||
layer.add_flags(is_first_iteration=is_first_iteration)
|
||||
layer.attention.add_flags(is_first_iteration=is_first_iteration)
|
||||
layer.attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
|
||||
|
||||
# pylint: disable=W0613
|
||||
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
|
||||
@@ -171,6 +170,6 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
return logits, input_ids, input_mask
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -147,8 +147,7 @@ class TelechatParallelAttention(ParallelAttention):
|
||||
if self.is_first_iteration:
|
||||
key, value = self._cat_prefix(key, value, prefix_keys_values)
|
||||
|
||||
key_out = self.paged_attention_mgr(key, value, slot_mapping)
|
||||
query = ops.depend(query, key_out)
|
||||
key_cache, value_cache = self.kv_cache_mgr(key, value, slot_mapping, batch_valid_length)
|
||||
|
||||
if self.is_first_iteration:
|
||||
if self.use_flash_attention:
|
||||
@@ -181,7 +180,8 @@ class TelechatParallelAttention(ParallelAttention):
|
||||
value = value.transpose((0, 2, 1, 3))
|
||||
context_layer = self.core_attention(query, key, value, attn_mask)
|
||||
else:
|
||||
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables)
|
||||
context_layer = self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
|
||||
None, None, None, None)
|
||||
else:
|
||||
# [B, S, H] -> [B, N, S, D]
|
||||
query = query.reshape(bs, seq_len, -1, self.head_dim).transpose((0, 2, 1, 3))
|
||||
|
||||
@@ -444,6 +444,6 @@ class TelechatForCausalLM(TelechatPreTrainedModel):
|
||||
return loss
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -462,7 +462,7 @@ class YiZhaoWithPtuning2(YiZhaoForCausalLM):
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = \
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.key_cache
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = \
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.value_cache
|
||||
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -501,6 +501,6 @@ class MixtralForCausalLM(LlamaPreTrainedModel):
|
||||
return loss
|
||||
|
||||
def kvcache(self, layer_idx):
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
|
||||
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
|
||||
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
|
||||
return key_cache, value_cache
|
||||
|
||||
@@ -51,6 +51,7 @@ def test_qwen2_0_5b_predict_standalone():
|
||||
config.model.model_config.seq_length = seq_length
|
||||
config.processor.tokenizer.vocab_file = vocab_file_path
|
||||
config.processor.tokenizer.merges_file = merges_file_path
|
||||
config.context.device_id = int(os.environ.get("DEVICE_ID", "0"))
|
||||
|
||||
# init context
|
||||
build_context(config)
|
||||
|
||||
96
tests/st/test_infer/test_infer_module/test_infer_normal.py
Normal file
96
tests/st/test_infer/test_infer_module/test_infer_normal.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test infer attention"""
|
||||
import math
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindformers.modules import KVCacheMgr
|
||||
from mindformers.modules.infer_attention import InferAttention
|
||||
from mindformers.modules.layers import FreqsMgr
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_kv_cache_mgr():
|
||||
"""
|
||||
Feature: Test the kv cache manager.
|
||||
Description: Test the forward
|
||||
Expectation: No exception
|
||||
"""
|
||||
os.environ['ASCEND_HOME_PATH'] = "/usr/local/Ascend/latest"
|
||||
jit_level = "O0"
|
||||
infer_boost = "off"
|
||||
ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
|
||||
|
||||
bsz, n_kv_head, seq_len, head_dim = 8, 16, 4096, 128
|
||||
compute_dtype = mstype.float16
|
||||
|
||||
key = Tensor(np.ones((bsz, n_kv_head, seq_len, head_dim)), mstype.float16)
|
||||
value = Tensor(np.ones((bsz, n_kv_head, seq_len, head_dim)), mstype.float16)
|
||||
batch_valid_length = Tensor(np.zeros((bsz, 2)), mstype.int32)
|
||||
|
||||
kv_shape = (bsz, n_kv_head, seq_len, head_dim)
|
||||
kv_cache_mgr = KVCacheMgr(n_kv_head,
|
||||
head_dim,
|
||||
batch_size=bsz,
|
||||
seq_length=seq_len,
|
||||
compute_dtype=compute_dtype)
|
||||
key_cache, value_cache = kv_cache_mgr(key, value, None, batch_valid_length)
|
||||
assert key_cache.shape == kv_shape and value_cache.shape == kv_shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_infer_attention():
|
||||
"""
|
||||
Feature: Test the infer attention.
|
||||
Description: Test the forward
|
||||
Expectation: No exception
|
||||
"""
|
||||
os.environ['ASCEND_HOME_PATH'] = "/usr/local/Ascend/latest"
|
||||
jit_level = "O0"
|
||||
infer_boost = "off"
|
||||
ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
|
||||
|
||||
bsz, head_num, seq_len, head_dim = 8, 80, 256, 128
|
||||
n_kv_head = 8
|
||||
hidden_size = head_num * head_dim
|
||||
query = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
|
||||
key = Tensor(np.ones((bsz, seq_len, n_kv_head * head_dim)), mstype.float16)
|
||||
value = Tensor(np.ones((bsz, seq_len, n_kv_head * head_dim)), mstype.float16)
|
||||
batch_valid_length = Tensor(np.zeros((bsz, 2)), mstype.int32)
|
||||
attn_mask = Tensor(np.ones((bsz, 1, seq_len, seq_len)), mstype.uint8)
|
||||
|
||||
freqs_mgr = FreqsMgr(head_dim=head_dim, seq_length=seq_len, max_position_embedding=seq_len)
|
||||
freqs_cis = freqs_mgr(seq_len)
|
||||
infer_attention = InferAttention(head_num,
|
||||
head_dim,
|
||||
n_kv_head,
|
||||
scale_value=1. / math.sqrt(head_dim),
|
||||
pre_tokens=65536,
|
||||
next_tokens=0,
|
||||
batch_size=bsz,
|
||||
seq_length=seq_len,
|
||||
compute_dtype=mstype.float16)
|
||||
|
||||
output = infer_attention(query, key, value, batch_valid_length, None, None, freqs_cis, attn_mask)
|
||||
assert output.shape == (bsz, seq_len, hidden_size)
|
||||
@@ -1,62 +0,0 @@
|
||||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test infer attention"""
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindformers.modules import PagedAttentionMgr
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_paged_attention_mgr():
|
||||
"""
|
||||
Feature: Test the paged attention.
|
||||
Description: Test the forward
|
||||
Expectation: No exception
|
||||
"""
|
||||
os.environ['ASCEND_HOME_PATH'] = "/usr/local/Ascend/latest"
|
||||
jit_level = "O0"
|
||||
infer_boost = "on"
|
||||
ms.set_context(jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
|
||||
|
||||
bsz, head_num, seq_len, head_dim = 1, 16, 4096, 128
|
||||
n_kv_head = 16
|
||||
block_size = 1024
|
||||
num_blocks = 16
|
||||
compute_dtype = mstype.float16
|
||||
hidden_size = head_num * head_dim
|
||||
batch_valid_length = Tensor(np.ones((bsz, 1)), mstype.int32)
|
||||
block_tables = Tensor(np.ones((bsz, num_blocks)), mstype.int64)
|
||||
query = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
|
||||
key = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
|
||||
value = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
|
||||
slot_mapping = Tensor(np.ones((bsz * seq_len)), mstype.int32)
|
||||
|
||||
kv_shape = (num_blocks, block_size, n_kv_head, head_dim)
|
||||
paged_attention_mgr = PagedAttentionMgr(head_num,
|
||||
head_dim,
|
||||
n_kv_head,
|
||||
kv_shape,
|
||||
compute_dtype=compute_dtype)
|
||||
paged_attention_mgr(key, value, slot_mapping)
|
||||
|
||||
context_layer = paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables)
|
||||
assert context_layer.shape == (1, 4096, 2048)
|
||||
Reference in New Issue
Block a user