Compare commits

...

4 Commits

Author SHA1 Message Date
tan-wei-cheng
a1eb2ec173 !5068 infer support infer boost off mode with kvcache and aclnn op
Merge pull request !5068 from tan-wei-cheng/develop-twc-dev4
2025-01-14 07:00:16 +00:00
twc
8cb628a9ab infer support infer boost off mode with kvcache and aclnn op 2025-01-14 14:56:37 +08:00
tan-wei-cheng
562584dc0e !5019 MOE infer realize fusion op
Merge pull request !5019 from tan-wei-cheng/develop-twc-dev2
2025-01-14 06:48:45 +00:00
twc
75078d33d5 MOE infer realize fusion op 2025-01-13 15:49:58 +08:00
30 changed files with 503 additions and 572 deletions

View File

@@ -20,6 +20,7 @@ import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer
from mindspore.ops.auto_generate import PagedAttention
from mindformers.experimental.parallel_core.pynative.transformer.scale_mask_softmax import ScaleMaskSoftmax
from mindformers.experimental.parallel_core.pynative.utils import divide
@@ -29,7 +30,7 @@ from mindformers.experimental.infer.core.utils import get_tp_world_size, create_
from mindformers.modules.flash_attention import FlashAttention
from mindformers.modules.infer_attention import InferRotaryEmbedding
from mindformers.modules.layers import FreqsMgr, RotaryEmbedding
from mindformers.modules.paged_attention_mgr import PagedAttentionMgr
from mindformers.modules.kv_cache_mgr import KVCacheMgr
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.tools.utils import is_pynative
@@ -382,24 +383,14 @@ class ParallelAttention(nn.Cell):
if self.use_past:
kv_shape = (self.config.num_blocks, self.config.block_size, self.kv_num_heads_per_partition, self.head_dim)
self.paged_attention_mgr = PagedAttentionMgr(self.num_heads_per_partition,
self.head_dim,
self.kv_num_heads_per_partition,
kv_shape,
config.seq_length,
compute_dtype=self.compute_dtype)
self.paged_attention_mgr.key_cache = create_empty_parameter(
shape=self.paged_attention_mgr.key_cache.shape,
dtype=self.paged_attention_mgr.key_cache.dtype,
name=self.paged_attention_mgr.key_cache.name,
requires_grad=self.paged_attention_mgr.key_cache.requires_grad,
)
self.paged_attention_mgr.value_cache = create_empty_parameter(
shape=self.paged_attention_mgr.value_cache.shape,
dtype=self.paged_attention_mgr.value_cache.dtype,
name=self.paged_attention_mgr.value_cache.name,
requires_grad=self.paged_attention_mgr.value_cache.requires_grad,
)
self.kv_cache_mgr = KVCacheMgr(self.kv_num_heads_per_partition,
self.head_dim,
num_blocks=self.config.num_blocks,
block_size=self.config.block_size,
compute_dtype=self.compute_dtype)
self.paged_attention = PagedAttention(self.num_heads_per_partition,
1.0 / self.norm_factor,
self.kv_num_heads_per_partition)
self.rotary_embedding = InferRotaryEmbedding(rotary_cos_format=2)
else:
self.apply_rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_dtype)
@@ -461,8 +452,7 @@ class ParallelAttention(nn.Cell):
if self.is_first_iteration:
key, value = self._cat_prefix(key, value, prefix_keys_values)
key_out = self.paged_attention_mgr(key, value, slot_mapping, batch_valid_length)
query = ops.depend(query, key_out)
key_cache, value_cache = self.kv_cache_mgr(key, value, slot_mapping, batch_valid_length)
if self.is_first_iteration:
if self.use_flash_attention:
@@ -496,9 +486,8 @@ class ParallelAttention(nn.Cell):
context_layer = context_layer.transpose(0, 2, 1, 3).reshape(
bs, seq_len, self.hidden_size_per_partition)
else:
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables)
# [B, S, N, D]
context_layer = self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
None, None, None, None)
else:
# [B, S, N, D] --> [B, N, S, D]
query = query.transpose(0, 2, 1, 3)

View File

@@ -133,7 +133,6 @@ class ParallelLlamaForCausalLM(LlamaPreTrainedModel):
for layer in self.model.layers:
layer.add_flags(is_first_iteration=is_first_iteration)
layer.attention.add_flags(is_first_iteration=is_first_iteration)
layer.attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
# pylint: disable=W0613
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
@@ -164,8 +163,8 @@ class ParallelLlamaForCausalLM(LlamaPreTrainedModel):
return logits, input_ids, input_mask
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.value_cache
return key_cache, value_cache
@classmethod

View File

