!1293 [AKG-MLIR] Harden StoreLoadElim memref handling to avoid invalid casts

Merge pull request !1293 from yuziyu/br_aikg
This commit is contained in:
i-robot
2025-11-28 01:20:27 +00:00
committed by Gitee
2 changed files with 34 additions and 13 deletions

View File

@@ -406,13 +406,20 @@ class CommonUtils {
}
static Value getStoreMemref(Operation *storeOp) {
Value memref;
if (dyn_cast<affine::AffineStoreOp>(storeOp)) {
memref = dyn_cast<affine::AffineStoreOp>(storeOp).getMemref();
} else if (dyn_cast<memref::StoreOp>(storeOp)) {
memref = dyn_cast<memref::StoreOp>(storeOp).getMemref();
} else {
if (!storeOp) {
return Value();
}
if (!isa<affine::AffineStoreOp>(storeOp) && !isa<memref::StoreOp>(storeOp)) {
llvm::errs() << "can only get memref from AffineStore or memref::StoreOp.\n";
return Value();
}
if (storeOp->getNumOperands() < 2) {
llvm::errs() << "store op has insufficient operands when querying memref.\n";
return Value();
}
Value memref = storeOp->getOperand(1);
if (!memref || !isa<BaseMemRefType>(memref.getType())) {
return Value();
}
return memref;
}

View File

@@ -36,7 +36,7 @@ namespace mlir {
#define DEBUG_TYPE "elim-store-load"
using namespace mlir;
namespace mlir {
namespace {
// ===----------------------------------------------------------------------===//
@@ -56,7 +56,7 @@ namespace {
// load alloc
// ...
// memref.copy alloc, %global_out
// 2. store and load are not located in the same branch (whatever comes from differnt If or For)
// 2. store and load are not located in the same branch (whatever comes from different If or For)
// because the stored variable %x will be out of scope for load
// e.g.
// for arg0
@@ -89,6 +89,10 @@ struct StoreLoadElimPass : public StoreLoadElimBase<StoreLoadElimPass> {
SmallVector<Operation *> getPossibleElimLoads(Operation *storeOp) const {
SmallVector<Operation *> elimLoads;
auto memref = CommonUtils::getStoreMemref(storeOp);
// check if the memref is valid and the type is correct
if (!memref || !isa<MemRefType>(memref.getType())) {
return SmallVector<Operation *>();
}
for (auto user : memref.getUsers()) {
if (user == storeOp) {
continue;
@@ -118,6 +122,11 @@ void StoreLoadElimPass::runOnOperation() {
SmallVector<Operation *> toElimLoads;
getOperation()->walk([&](Operation *op) {
if (dyn_cast<memref::StoreOp>(op) || dyn_cast<affine::AffineStoreOp>(op)) {
// check if the memref is valid and the type is correct, skip invalid store operations
auto memref = CommonUtils::getStoreMemref(op);
if (!memref || !isa<MemRefType>(memref.getType())) {
return;
}
auto elimLoads = getPossibleElimLoads(op);
size_t eraseSize = 0;
for (auto loadOp : elimLoads) {
@@ -129,7 +138,7 @@ void StoreLoadElimPass::runOnOperation() {
eraseSize++;
}
}
bool isGlobalBuffer = CommonUtils::getStoreMemref(op).getDefiningOp() == nullptr;
bool isGlobalBuffer = memref.getDefiningOp() == nullptr;
bool elimAllLoads = eraseSize > 0 && eraseSize == elimLoads.size();
if (elimAllLoads && !isGlobalBuffer) {
toElimStores.push_back(op);
@@ -140,15 +149,20 @@ void StoreLoadElimPass::runOnOperation() {
loadOp->erase();
}
for (auto storeOp : toElimStores) {
// before erasing storeOp, capture the memref
auto memref = CommonUtils::getStoreMemref(storeOp);
if (storeOp->use_empty()) {
storeOp->erase();
}
auto memrefOp = CommonUtils::getStoreMemref(storeOp).getDefiningOp();
if (memrefOp->use_empty()) {
memrefOp->erase();
if (memref && isa<MemRefType>(memref.getType())) {
auto memrefOp = memref.getDefiningOp();
if (memrefOp && memrefOp->use_empty()) {
memrefOp->erase();
}
}
}
}
} // end anonymous namespace
std::unique_ptr<Pass> mlir::createStoreLoadElimPass() { return std::make_unique<StoreLoadElimPass>(); }
std::unique_ptr<Pass> createStoreLoadElimPass() { return std::make_unique<StoreLoadElimPass>(); }
} // namespace mlir