mirror of
https://gitee.com/mindspore/akg.git
synced 2025-12-06 11:59:12 +08:00
!1293 [AKG-MLIR] Harden StoreLoadElim memref handling to avoid invalid casts
Merge pull request !1293 from yuziyu/br_aikg
This commit is contained in:
@@ -406,13 +406,20 @@ class CommonUtils {
|
||||
}
|
||||
|
||||
static Value getStoreMemref(Operation *storeOp) {
|
||||
Value memref;
|
||||
if (dyn_cast<affine::AffineStoreOp>(storeOp)) {
|
||||
memref = dyn_cast<affine::AffineStoreOp>(storeOp).getMemref();
|
||||
} else if (dyn_cast<memref::StoreOp>(storeOp)) {
|
||||
memref = dyn_cast<memref::StoreOp>(storeOp).getMemref();
|
||||
} else {
|
||||
if (!storeOp) {
|
||||
return Value();
|
||||
}
|
||||
if (!isa<affine::AffineStoreOp>(storeOp) && !isa<memref::StoreOp>(storeOp)) {
|
||||
llvm::errs() << "can only get memref from AffineStore or memref::StoreOp.\n";
|
||||
return Value();
|
||||
}
|
||||
if (storeOp->getNumOperands() < 2) {
|
||||
llvm::errs() << "store op has insufficient operands when querying memref.\n";
|
||||
return Value();
|
||||
}
|
||||
Value memref = storeOp->getOperand(1);
|
||||
if (!memref || !isa<BaseMemRefType>(memref.getType())) {
|
||||
return Value();
|
||||
}
|
||||
return memref;
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ namespace mlir {
|
||||
|
||||
#define DEBUG_TYPE "elim-store-load"
|
||||
|
||||
using namespace mlir;
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
// ===----------------------------------------------------------------------===//
|
||||
@@ -56,7 +56,7 @@ namespace {
|
||||
// load alloc
|
||||
// ...
|
||||
// memref.copy alloc, %global_out
|
||||
// 2. store and load are not located in the same branch (whatever comes from differnt If or For)
|
||||
// 2. store and load are not located in the same branch (whatever comes from different If or For)
|
||||
// because the stored variable %x will be out of scope for load
|
||||
// e.g.
|
||||
// for arg0
|
||||
@@ -89,6 +89,10 @@ struct StoreLoadElimPass : public StoreLoadElimBase<StoreLoadElimPass> {
|
||||
SmallVector<Operation *> getPossibleElimLoads(Operation *storeOp) const {
|
||||
SmallVector<Operation *> elimLoads;
|
||||
auto memref = CommonUtils::getStoreMemref(storeOp);
|
||||
// check if the memref is valid and the type is correct
|
||||
if (!memref || !isa<MemRefType>(memref.getType())) {
|
||||
return SmallVector<Operation *>();
|
||||
}
|
||||
for (auto user : memref.getUsers()) {
|
||||
if (user == storeOp) {
|
||||
continue;
|
||||
@@ -118,6 +122,11 @@ void StoreLoadElimPass::runOnOperation() {
|
||||
SmallVector<Operation *> toElimLoads;
|
||||
getOperation()->walk([&](Operation *op) {
|
||||
if (dyn_cast<memref::StoreOp>(op) || dyn_cast<affine::AffineStoreOp>(op)) {
|
||||
// check if the memref is valid and the type is correct, skip invalid store operations
|
||||
auto memref = CommonUtils::getStoreMemref(op);
|
||||
if (!memref || !isa<MemRefType>(memref.getType())) {
|
||||
return;
|
||||
}
|
||||
auto elimLoads = getPossibleElimLoads(op);
|
||||
size_t eraseSize = 0;
|
||||
for (auto loadOp : elimLoads) {
|
||||
@@ -129,7 +138,7 @@ void StoreLoadElimPass::runOnOperation() {
|
||||
eraseSize++;
|
||||
}
|
||||
}
|
||||
bool isGlobalBuffer = CommonUtils::getStoreMemref(op).getDefiningOp() == nullptr;
|
||||
bool isGlobalBuffer = memref.getDefiningOp() == nullptr;
|
||||
bool elimAllLoads = eraseSize > 0 && eraseSize == elimLoads.size();
|
||||
if (elimAllLoads && !isGlobalBuffer) {
|
||||
toElimStores.push_back(op);
|
||||
@@ -140,15 +149,20 @@ void StoreLoadElimPass::runOnOperation() {
|
||||
loadOp->erase();
|
||||
}
|
||||
for (auto storeOp : toElimStores) {
|
||||
// before erasing storeOp, capture the memref
|
||||
auto memref = CommonUtils::getStoreMemref(storeOp);
|
||||
if (storeOp->use_empty()) {
|
||||
storeOp->erase();
|
||||
}
|
||||
auto memrefOp = CommonUtils::getStoreMemref(storeOp).getDefiningOp();
|
||||
if (memrefOp->use_empty()) {
|
||||
memrefOp->erase();
|
||||
if (memref && isa<MemRefType>(memref.getType())) {
|
||||
auto memrefOp = memref.getDefiningOp();
|
||||
if (memrefOp && memrefOp->use_empty()) {
|
||||
memrefOp->erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createStoreLoadElimPass() { return std::make_unique<StoreLoadElimPass>(); }
|
||||
std::unique_ptr<Pass> createStoreLoadElimPass() { return std::make_unique<StoreLoadElimPass>(); }
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user