@@ -504,6 +504,6 @@ class CogVLM2VideoLM(LlamaPreTrainedModel):
def kvcache(self, layer_idx):
"""Get kvcache with input layer index."""
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -891,8 +891,8 @@ class LlamaForCausalLMForCogVLM2Image(LlamaPreTrainedModel):
"""Get kvcache with input layer index."""
key_cache = self.model.layers[
layer_idx
].self_attn.infer_attention.paged_attention_mgr.key_cache
].self_attn.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[
layer_idx
].self_attn.infer_attention.paged_attention_mgr.value_cache
].self_attn.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -478,7 +478,7 @@ class ChatGLM2WithPtuning2(ChatGLM2ForConditionalGeneration):
def kvcache(self, layer_idx):
key_cache = \
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.key_cache
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.key_cache
value_cache = \
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.value_cache
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -35,7 +35,7 @@ from mindformers.modules.layers import Linear, FreqsMgr
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.tools.utils import get_predict_run_mode
from mindformers.tools.utils import get_predict_run_mode, get_infer_boost
from .llama_config import LlamaConfig
from .llama_layer import LlamaEmbedding, LlamaRMSNorm
@@ -95,7 +95,9 @@ class LlamaModel(LlamaPreTrainedModel):
self.shape = P.Shape()
self.reshape = P.Reshape()
self.rmsnorm_compute_2d = config.rmsnorm_compute_2d
self.enable_infer_boost = get_infer_boost()
if self.use_past and not self.enable_infer_boost:
self.range = Tensor(np.arange(config.seq_length).reshape((1, 1, -1)), mstype.int32)
if config.moe_config.expert_num > 1:
logger.info("MoE config is provided, use MoE FFN")
else:
@@ -280,6 +282,34 @@ class LlamaModel(LlamaPreTrainedModel):
else:
self.norm_out.shard((dp, cp, 1))
def gen_infer_freqs_and_mask(self, batch_size, seq_len, tokens, batch_valid_length, prefix_keys_values):
"""generate infer rope freqs and attention mask."""
# for infer boost off mode or o2 mode
if not self.enable_infer_boost:
if self.is_first_iteration:
freqs_cis = self.freqs_mgr(seq_len)
mask = self.casual_mask(tokens) # mask: [bs, seq, seq]
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
mask = self.casual_mask.increment(self.range, batch_valid_length)
return freqs_cis, mask
# for O0 + infer boost on mode
mask = None
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill(batch_size, seq_len)
mask = self.casual_mask.prefill()
if prefix_keys_values is not None:
if mask is None:
mask = self.casual_mask(tokens)
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((batch_size, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
return freqs_cis, mask
# pylint: disable=W0613
def construct(self, tokens: Tensor, input_embeds=None, batch_valid_length=None, batch_index=None,
zactivate_len=None, block_tables=None, slot_mapping=None, prefix_keys_values=None,
@@ -325,17 +355,8 @@ class LlamaModel(LlamaPreTrainedModel):
else:
mask = None
if self.use_past:
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
mask = self.casual_mask.prefill()
if prefix_keys_values is not None:
if mask is None:
mask = self.casual_mask(tokens)
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
freqs_cis, mask = self.gen_infer_freqs_and_mask(bs, seq_len, tokens, batch_valid_length,
prefix_keys_values)
else:
if self.seq_pipe:
mask = self.casual_mask(tokens, seq_chunk=self.seq_chunk)
@@ -359,7 +380,7 @@ class LlamaModel(LlamaPreTrainedModel):
else:
h = self.cast(self.tok_embeddings(tokens), self.dtype)
if not rmsnorm_compute_2d:
h = self.reshape(h, (bs, seq_len, self.hidden_size)) # h: [bs, seq/1, hidden_dim]
h = self.reshape(h, (bs, seq_len, self.hidden_size)) # h: [bs, seq/1, hidden_dim]
for i in range(self.num_layers):
prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None
h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables,
@@ -580,7 +601,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
for layer in self.model.layers:
layer.add_flags(is_first_iteration=is_first_iteration)
layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration)
layer.attention.infer_attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
def pre_gather_func(self, pre_gather, output, batch_valid_length, gather_index=None):
"""Pre gather operation in infer mode."""
@@ -655,8 +675,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache
@classmethod

View File

@@ -664,7 +664,7 @@ class LlamaMoeInferFeedForward(Cell):
self.w3.shard(strategy_matmul=(((1, 1),), ((1, 1, mp),), ((),), ((),), ((),), ((),), ((),), (1,)))
self.w2.shard(strategy_matmul=(((1, mp),), ((1, mp, 1),), ((),), ((),), ((),), ((),), ((),), (1,)))
else:
self.mul.shard((1, mp), (1, mp))
self.mul.shard(((1, mp), (1, mp)))
self.w1.shard(strategy_matmul=(((1, 1),), ((1, 1, mp),), ((),), ((),), ((),), ((),), ((),), (1,)),
strategy_activation=((1, 1, mp, 1),))
self.w3.shard(strategy_matmul=(((1, 1),), ((1, 1, mp),), ((),), ((),), ((),), ((),), ((),), (1,)))
@@ -771,8 +771,10 @@ class LlamaFeedForwardWithMoE(Cell):
self.mul.shard(((dp, 1, 1), (dp, 1, 1)))
self.add.shard(((dp, 1, 1), (dp, 1, 1)))
self.sigmoid.shard(((dp, 1, 1),))
self.routed_experts.ffn.shard(parallel_config)
if self.use_moe_infer:
self.routed_experts.shard(parallel_config)
else:
self.routed_experts.ffn.shard(parallel_config)
self.shared_experts.shard(parallel_config)
self.shared_experts.mul.shard(((dp, 1, mp), (dp, 1, mp)))

View File

@@ -41,6 +41,6 @@ from .layers import (
RotaryEmbedding
)
from .local_block_sparse_attention import LocalBlockSparseAttention
from .paged_attention_mgr import PagedAttentionMgr
from .kv_cache_mgr import KVCacheMgr
__all__ = []

View File

@@ -20,9 +20,11 @@ from mindspore import ops
from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell
from mindspore.ops import operations as P
from mindspore.ops.auto_generate import PagedAttention
from mindformers.modules import PagedAttentionMgr
from mindformers.modules import KVCacheMgr
from mindformers.modules.flash_attention import FlashAttention
from mindformers.tools.utils import get_infer_boost
class InferRotaryEmbedding(Cell):
@@ -59,6 +61,22 @@ class InferRotaryEmbedding(Cell):
self.rotary_embedding_op.shard(((dp, 1, mp), (dp, 1, mp), (1, 1), (1, 1), (dp,)))
class AttentionInput:
"""Infer Attention Input."""
def __init__(self, query, key, value, batch_valid_length, block_tables, slot_mapping):
self.query = query
self.key = key
self.value = value
self.batch_valid_length = batch_valid_length
self.block_tables = block_tables
self.slot_mapping = slot_mapping
def __repr__(self):
return f"AttentionInput(query={self.query}, key={self.key}, value={self.value}," \
f"batch_valid_length={self.batch_valid_length}, block_tables={self.block_tables}," \
f"slot_mapping={self.slot_mapping})"
class InferAttention(Cell):
"""Infer Attention Layer.
@@ -215,6 +233,7 @@ class InferAttention(Cell):
sparse_mode=0,
block_size=16,
num_blocks=1024,
batch_size=32,
seq_length=-1,
is_dynamic=True,
use_flash_attention=True,
@@ -238,6 +257,8 @@ class InferAttention(Cell):
self.sparse_mode = sparse_mode
self.block_size = block_size
self.num_blocks = num_blocks
self.batch_size = batch_size
self.seq_length = seq_length
self.use_flash_attention = use_flash_attention
self.use_alibi_mask = use_alibi_mask
self.use_rope_rotary_emb = use_rope_rotary_emb
@@ -259,16 +280,20 @@ class InferAttention(Cell):
self.inv_norm_factor = Tensor(1.0 / math.sqrt(self.head_dim), dtype=compute_dtype)
self.not_equal = P.NotEqual()
self.n_rep = self.n_head // self.n_kv_head
if self.use_alibi_mask:
self.add_alibi = P.Add()
self.use_attention_mask = True
self.enable_infer_boost = get_infer_boost()
self.is_dynamic = is_dynamic
if self.is_dynamic:
self.input_layout = "TH"
self.use_attention_mask = True
self.parallel_decoding = parallel_decoding
if self.enable_infer_boost:
if self.is_dynamic:
self.input_layout = "TH"
self.use_attention_mask = True
else:
self.input_layout = "BSH"
self.use_attention_mask = False
else:
self.input_layout = "BSH"
self.use_attention_mask = False
self.input_layout = "BNSD"
self.use_attention_mask = True
if self.use_flash_attention:
self.flash_attention = FlashAttention(head_num=self.n_head,
@@ -281,16 +306,17 @@ class InferAttention(Cell):
use_alibi_mask=self.use_alibi_mask,
input_layout=self.input_layout)
kv_shape = (self.num_blocks, self.block_size, self.n_kv_head, self.head_dim)
self.paged_attention_mgr = PagedAttentionMgr(self.pa_n_head_split,
self.head_dim,
self.pa_n_kv_head_split,
kv_shape,
seq_length,
compute_dtype=self.compute_dtype,
parallel_decoding=parallel_decoding,
chunk_prefill=chunk_prefill,
)
self.kv_cache_mgr = KVCacheMgr(self.n_kv_head,
self.head_dim,
num_blocks=self.num_blocks,
block_size=self.block_size,
batch_size=self.batch_size,
seq_length=self.seq_length,
compute_dtype=self.compute_dtype)
self.paged_attention = PagedAttention(self.pa_n_head_split,
self.scale_value,
self.pa_n_kv_head_split)
if use_rope_rotary_emb:
self.rotary_embedding = InferRotaryEmbedding(self.rotary_cos_format)
@@ -425,12 +451,51 @@ class InferAttention(Cell):
raise ValueError("FlashAttention input layout:{} is not supported.".format(self.input_layout))
def _incre_attention(self, query, batch_valid_length, block_tables, alibi_mask=None, attn_mask=None,
q_seq_lens=None):
if self.use_alibi_mask:
return self.paged_attention_mgr.paged_attn_with_alibi(query, batch_valid_length, block_tables, alibi_mask)
return self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables, attn_mask=attn_mask,
q_seq_lens=q_seq_lens)
def _infer_boost_attention(self, query, key, value, batch_valid_length, block_tables, slot_mapping, attn_mask=None, alibi_mask=None, prefix_keys_values=None,
q_seq_lens=None):
"""The forward compute of infer Attention with boost."""
if prefix_keys_values is not None:
prefix_len = prefix_keys_values.shape[2]
slot_mapping = slot_mapping + self.cast(self.not_equal(slot_mapping, -1), mstype.int32) * prefix_len
if self.is_first_iteration:
key, value = self._cat_prefix(key, value, prefix_keys_values)
key_cache, value_cache = self.kv_cache_mgr(key, value, slot_mapping, batch_valid_length)
if self.chunk_prefill:
return self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
None, None, attn_mask, q_seq_lens)
if self.is_first_iteration:
return self._prefill_attention(query, key, value, attn_mask, alibi_mask, batch_valid_length,
batch_valid_length)
else:
if self.parallel_decoding:
return self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
None, None, attn_mask, q_seq_lens)
return self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length)
def _infer_normal_attention(self, query, key, value, batch_valid_length, attn_mask):
"""The forward compute of infer Attention without boost."""
bs, seq_len, _ = query.shape
key_seq_len = key.shape[1]
value_seq_len = value.shape[1]
# (B,S,H) -> (B,N,S,D)
query = self.transpose(self.reshape(query, (bs, seq_len, self.n_head, self.head_dim)), (0, 2, 1, 3))
key = self.transpose(self.reshape(key, (bs, key_seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3))
value = self.transpose(self.reshape(value, (bs, value_seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3))
if self.is_first_iteration:
batch_valid_length = batch_valid_length * 0
key_cache, value_cache = self.kv_cache_mgr(key, value, None, batch_valid_length)
if self.use_flash_attention:
attention = self.flash_attention(query, key_cache, value_cache, attn_mask)
return self._merge_heads(attention)
key_cache = self._repeat_kv(key_cache, self.n_rep)
value_cache = self._repeat_kv(value_cache, self.n_rep)
return self._core_attention(query, key_cache, value_cache, attn_mask)
def construct(self, query, key, value, batch_valid_length, block_tables, slot_mapping, freqs_cis=None,
attn_mask=None, alibi_mask=None, prefix_keys_values=None, q_seq_lens=None):
@@ -438,23 +503,9 @@ class InferAttention(Cell):
if self.use_rope_rotary_emb:
query, key = self._apply_rotary_pos_emb(query, key, freqs_cis, batch_valid_length)
if prefix_keys_values is not None:
prefix_len = prefix_keys_values.shape[2]
slot_mapping = slot_mapping + self.cast(self.not_equal(slot_mapping, -1), mstype.int32) * prefix_len
if self.is_first_iteration:
key, value = self._cat_prefix(key, value, prefix_keys_values)
key_out = self.paged_attention_mgr(key, value, slot_mapping, batch_valid_length)
query = ops.depend(query, key_out)
if self.chunk_prefill:
return self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables, attn_mask=attn_mask,
q_seq_lens=q_seq_lens)
if self.is_first_iteration:
return self._prefill_attention(query, key, value, attn_mask, alibi_mask, batch_valid_length,
batch_valid_length)
return self._incre_attention(query, batch_valid_length, block_tables, alibi_mask, attn_mask, q_seq_lens)
if self.enable_infer_boost:
return self._infer_boost_attention(query, key, value, batch_valid_length, block_tables, slot_mapping, attn_mask, alibi_mask, prefix_keys_values, q_seq_lens)
return self._infer_normal_attention(query, key, value, batch_valid_length, attn_mask)
def shard(self, parallel_config):
"""Parallel strategy configuratiuon interface."""
@@ -465,7 +516,11 @@ class InferAttention(Cell):
self.rotary_embedding.shard(parallel_config)
if self.use_flash_attention:
self.flash_attention.shard(parallel_config)
self.paged_attention_mgr.shard(parallel_config)
self.kv_cache_mgr.shard(parallel_config)
if self.parallel_decoding:
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,), (dp, 1), (1,)))
else:
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,)))
self.transpose.shard(((dp, 1, mp, 1),))
self.merger_head_transpose.shard(((dp, mp, 1, 1),))
@@ -473,8 +528,6 @@ class InferAttention(Cell):
self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
self.mul.shard(((dp, mp, 1, 1), ()))
self.add.shard(((dp, 1, 1, 1), (dp, mp, 1, 1)))
if self.use_alibi_mask:
self.add_alibi.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
self.softmax.shard(((dp, mp, 1, 1),))
self.tile_kv.shard(((dp, mp, 1, 1),))
return self

View File

@@ -0,0 +1,80 @@
# Copyright 20244 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""KV Cache Attention Manager for inference."""
import mindspore.common.dtype as mstype
from mindspore import nn, Parameter, ops
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import Zero
from mindspore.ops.auto_generate import KVCacheScatterUpdate, ReshapeAndCache
from mindformers.tools.utils import get_infer_boost
class KVCacheMgr(nn.Cell):
"""KV Cache Manager."""
def __init__(self,
n_kv_head,
head_dim,
num_blocks=1024,
block_size=128,
batch_size=32,
seq_length=4096,
compute_dtype=mstype.float16):
super().__init__()
self.n_kv_head = n_kv_head
self.head_dim = head_dim
self.num_blocks = num_blocks
self.block_size = block_size
self.batch_size = batch_size
self.seq_length = seq_length
self.enable_infer_boost = get_infer_boost()
print("enable_infer_boost-----",self.enable_infer_boost)
if self.enable_infer_boost:
kv_shape = (self.num_blocks, self.block_size, self.n_kv_head, self.head_dim)
self.reshape_and_cache = ReshapeAndCache()
else:
kv_shape = (self.batch_size, self.n_kv_head, self.seq_length, self.head_dim)
self.kv_cache_scatter_update = KVCacheScatterUpdate()
self.key_cache = Parameter(Tensor(shape=kv_shape, dtype=compute_dtype, init=Zero()), name="key_cache",
requires_grad=False)
self.value_cache = Parameter(Tensor(shape=kv_shape, dtype=compute_dtype, init=Zero()), name="value_cache",
requires_grad=False)
print("KVCacheMgr--finish--------------------")
def construct(self, key_update, value_update, slot_mapping=None, batch_valid_length=None):
"""The forward compute of KVCache for Attention."""
key_cache = self.key_cache
value_cache = self.value_cache
if self.enable_infer_boost:
self.reshape_and_cache(key_update, value_update, self.key_cache, self.value_cache, slot_mapping)
else:
# update shape: [real_bs, n_head, max_seqlen, head_dim]
self.kv_cache_scatter_update(self.key_cache, batch_valid_length, key_update, -2, 'update')
self.kv_cache_scatter_update(self.value_cache, batch_valid_length, value_update, -2, 'update')
key_cache = ops.depend(key_cache, key_update)
value_cache = ops.depend(value_cache, value_update)
return key_cache, value_cache
def shard(self, parallel_config):
"""The shard strategy."""
dp = 1 if parallel_config is None else parallel_config.data_parallel
mp = 1 if parallel_config is None else parallel_config.model_parallel
if self.enable_infer_boost:
self.reshape_and_cache.shard(((dp, 1, mp), (dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (1,)))
else:
self.kv_cache_scatter_update.shard(((dp, mp, 1, 1), (1,), (dp, mp, 1, 1)))

View File

@@ -1,254 +0,0 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""KVCache Manager for inference."""
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore import nn, Parameter, ops
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
class KVCacheMgr(nn.Cell):
"""KVCache Manager."""
def __init__(self,
n_head,
head_dim,
max_batch_size=8,
max_seq_length=4096,
compute_dtype=mstype.float16,
is_dynamic=False,
use_kvcache_op=True,
is_flexible_shape=False):
super().__init__()
self.n_head = n_head
self.head_dim = head_dim
self.max_batch_size = max_batch_size
self.max_seq_length = max_seq_length
self.dtype = compute_dtype
self.use_kvcache_op = use_kvcache_op
self.is_dynamic = is_dynamic
self.is_flexible_shape = is_flexible_shape
self.is_first_iteration = True
self.cache_length_tensor = Tensor([max_batch_size * max_seq_length], dtype=mstype.int32)
self.cache_pad_tensor = Tensor([3], dtype=mstype.int64)
self.seq_length_tensor = Tensor([max_seq_length], dtype=mstype.int32)
self.seq_length_tensor_pad = Tensor([max_seq_length, 3], dtype=mstype.int64)
self.seqlen_axis_tensor_pad = Tensor([2, 3], dtype=mstype.int64)
self.pad_before = Tensor([0, 0, 0, 0, 0], mstype.int32)
self.pad_after = Tensor([0, 0], mstype.int32)
self.pad_zero = Tensor(0, compute_dtype)
if self.use_kvcache_op:
# pylint: disable=W0212
self.prompt_kvcache = P._inner_ops.PromptKVCache()
# pylint: disable=W0212
self.decoder_kvcache = P._inner_ops.DecoderKVCache()
else:
self.add = P.Add()
self.mul = P.Mul()
self.assign = P.Assign()
self.concat = P.Concat(axis=0)
self.sub = P.Sub()
self.div = P.Div()
self.pad = P.PadV3()
self.slice = P.StridedSlice()
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
kv_shape = (max_batch_size, n_head, max_seq_length, head_dim)
self.key_past = Parameter(Tensor(np.zeros(kv_shape), compute_dtype), name="key_past", requires_grad=False)
self.value_past = Parameter(Tensor(np.zeros(kv_shape), compute_dtype), name="value_past", requires_grad=False)
def shard(self, parallel_config):
"""shard"""
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.pad.shard(((dp, mp, 1, 1), (1,), ()))
self.slice.shard(((dp, mp, 1, 1),))
if self.use_kvcache_op:
self.prompt_kvcache.shard(((dp, mp, 1, 1), (dp, mp, 1, 1), (dp,), (1,), (1,), (1,), (1,)))
self.decoder_kvcache.shard(((dp, mp, 1, 1), (dp, mp, 1, 1), (dp,), (1,), (1,), (1,), (1,)))
else:
self.add.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
self.mul.shard(((dp, mp, 1, 1), (dp, 1, 1, 1)))
self.assign.shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
def padding(self, key, value, seq_length):
"""padding key, value"""
pad_length = self.sub(self.seq_length_tensor, seq_length)
# calculate padding parameter: (0, 0),(0,0),(0,pad_length),(0,0), append values of 'pad_length' in axis
pad_config = self.concat((self.pad_before, pad_length, self.pad_after))
key_padding = self.pad(key, pad_config, self.pad_zero)
value_padding = self.pad(value, pad_config, self.pad_zero)
return key_padding, value_padding
def trimming(self, key, value, zactivate_len, batch_size):
"""tramming key, value"""
if self.is_flexible_shape:
key = self.reshape(key, (batch_size, self.n_head, -1, self.head_dim))
value = self.reshape(value, (batch_size, self.n_head, -1, self.head_dim))
if zactivate_len is not None:
act_len = self.shape(zactivate_len)[0]
key = self.slice(key, (0, 0, 0, 0), (batch_size, self.n_head, act_len, self.head_dim), (1, 1, 1, 1))
value = self.slice(value, (0, 0, 0, 0), (batch_size, self.n_head, act_len, self.head_dim), (1, 1, 1, 1))
elif not self.is_flexible_shape:
key = self.slice(key, (0, 0, 0, 0),
(batch_size, self.n_head, self.max_seq_length, self.head_dim), (1, 1, 1, 1))
value = self.slice(value, (0, 0, 0, 0),
(batch_size, self.n_head, self.max_seq_length, self.head_dim), (1, 1, 1, 1))
return key, value
def auto_caching(self, key_update, value_update, batch_valid_length, seq_length_tensor_pad, batch_index_pad=None):
"""use kvcache op to cache key, value"""
# key_update shape: [real_bs, n_head, max_seqlen, head_dim]
if self.is_first_iteration:
batch_valid_length = batch_valid_length * 0
self.prompt_kvcache(self.key_past, key_update, batch_valid_length, batch_index_pad,
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
self.prompt_kvcache(self.value_past, value_update, batch_valid_length, batch_index_pad,
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
return None
key_cache = self.key_past
value_cache = self.value_past
key_update = self.decoder_kvcache(self.key_past, key_update, batch_valid_length, batch_index_pad,
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
value_update = self.decoder_kvcache(self.value_past, value_update, batch_valid_length, batch_index_pad,
self.seqlen_axis_tensor_pad, seq_length_tensor_pad, seq_length_tensor_pad)
key_cache = ops.depend(key_cache, key_update)
value_cache = ops.depend(value_cache, value_update)
return key_cache, value_cache
def manual_caching(self, key_update, value_update, valid_length_vector, batch_size):
"""use assign to cache key, value"""
# key_update shape: [real_bs, n_head, 1, head_dim]
if self.is_first_iteration:
if self.is_dynamic:
self.assign(self.key_past,
self.reshape(key_update, (self.max_batch_size, self.n_head, -1, self.head_dim)))
self.assign(self.value_past,
self.reshape(value_update, (self.max_batch_size, self.n_head, -1, self.head_dim)))
else:
self.assign(self.key_past, self.mul(key_update, valid_length_vector))
self.assign(self.value_past, self.mul(value_update, valid_length_vector))
return None
if self.is_dynamic:
key = self.add(self.reshape(self.key_past, (batch_size, self.n_head, -1, self.head_dim)),
self.mul(key_update, valid_length_vector))
value = self.add(self.reshape(self.value_past, (batch_size, self.n_head, -1, self.head_dim)),
self.mul(value_update, valid_length_vector))
self.assign(self.key_past,
self.reshape(key, (self.max_batch_size, self.n_head, -1, self.head_dim)))
self.assign(self.value_past,
self.reshape(value, (self.max_batch_size, self.n_head, -1, self.head_dim)))
else:
key = self.add(self.key_past, self.mul(key_update, valid_length_vector))
value = self.add(self.value_past, self.mul(value_update, valid_length_vector))
self.assign(self.key_past, key)
self.assign(self.value_past, value)
# key shape: [real_bs, n_head, max_cache_len // real_bs, head_dim]
return key, value
def construct(self, key, value, kvcache_inputs=None):
"""The forward compute of KVCacheMgr."""
# add inputs check
batch_valid_length, zactivate_len, batch_index_pad, seq_length_tensor_pad = kvcache_inputs
if not self.use_kvcache_op:
batch_valid_length = self.cast(batch_valid_length, self.dtype)
batch_size, _, seq_length, _ = self.shape(key)
if self.is_first_iteration:
if self.is_dynamic:
key_padding, value_padding = self.padding(key, value, seq_length=seq_length)
else:
key_padding, value_padding = key, value
if self.use_kvcache_op:
self.auto_caching(key_padding, value_padding, batch_valid_length,
seq_length_tensor_pad, batch_index_pad)
else:
self.manual_caching(key_padding, value_padding, batch_valid_length, batch_size=batch_size)
else:
if self.use_kvcache_op:
key, value = self.auto_caching(key, value, batch_valid_length,
seq_length_tensor_pad, batch_index_pad)
else:
key, value = self.manual_caching(key, value, batch_valid_length, batch_size=batch_size)
key, value = self.trimming(key, value, zactivate_len, batch_size=batch_size)
return key, value
class KVCachePreprocess(nn.Cell):
"""KVCache Manager."""
def __init__(self,
max_batch_size=8,
max_seq_length=4096,
is_dynamic=False,
use_kvcache_op=False,
is_flexible_shape=False):
super().__init__()
self.is_dynamic = is_dynamic
self.use_kvcache_op = use_kvcache_op
self.is_flexible_shape = is_flexible_shape
self.max_cache_length = max_batch_size * max_seq_length
range_len = self.max_cache_length if self.is_flexible_shape else max_seq_length
self.range = Tensor(np.arange(range_len).reshape((1, 1, -1)), mstype.int32)
self.cache_length_tensor = Tensor([max_batch_size * max_seq_length], dtype=mstype.int32)
self.cache_pad_tensor = Tensor([3], dtype=mstype.int64)
self.seq_length_tensor = Tensor([max_seq_length], dtype=mstype.int32)
self.seq_length_tensor_pad = Tensor([max_seq_length, 3], dtype=mstype.int64)
self.is_first_iteration = True
self.slice = P.StridedSlice()
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
self.div = P.Div()
self.concat = P.Concat(axis=0)
self.cast = P.Cast()
def construct(self, batch_size, batch_valid_length=None, batch_index=None, zactivate_len=None):
"""precompute kvcache inputs"""
seq_range = self.range
if self.is_dynamic and self.is_flexible_shape and not self.use_kvcache_op:
seq_range = self.slice(seq_range, (0, 0, 0), (1, 1, self.max_cache_length // batch_size), (1, 1, 1))
if self.use_kvcache_op:
if batch_index is None:
batch_index = ops.arange(0, batch_size, 1)
batch_index_pad = self.concat((batch_index, self.cache_pad_tensor))
seq_length_tensor_pad = self.get_seq_length_tensor_pad(batch_size=batch_size)
batch_valid_length = self.cast(self.reshape(batch_valid_length, (-1,)), mstype.int64)
kvcache_inputs = (batch_valid_length, zactivate_len, batch_index_pad, seq_length_tensor_pad)
else:
if self.is_first_iteration:
valid_length_vector = self.less(seq_range, self.reshape(batch_valid_length, (-1, 1, 1)))
else:
valid_length_vector = self.equal(seq_range, self.reshape(batch_valid_length, (-1, 1, 1)))
valid_length_vector = self.expand_dims(valid_length_vector, 3)
kvcache_inputs = (valid_length_vector, zactivate_len, None, None)
return kvcache_inputs
def get_seq_length_tensor_pad(self, batch_size):
"""get seq_length_tensor_pad"""
if self.is_flexible_shape:
max_seq_length = self.div(self.cache_length_tensor, batch_size).astype(mstype.int64)
return self.concat((max_seq_length, self.cache_pad_tensor))
return self.seq_length_tensor_pad

View File

@@ -1,87 +0,0 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Paged Attention Manager for inference."""
import math
import mindspore.common.dtype as mstype
from mindspore import nn, Parameter
from mindspore import ops as P
from mindspore.common.initializer import initializer
class PagedAttentionMgr(nn.Cell):
"""Paged Attention Manager."""
def __init__(self,
n_heads,
head_dim,
n_kv_heads,
kv_shape,
seq_length=-1,
compute_dtype=mstype.float16,
parallel_decoding=False,
chunk_prefill=False):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.n_kv_heads = n_kv_heads
self.seq_length = seq_length
self.is_first_iteration = True
self.scale_value = 1 / math.sqrt(self.head_dim)
self.key_cache = Parameter(initializer('zeros', kv_shape, compute_dtype), name="key_cache",
requires_grad=False)
self.value_cache = Parameter(initializer('zeros', kv_shape, compute_dtype), name="value_cache",
requires_grad=False)
self.reshape_and_cache = P.auto_generate.ReshapeAndCache()
self.paged_attention = P.auto_generate.PagedAttention(self.n_heads,
self.scale_value,
self.n_kv_heads)
self.paged_attention_with_alibi = P.auto_generate.PagedAttentionMask(self.n_heads,
self.scale_value,
self.n_kv_heads)
self.parallel_decoding = parallel_decoding
self.chunk_prefill = chunk_prefill
# pylint: disable=W0613
def construct(self, key, value, slot_mapping, batch_valid_length=None):
"""The forward compute of KVCache for Paged Attention."""
return self.reshape_and_cache(key, value, self.key_cache, self.value_cache, slot_mapping)
def paged_attn(self, query, batch_valid_length, block_tables, attn_mask=None, q_seq_lens=None):
"""The forward compute of Paged Attention."""
if self.parallel_decoding or self.chunk_prefill:
attn_mask = attn_mask.astype(mstype.bool_).astype(query.dtype) * -10000
return self.paged_attention(query, self.key_cache, self.value_cache, block_tables, batch_valid_length,
None, None, attn_mask, q_seq_lens)
return self.paged_attention(query, self.key_cache, self.value_cache, block_tables, batch_valid_length)
def paged_attn_with_alibi(self, query, batch_valid_length, block_tables, alibi_tensor):
"""The forward compute of KVCache for Paged Attention with alibi tensor."""
return self.paged_attention_with_alibi(query, self.key_cache, self.value_cache,
block_tables, batch_valid_length, None, None, alibi_tensor)
def shard(self, parallel_config):
"""The shard strategy."""
dp = 1 if parallel_config is None else parallel_config.data_parallel
mp = 1 if parallel_config is None else parallel_config.model_parallel
self.reshape_and_cache.shard(((dp, 1, mp), (dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (1,)))
if self.parallel_decoding:
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,), (dp, 1), (1,)))
else:
self.paged_attention.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,)))
self.paged_attention_with_alibi.shard(((dp, 1, mp), (1, 1, mp, 1), (1, 1, mp, 1), (dp, 1), (dp,),
(dp, mp, 1, 1)))

