mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
!7773 【master】【cleancode】fix softmax ques
Merge pull request !7773 from zyw_hw/fix_cleancode_ques
This commit is contained in:
@@ -14,8 +14,8 @@
|
||||
# ============================================================================
|
||||
"""utils for text generation."""
|
||||
from collections import UserDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
|
||||
@@ -59,22 +59,22 @@ def softmax(x, axis=None):
|
||||
|
||||
|
||||
def softmax_single(i, res, x):
|
||||
""" Worker used by thread pool to compute softmax safely. """
|
||||
res[i] = softmax(x)
|
||||
|
||||
|
||||
def softmax_with_threads(x, is_finished=None):
|
||||
"""calculate softmax with threads"""
|
||||
res = np.ones_like(x)
|
||||
all_threads = []
|
||||
for i in range(0, res.shape[0]):
|
||||
if is_finished and is_finished[i]:
|
||||
continue
|
||||
thread = Thread(target=softmax_single,
|
||||
args=(i, res, x[i]))
|
||||
all_threads.append(thread)
|
||||
thread.start()
|
||||
for thread in all_threads:
|
||||
thread.join()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for i in range(0, res.shape[0]):
|
||||
if is_finished and is_finished[i]:
|
||||
continue
|
||||
future = executor.submit(softmax_single, i, res, x[i])
|
||||
futures.append(future)
|
||||
for future in futures:
|
||||
future.result()
|
||||
return res
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user