View File

@@ -38,15 +38,16 @@ except ImportError:
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.ops.auto_generate import MoeFinalizeRouting
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindformers.modules.transformer.op_parallel_config import default_moeparallel_config, MoEParallelConfig
from mindformers.version_control import check_valid_moefinalizerouting_op
from mindformers.modules.transformer.moe_utils import ZLoss
from mindformers.tools.utils import get_predict_run_mode
from mindformers.version_control import check_valid_moe_big_kernel
__all__ = [
"MoEConfig"]
@@ -1793,6 +1794,7 @@ class TopkRouterV2(Cell):
return comm_load_loss
class MoEInfer(Cell):
r"""
MoEInfer. Routing each tokens to the topk expert and calculating the final output.
@@ -1816,10 +1818,11 @@ class MoEInfer(Cell):
parallel_config):
super(MoEInfer, self).__init__()
self.hidden_size = dim
self.expert_dim = moe_config.expert_num
self.topk_norm_prob = moe_config.norm_topk_prob
self.expert_num = moe_config.expert_num
self.norm_topk_prob = moe_config.norm_topk_prob
self.num_experts_chosen = moe_config.num_experts_chosen
self.moe_config = moe_config
self.router_dense_type = moe_config.router_dense_type
self.routed_scaling_factor = moe_config.routed_scaling_factor
self.ffn = ffn
self.router = Router(d_model=self.hidden_size, moe_config=moe_config, routing_policy=None,
@@ -1831,23 +1834,58 @@ class MoEInfer(Cell):
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
self.mod = P.Mod().shard(((1,), ()))
self.topk = P.TopK().shard(((1, 1),))
self.softmax = P.Softmax().shard(((1, 1),))
self.expand_dims = P.ExpandDims().shard(((1,),))
self.transpose_2d = P.Transpose().shard(((1, 1),))
self.sort = P.Sort().shard(((1,),))
self.gather = P.Gather().shard(((1, 1), (1,)))
self.onehot = P.OneHot().shard(((1, 1), (), ()))
self.cumsum = P.CumSum(exclusive=False).shard(((1,),))
self.expand_dims = P.ExpandDims()
self.moe_finalize_routing = MoeFinalizeRouting()
self.use_moe_big_kernel = check_valid_moe_big_kernel()
if self.use_moe_big_kernel:
# fused op
from mindspore.ops.auto_generate import MoeGatingTopKSoftmax, MoeInitRouting, MoeComputeExpertTokens
self.moe_init_routing = MoeInitRouting()
self.moe_compute_expert_tokens = MoeComputeExpertTokens()
self.moe_gating_topk_softmax = MoeGatingTopKSoftmax()
else:
self.mod = ops.auto_generate.RemainderTensorScalar()
self.topk = P.TopK()
self.softmax = P.Softmax()
self.expand_dims = P.ExpandDims()
self.transpose_2d = P.Transpose()
self.sort = P.Sort()
self.gather = P.Gather()
self.onehot = P.OneHot()
self.cumsum = P.CumSum(exclusive=False)
self.on_value = Tensor(1.0, dtype=mstype.float32)
self.off_value = Tensor(0.0, dtype=mstype.float32)
if check_valid_moefinalizerouting_op():
from mindspore.ops.auto_generate import MoeFinalizeRouting
self.moe_finalize_routing = MoeFinalizeRouting().shard(((1, 1), (1, 1), (1, 1), (1, 1), (1,), (1, 1)))
self.on_value = Tensor(1.0, dtype=mstype.float32)
self.off_value = Tensor(0.0, dtype=mstype.float32)
def tensor_sort(self, input_tensor, expert_ids):
def tensor_sort(self, input_tensor, expert_index, row_index):
"""dispatch and get unsort map for routing"""
expanded_x, expanded_row_idx, expanded_expert_idx = \
self.moe_init_routing(input_tensor, row_index, expert_index, self.shape(input_tensor)[0])
expert_tokens = self.moe_compute_expert_tokens(expanded_expert_idx, self.expert_num)
return expanded_x, expert_tokens, expanded_row_idx
def tensor_moe_finalize_routing(self, input_tensor, expert_weight, expert_index, unsort_map):
"""calculate the final output by multiplying FeedForward's output and experts' weight in MoeFinalizeRouting"""
input_shape = input_tensor.shape
x1 = self.zeros((input_shape[0] // self.num_experts_chosen, input_shape[-1]), input_tensor.dtype)
x2 = None
bias = self.zeros((self.expert_num, input_shape[-1]), input_tensor.dtype)
expert_weight = self.cast(expert_weight, input_tensor.dtype)
output_tensor = self.moe_finalize_routing(input_tensor, x1, x2, bias, expert_weight, unsort_map, expert_index)
return output_tensor
def gating_topk_softmax(self, input_tensor):
"""calculate the expert value and expert index in MoeGatingTopKSoftmax"""
# (N, E)
gating_logits = self.gating(input_tensor.astype(self.router_dense_type))
# (N, num_experts_chosen), (N, num_experts_chosen), (N, num_experts_chosen)
expert_val, expert_index, row_index = self.moe_gating_topk_softmax(gating_logits, finished=None,
k=self.num_experts_chosen)
return expert_val, expert_index, row_index
def tensor_sort_self_define(self, input_tensor, expert_ids):
"""dispatch and get unsort map for routing"""
expert_shape = expert_ids.shape
transposed_index = self.transpose_2d(expert_ids, (1, 0)) # (N, k) -> (k, N)
@@ -1856,7 +1894,7 @@ class MoEInfer(Cell):
inter_map = self.mod(sort_map, expert_shape[0])
output_tensor = self.gather(input_tensor, inter_map, 0)
expert_mask = self.onehot(reshaped_index, self.expert_dim, self.on_value, self.off_value)
expert_mask = self.onehot(reshaped_index, self.expert_num, self.on_value, self.off_value)
expert_cnt = ops.sum(expert_mask, 0)
group_list = self.cast(self.cumsum(expert_cnt, 0), mstype.int64)
@@ -1864,43 +1902,90 @@ class MoEInfer(Cell):
unsort_map = self.cast(unsort_map, mstype.int32)
return output_tensor, group_list, unsort_map
def tensor_moe_finalize_routing(self, input_tensor, expert_weight, expert_index, unsort_map):
"""calculate the final output by multiplying FeedForward's output and experts' weight in MoeFinalizeRouting"""
input_shape = input_tensor.shape # (2N, h)
x1 = self.zeros((input_shape[0] // self.num_experts_chosen, input_shape[-1]), input_tensor.dtype)
x2 = None
bias = self.zeros((self.expert_dim, input_shape[-1]), input_tensor.dtype)
expert_weight = self.cast(expert_weight, input_tensor.dtype)
output_tensor = self.moe_finalize_routing(input_tensor, x1, x2, bias, expert_weight, unsort_map, expert_index)
return output_tensor
def construct(self, input_tensor):
"""forward process"""
input_tensor_shape = self.shape(input_tensor) # (B, S, H)
def moe_compute_self_define(self, input_tensor):
"""moe compute self define"""
input_dtype = input_tensor.dtype
input_tensor = self.reshape(
input_tensor, (-1, self.hidden_size)) # (bs, seq/1, h) -> (bs*seq, h) : use N replace bs*seq
# gating + topk + softmax
gating_logits = self.gating(input_tensor.astype(mstype.float32)) # (N, h) * (h, E) -> (bs*seq, E)
routing_weights = self.softmax(gating_logits.astype(mstype.float32)) # (N, E) -> (N, E)
expert_val, expert_index = self.topk(routing_weights, self.num_experts_chosen) # (N, E) -> (N, 2), (N, 2)
if self.moe_config.norm_topk_prob and self.num_experts_chosen > 1:
if self.norm_topk_prob and self.num_experts_chosen > 1:
expert_val = self.cast(expert_val, mstype.float32)
expert_weight = expert_val / (self.expand_dims(ops.sum(expert_val, -1), -1) + 1e-9)
else:
expert_weight = ops.mul(self.moe_config.routed_scaling_factor, expert_val)
expert_weight = ops.mul(self.routed_scaling_factor, expert_val)
expert_weight = self.cast(expert_weight, input_dtype)
sorted_input_tensor, group_list, unsort_map = self.tensor_sort(input_tensor, expert_index)
sorted_input_tensor, group_list, unsort_map = self.tensor_sort_self_define(input_tensor, expert_index)
return sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight
# moeffn
expert_output = self.ffn(sorted_input_tensor, group_list) # (N, h) (N, 2) -> (N, 2, h)
def moe_compute(self, input_tensor):
"""moe compute with big kernel"""
input_dtype = input_tensor.dtype
# (N, num_experts_chosen), (N, num_experts_chosen), (N, num_experts_chosen)
expert_val, expert_index, row_index = self.gating_topk_softmax(input_tensor)
if self.norm_topk_prob and self.num_experts_chosen > 1:
expert_val = self.cast(expert_val, mstype.float32)
expert_weight = expert_val / (self.expand_dims(ops.sum(expert_val, -1), -1) + 1e-9)
else:
expert_weight = ops.mul(self.routed_scaling_factor, expert_val)
# (N, num_experts_chosen)
expert_weight = self.cast(expert_weight, input_dtype)
# MOE recompute
# (N * num_experts_chosen, h), (N * num_experts_chosen,), (N * num_experts_chosen,)
sorted_input_tensor, group_list, unsort_map = self.tensor_sort(input_tensor, expert_index, row_index)
group_list = group_list.astype(mstype.int64)
return sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight
def construct(self, input_tensor):
"""forward process"""
# (B, S, H)
input_tensor_shape = self.shape(input_tensor)
# (bs, seq/1, h) -> (bs*seq, h) : use N replace bs*seq
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
if self.use_moe_big_kernel:
sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight = self.moe_compute(input_tensor)
else:
sorted_input_tensor, group_list, unsort_map, expert_index, expert_weight =\
self.moe_compute_self_define(input_tensor)
# MOEFFN
# (N * num_experts_chosen, h)
expert_output = self.ffn(sorted_input_tensor, group_list)
# MOE revert recompute
# (N, h)
moe_output = self.tensor_moe_finalize_routing(
expert_output, expert_weight, expert_index, unsort_map) # -> (N, h)
expert_output, expert_weight, expert_index, unsort_map)
output_tensor = self.reshape(moe_output, input_tensor_shape) # (N, h) -> (bs, seq, h)
# (N, h) -> (bs, seq, h)
output_tensor = self.reshape(moe_output, input_tensor_shape)
return output_tensor
def shard(self, parallel_config):
"""sharding for moe infer cell"""
self.ffn.shard(parallel_config)
self.expand_dims.shard(((1,),))
self.moe_finalize_routing.shard(((1, 1), (1, 1), (1, 1), (1, 1), (1,), (1, 1)))
if self.use_moe_big_kernel:
self.moe_init_routing.shard(((1, 1), (1, 1), (1, 1)))
self.moe_compute_expert_tokens.shard(((1,),))
self.moe_gating_topk_softmax.shard(((1, 1),))
else:
self.mod.shard(((1,), ()))
self.topk.shard(((1, 1),))
self.softmax.shard(((1, 1),))
self.expand_dims.shard(((1,),))
self.transpose_2d.shard(((1, 1),))
self.sort.shard(((1,),))
self.gather.shard(((1, 1), (1,)))
self.onehot.shard(((1, 1), (), ()))
self.cumsum.shard(((1,),))

View File

@@ -62,7 +62,7 @@ from mindformers.version_control import get_dropout, choose_flash_attention_dtyp
from mindformers.tools.logger import _LogActionOnce
from mindformers.tools.logger import logger as log
from mindformers.tools.utils import is_pynative
from mindformers.tools.utils import is_pynative, get_infer_boost
__all__ = [
"AttentionMask",
@@ -1007,7 +1007,7 @@ class LowerTriangularMaskWithDynamic(Cell):
if use_past and chunk_prefill:
self.lower_triangle_mask = Tensor(np.tril(np.ones(shape=(seq_length, seq_length), dtype=np.int8)), \
dtype=compute_type)
elif use_past and not self.is_pynative:
elif use_past and get_infer_boost() and not self.is_pynative:
if self.is_dynamic:
mask_coeff = 1.0 if compute_type is mstype.bfloat16 else -10000.0
self.lower_triangle_mask = Tensor(

View File

@@ -479,6 +479,11 @@ def get_predict_run_mode():
return run_mode == "predict"
def get_infer_boost():
infer_boost = context.get_jit_config().get("infer_boost")
return infer_boost == "on"
def is_main_rank(ignore_check_modelarts=False):
return not get_real_rank() or \
((ignore_check_modelarts or check_in_modelarts()) and get_real_rank() % get_device_num_per_node() == 0)

View File

@@ -295,12 +295,12 @@ def check_valid_gmm_op():
return True
def check_valid_moefinalizerouting_op():
"""check mindspore version is valid for groupedmatmul"""
version_valid = is_version_ge(ms.__version__, "2.3.0")
def check_valid_moe_big_kernel():
"""check mindspore version is valid for moe big kernel"""
version_valid = is_version_ge(ms.__version__, "2.5.0")
if not version_valid:
logger.warning(f"Current MindSpore do not support MoeFinalizeRouting, "
f"please upgrade to 2.3.0 or later version.")
logger.warning(f"Current MindSpore do not support Moe big kernel, "
f"please upgrade to 2.5.0 or later version.")
return False
return True

View File

@@ -221,8 +221,8 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel):
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -453,3 +453,8 @@ class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel):
input_mask = self.reshape(input_mask, (-1,))
loss = self.loss(logits, labels, input_mask)
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -193,7 +193,7 @@ class DeepSeekV2MoEInfer(Cell):
self.add.shard(((dp, 1, 1), (dp, 1, 1)))
self.routed_experts.ffn.shard(parallel_config)
self.routed_experts.shard(parallel_config)
self.shared_experts.shard(parallel_config)
self.shared_experts.mul.shard(((dp, 1, mp), (dp, 1, mp)))

View File

@@ -204,8 +204,8 @@ class QwenForCausalLM(QwenPreTrainedModel):
self.lm_head.shard(strategy_matmul=((1, 1), (dp * mp, 1)))
def kvcache(self, layer_idx):
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -286,8 +286,8 @@ class QwenForCausalLM(QwenPreTrainedModel):
self.lm_head.shard(strategy_matmul=((1, 1), (dp * mp, 1)))
def kvcache(self, layer_idx):
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.transformer.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -465,6 +465,6 @@ class TelechatForCausalLM(TelechatPreTrainedModel):
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -142,7 +142,6 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
for layer in self.model.layers:
layer.add_flags(is_first_iteration=is_first_iteration)
layer.attention.add_flags(is_first_iteration=is_first_iteration)
layer.attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
# pylint: disable=W0613
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
@@ -171,6 +170,6 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
return logits, input_ids, input_mask
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -147,8 +147,7 @@ class TelechatParallelAttention(ParallelAttention):
if self.is_first_iteration:
key, value = self._cat_prefix(key, value, prefix_keys_values)
key_out = self.paged_attention_mgr(key, value, slot_mapping)
query = ops.depend(query, key_out)
key_cache, value_cache = self.kv_cache_mgr(key, value, slot_mapping, batch_valid_length)
if self.is_first_iteration:
if self.use_flash_attention:
@@ -181,7 +180,8 @@ class TelechatParallelAttention(ParallelAttention):
value = value.transpose((0, 2, 1, 3))
context_layer = self.core_attention(query, key, value, attn_mask)
else:
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables)
context_layer = self.paged_attention(query, key_cache, value_cache, block_tables, batch_valid_length,
None, None, None, None)
else:
# [B, S, H] -> [B, N, S, D]
query = query.reshape(bs, seq_len, -1, self.head_dim).transpose((0, 2, 1, 3))

View File

@@ -444,6 +444,6 @@ class TelechatForCausalLM(TelechatPreTrainedModel):
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -462,7 +462,7 @@ class YiZhaoWithPtuning2(YiZhaoForCausalLM):
def kvcache(self, layer_idx):
key_cache = \
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.key_cache
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.key_cache
value_cache = \
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.value_cache
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -501,6 +501,6 @@ class MixtralForCausalLM(LlamaPreTrainedModel):
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
key_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.kv_cache_mgr.value_cache
return key_cache, value_cache

View File

@@ -51,6 +51,7 @@ def test_qwen2_0_5b_predict_standalone():
config.model.model_config.seq_length = seq_length
config.processor.tokenizer.vocab_file = vocab_file_path
config.processor.tokenizer.merges_file = merges_file_path
config.context.device_id = int(os.environ.get("DEVICE_ID", "0"))
# init context
build_context(config)

View File

@@ -0,0 +1,96 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test infer attention"""
import math
import os
import pytest
import numpy as np
import mindspore as ms
from mindspore import dtype as mstype
from mindspore import Tensor
from mindformers.modules import KVCacheMgr
from mindformers.modules.infer_attention import InferAttention
from mindformers.modules.layers import FreqsMgr
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_kv_cache_mgr():
"""
Feature: Test the kv cache manager.
Description: Test the forward
Expectation: No exception
"""
os.environ['ASCEND_HOME_PATH'] = "/usr/local/Ascend/latest"
jit_level = "O0"
infer_boost = "off"
ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
bsz, n_kv_head, seq_len, head_dim = 8, 16, 4096, 128
compute_dtype = mstype.float16
key = Tensor(np.ones((bsz, n_kv_head, seq_len, head_dim)), mstype.float16)
value = Tensor(np.ones((bsz, n_kv_head, seq_len, head_dim)), mstype.float16)
batch_valid_length = Tensor(np.zeros((bsz, 2)), mstype.int32)
kv_shape = (bsz, n_kv_head, seq_len, head_dim)
kv_cache_mgr = KVCacheMgr(n_kv_head,
head_dim,
batch_size=bsz,
seq_length=seq_len,
compute_dtype=compute_dtype)
key_cache, value_cache = kv_cache_mgr(key, value, None, batch_valid_length)
assert key_cache.shape == kv_shape and value_cache.shape == kv_shape
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_infer_attention():
"""
Feature: Test the infer attention.
Description: Test the forward
Expectation: No exception
"""
os.environ['ASCEND_HOME_PATH'] = "/usr/local/Ascend/latest"
jit_level = "O0"
infer_boost = "off"
ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
bsz, head_num, seq_len, head_dim = 8, 80, 256, 128
n_kv_head = 8
hidden_size = head_num * head_dim
query = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
key = Tensor(np.ones((bsz, seq_len, n_kv_head * head_dim)), mstype.float16)
value = Tensor(np.ones((bsz, seq_len, n_kv_head * head_dim)), mstype.float16)
batch_valid_length = Tensor(np.zeros((bsz, 2)), mstype.int32)
attn_mask = Tensor(np.ones((bsz, 1, seq_len, seq_len)), mstype.uint8)
freqs_mgr = FreqsMgr(head_dim=head_dim, seq_length=seq_len, max_position_embedding=seq_len)
freqs_cis = freqs_mgr(seq_len)
infer_attention = InferAttention(head_num,
head_dim,
n_kv_head,
scale_value=1. / math.sqrt(head_dim),
pre_tokens=65536,
next_tokens=0,
batch_size=bsz,
seq_length=seq_len,
compute_dtype=mstype.float16)
output = infer_attention(query, key, value, batch_valid_length, None, None, freqs_cis, attn_mask)
assert output.shape == (bsz, seq_len, hidden_size)

View File

@@ -1,62 +0,0 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test infer attention"""
import os
import pytest
import numpy as np
import mindspore as ms
from mindspore import dtype as mstype
from mindspore import Tensor
from mindformers.modules import PagedAttentionMgr
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_paged_attention_mgr():
"""
Feature: Test the paged attention.
Description: Test the forward
Expectation: No exception
"""
os.environ['ASCEND_HOME_PATH'] = "/usr/local/Ascend/latest"
jit_level = "O0"
infer_boost = "on"
ms.set_context(jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
bsz, head_num, seq_len, head_dim = 1, 16, 4096, 128
n_kv_head = 16
block_size = 1024
num_blocks = 16
compute_dtype = mstype.float16
hidden_size = head_num * head_dim
batch_valid_length = Tensor(np.ones((bsz, 1)), mstype.int32)
block_tables = Tensor(np.ones((bsz, num_blocks)), mstype.int64)
query = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
key = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
value = Tensor(np.ones((bsz, seq_len, hidden_size)), mstype.float16)
slot_mapping = Tensor(np.ones((bsz * seq_len)), mstype.int32)
kv_shape = (num_blocks, block_size, n_kv_head, head_dim)
paged_attention_mgr = PagedAttentionMgr(head_num,
head_dim,
n_kv_head,
kv_shape,
compute_dtype=compute_dtype)
paged_attention_mgr(key, value, slot_mapping)
context_layer = paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables)
assert context_layer.shape == (1, 4096, 2048)