mirror of
https://gitee.com/mindspore/akg.git
synced 2025-12-06 11:59:12 +08:00
add AkgAutoTilingFuncPass
This commit is contained in:
@@ -45,6 +45,7 @@
|
||||
#include "akg/Dialect/Affine/Transforms/SimplifyShape.h"
|
||||
#include "akg/Dialect/Affine/Transforms/UnifyShape.h"
|
||||
#include "akg/Dialect/Affine/Transforms/WorkaroundFixReduceInitialization.h"
|
||||
#include "akg/Dialect/Affine/Transforms/TilingFunc.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
@@ -286,4 +286,9 @@ def AffineForVectPass : Pass<"affine-for-vectorize", "mlir::func::FuncOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
def TilingFunc : Pass<"tiling-func", "mlir::func::FuncOp"> {
|
||||
let summary = "This pass applies tiling to operations on tensors, dynamically processing tiles";
|
||||
let constructor = "mlir::affine::createTilingFuncPass()";
|
||||
}
|
||||
|
||||
#endif // AKG_MLIR_DIALECT_AFFINE_PASSES
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Copyright 2025 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef COMPILER_INCLUDE_AKG_DIALECT_AFFINE_TRANSFORMS_TILINGFUNC_H_
|
||||
#define COMPILER_INCLUDE_AKG_DIALECT_AFFINE_TRANSFORMS_TILINGFUNC_H_
|
||||
|
||||
#include <memory>
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace func {
|
||||
class FuncOp;
|
||||
} // namespace func
|
||||
|
||||
namespace affine {
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createTilingFuncPass();
|
||||
|
||||
} // namespace affine
|
||||
} // namespace mlir
|
||||
|
||||
#endif // COMPILER_INCLUDE_AKG_DIALECT_AFFINE_TRANSFORMS_TILINGFUNC_H_
|
||||
423
akg-mlir/compiler/lib/Dialect/Affine/Transforms/TilingFunc.cpp
Normal file
423
akg-mlir/compiler/lib/Dialect/Affine/Transforms/TilingFunc.cpp
Normal file
@@ -0,0 +1,423 @@
|
||||
/**
|
||||
* Copyright 2025 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "akg/Dialect/Affine/Transforms/TilingFunc.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#define DEBUG_TYPE "tiling-func"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DECL_TILINGFUNC
|
||||
#define GEN_PASS_DEF_TILINGFUNC
|
||||
#include "akg/Dialect/Affine/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir::affine {
|
||||
|
||||
namespace mockattr {
|
||||
static constexpr const char *kFunctionKind = "hacc.function_kind";
|
||||
static constexpr const char *kHostFuncType = "hacc.host_func_type";
|
||||
static constexpr const char *kEnableAutoMarkBufferSize = "enable_auto_mark_buffer_size";
|
||||
static constexpr const char *kBlockDim = "hacc.block_dim";
|
||||
static constexpr const char *kTilingFunction = "hacc.tiling_function";
|
||||
static constexpr const char *kFusionKind = "hfusion.fusion_kind";
|
||||
static constexpr const char *kDevice = "DEVICE";
|
||||
static constexpr const char *kHost = "HOST";
|
||||
static constexpr const char *kHostTilingFunction = "tiling_function";
|
||||
static constexpr const char *kFusionKindPureElemwise = "PURE_ELEMWISE";
|
||||
} // namespace mockattr
|
||||
|
||||
namespace {
|
||||
|
||||
struct AutoTilingOptions {
|
||||
unsigned blockDim = 40;
|
||||
[[maybe_unused]] bool enableManageHostResources = false;
|
||||
};
|
||||
|
||||
struct KernelInfo {
|
||||
func::FuncOp originalKernel;
|
||||
std::string baseKernelName;
|
||||
unsigned blockDim = 40;
|
||||
};
|
||||
|
||||
struct TilingInfo {
|
||||
func::FuncOp hostTilingFunc;
|
||||
void setHostTilingFunc(func::FuncOp f) { hostTilingFunc = f; }
|
||||
func::FuncOp getHostTilingFunc() const { return hostTilingFunc; }
|
||||
};
|
||||
|
||||
class TilingBase {
|
||||
public:
|
||||
explicit TilingBase(func::FuncOp f)
|
||||
: originalKernel_(f),
|
||||
module_(f ? f->getParentOfType<ModuleOp>() : ModuleOp()),
|
||||
kernelInfo_(std::make_unique<KernelInfo>()),
|
||||
tilingInfo_(),
|
||||
tilingKernel_() {
|
||||
if (kernelInfo_) {
|
||||
kernelInfo_->originalKernel = f;
|
||||
}
|
||||
}
|
||||
virtual ~TilingBase() = default;
|
||||
|
||||
LogicalResult runOnOperation(OpBuilder &builder) {
|
||||
if (failed(runPreTilingProcedure(builder))) return failure();
|
||||
if (failed(runTilingProcedure(builder))) return failure();
|
||||
if (failed(runPostTilingProcedure(builder))) return failure();
|
||||
return success();
|
||||
}
|
||||
static void setAutoTilingOptions(const AutoTilingOptions &opt) { options_ = opt; }
|
||||
|
||||
protected:
|
||||
LogicalResult runPreTilingProcedure(OpBuilder &) {
|
||||
kernelInfo_->baseKernelName = originalKernel_.getSymName().str();
|
||||
kernelInfo_->blockDim = options_.blockDim;
|
||||
|
||||
if (options_.enableManageHostResources) {
|
||||
(void)0;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult runTilingProcedure(OpBuilder &builder) {
|
||||
if (failed(createHostTilingFunction(builder))) return failure();
|
||||
if (failed(initTilingKernel(builder))) return failure();
|
||||
if (failed(applyTilingImpl(builder))) return failure();
|
||||
if (failed(fixCallSitesAndCaller(builder))) return failure();
|
||||
return success();
|
||||
}
|
||||
LogicalResult runPostTilingProcedure(OpBuilder &) { return success(); }
|
||||
|
||||
LogicalResult createHostTilingFunction(OpBuilder &builder) {
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(originalKernel_);
|
||||
|
||||
constexpr int64_t kDummyTiling[] = {0, 12280, 13, 1, 49120};
|
||||
constexpr unsigned kN = sizeof(kDummyTiling) / sizeof(kDummyTiling[0]);
|
||||
|
||||
auto origTy = originalKernel_.getFunctionType();
|
||||
if (origTy.getNumResults() != 1) {
|
||||
originalKernel_.emitError() << "expect exactly 1 result before rewriting";
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Type> argTypes(origTy.getInputs().begin(), origTy.getInputs().end());
|
||||
Type outTy = origTy.getResult(0);
|
||||
argTypes.push_back(outTy);
|
||||
|
||||
SmallVector<Type> resTypes(kN, builder.getI64Type());
|
||||
|
||||
std::string name = kernelInfo_->baseKernelName + "_single_outlined_0_0_tiling_function";
|
||||
auto funcTy = FunctionType::get(builder.getContext(), argTypes, resTypes);
|
||||
auto host = builder.create<func::FuncOp>(originalKernel_.getLoc(), name, funcTy);
|
||||
host.addEntryBlock();
|
||||
|
||||
host->setAttr(mockattr::kFunctionKind, StringAttr::get(builder.getContext(), mockattr::kHost));
|
||||
host->setAttr(mockattr::kHostFuncType, StringAttr::get(builder.getContext(), mockattr::kHostTilingFunction));
|
||||
|
||||
unsigned nInputs = origTy.getNumInputs();
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
host.setArgAttr(i, "hacc.arg_type", StringAttr::get(builder.getContext(), "input"));
|
||||
host.setArgAttr(i, "hacc.input_idx", builder.getI64IntegerAttr(i));
|
||||
}
|
||||
host.setArgAttr(nInputs, "hacc.arg_type", StringAttr::get(builder.getContext(), "output"));
|
||||
host.setArgAttr(nInputs, "hacc.output_idx", builder.getI64IntegerAttr(0));
|
||||
|
||||
host.setResultAttr(0, "hacc.arg_type", StringAttr::get(builder.getContext(), "tiling_key"));
|
||||
for (unsigned i = 1; i < kN; ++i)
|
||||
host.setResultAttr(i, "hacc.arg_type", StringAttr::get(builder.getContext(), "tiling_data"));
|
||||
|
||||
builder.setInsertionPointToEnd(&host.getBody().front());
|
||||
|
||||
SmallVector<Value> cst;
|
||||
cst.reserve(kN);
|
||||
std::transform(std::begin(kDummyTiling), std::end(kDummyTiling), std::back_inserter(cst),
|
||||
[&](int64_t v) { return builder.create<arith::ConstantIntOp>(host.getLoc(), v, 64); });
|
||||
|
||||
builder.create<func::ReturnOp>(host.getLoc(), cst);
|
||||
|
||||
tilingInfo_.setHostTilingFunc(host);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult initTilingKernel(OpBuilder &builder) {
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(originalKernel_);
|
||||
|
||||
constexpr unsigned kTilingCnt = 5;
|
||||
auto origTy = originalKernel_.getFunctionType();
|
||||
if (origTy.getNumResults() != 1) {
|
||||
originalKernel_.emitError() << "expect exactly 1 result";
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Type> devInputs(origTy.getInputs().begin(), origTy.getInputs().end());
|
||||
Type outTy = origTy.getResult(0);
|
||||
devInputs.push_back(outTy);
|
||||
|
||||
SmallVector<Type> devResults;
|
||||
std::string name = kernelInfo_->baseKernelName + "_single_outlined_0_0_0";
|
||||
auto devTy = FunctionType::get(builder.getContext(), devInputs, devResults);
|
||||
auto deviceFunc = builder.create<func::FuncOp>(originalKernel_.getLoc(), name, devTy);
|
||||
|
||||
unsigned nInputs = origTy.getNumInputs();
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
deviceFunc.setArgAttr(i, "hacc.arg_type", StringAttr::get(builder.getContext(), "input"));
|
||||
deviceFunc.setArgAttr(i, "hacc.input_idx", builder.getI64IntegerAttr(i));
|
||||
}
|
||||
deviceFunc.setArgAttr(nInputs, "hacc.arg_type", StringAttr::get(builder.getContext(), "output"));
|
||||
deviceFunc.setArgAttr(nInputs, "hacc.output_idx", builder.getI64IntegerAttr(0));
|
||||
|
||||
deviceFunc->setAttr(mockattr::kEnableAutoMarkBufferSize, builder.getUnitAttr());
|
||||
deviceFunc->setAttr(mockattr::kFunctionKind, StringAttr::get(builder.getContext(), mockattr::kDevice));
|
||||
deviceFunc->setAttr(mockattr::kFusionKind,
|
||||
StringAttr::get(builder.getContext(), mockattr::kFusionKindPureElemwise));
|
||||
deviceFunc->setAttr(mockattr::kBlockDim, builder.getI64IntegerAttr(kernelInfo_->blockDim));
|
||||
deviceFunc->setAttr("hacc.entry", builder.getUnitAttr());
|
||||
if (auto hostTiling = tilingInfo_.getHostTilingFunc())
|
||||
deviceFunc->setAttr(mockattr::kTilingFunction, FlatSymbolRefAttr::get(hostTiling.getSymNameAttr()));
|
||||
|
||||
Block *entry = deviceFunc.addEntryBlock();
|
||||
OpBuilder b = OpBuilder::atBlockEnd(entry);
|
||||
Location loc = deviceFunc.getLoc();
|
||||
|
||||
SmallVector<Value> inArgs;
|
||||
inArgs.reserve(nInputs);
|
||||
for (unsigned i = 0; i < nInputs; ++i) inArgs.push_back(entry->getArgument(i));
|
||||
Value outArg = entry->getArgument(nInputs);
|
||||
|
||||
Value returned;
|
||||
{
|
||||
IRMapping map;
|
||||
if (!originalKernel_.empty()) {
|
||||
Block &oldEntry = originalKernel_.front();
|
||||
unsigned argToMap = std::min<unsigned>(oldEntry.getNumArguments(), nInputs);
|
||||
for (unsigned i = 0; i < argToMap; ++i) map.map(oldEntry.getArgument(i), inArgs[i]);
|
||||
if (oldEntry.getNumArguments() > nInputs) map.map(oldEntry.getArgument(nInputs), outArg);
|
||||
|
||||
func::ReturnOp oldRet = nullptr;
|
||||
SmallVector<Operation *> toClone;
|
||||
for (Operation &op : oldEntry) {
|
||||
if (auto r = dyn_cast<func::ReturnOp>(op)) {
|
||||
oldRet = r;
|
||||
continue;
|
||||
}
|
||||
toClone.push_back(&op);
|
||||
}
|
||||
for (Operation *op : toClone) b.clone(*op, map);
|
||||
|
||||
if (oldRet && oldRet.getNumOperands() == 1) returned = map.lookupOrDefault(oldRet.getOperand(0));
|
||||
}
|
||||
}
|
||||
if (!returned) returned = outArg;
|
||||
|
||||
if (returned != outArg) b.create<memref::CopyOp>(loc, returned, outArg);
|
||||
|
||||
b.create<func::ReturnOp>(loc);
|
||||
|
||||
tilingKernel_ = deviceFunc;
|
||||
|
||||
if (failed(createOrGetGetTilingStructSizeFunction(builder, deviceFunc))) return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
virtual LogicalResult applyTilingImpl(OpBuilder &) { return success(); }
|
||||
|
||||
LogicalResult fixCallSitesAndCaller(OpBuilder &builder) {
|
||||
assert(tilingKernel_);
|
||||
|
||||
constexpr unsigned kTilingCnt = 5;
|
||||
(void)kTilingCnt;
|
||||
|
||||
auto oldTy = originalKernel_.getFunctionType();
|
||||
unsigned nInputs = oldTy.getNumInputs();
|
||||
Type outTy = oldTy.getResult(0);
|
||||
|
||||
SmallVector<Type> newInputs(oldTy.getInputs().begin(), oldTy.getInputs().end());
|
||||
newInputs.push_back(outTy);
|
||||
SmallVector<Type> newResults = {outTy};
|
||||
|
||||
originalKernel_.setFunctionType(FunctionType::get(builder.getContext(), newInputs, newResults));
|
||||
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
originalKernel_.setArgAttr(i, "hacc.arg_type", StringAttr::get(builder.getContext(), "input"));
|
||||
originalKernel_.setArgAttr(i, "hacc.input_idx", builder.getI64IntegerAttr(i));
|
||||
}
|
||||
originalKernel_.setArgAttr(nInputs, "hacc.arg_type", StringAttr::get(builder.getContext(), "output"));
|
||||
originalKernel_.setArgAttr(nInputs, "hacc.output_idx", builder.getI64IntegerAttr(0));
|
||||
|
||||
while (!originalKernel_.getBody().empty()) originalKernel_.getBody().front().erase();
|
||||
|
||||
Block *entry = originalKernel_.addEntryBlock();
|
||||
OpBuilder b = OpBuilder::atBlockEnd(entry);
|
||||
Location loc = originalKernel_.getLoc();
|
||||
|
||||
SmallVector<Value> passArgs(entry->args_begin(), entry->args_end());
|
||||
b.create<func::CallOp>(loc, tilingKernel_.getSymName(), TypeRange{}, passArgs);
|
||||
|
||||
Value outArg = entry->getArgument(nInputs);
|
||||
b.create<func::ReturnOp>(loc, outArg);
|
||||
|
||||
auto hostTiling = tilingInfo_.getHostTilingFunc();
|
||||
SmallVector<func::CallOp> callers;
|
||||
module_.walk([&](func::CallOp c) {
|
||||
if (c.getCallee() == kernelInfo_->baseKernelName) callers.push_back(c);
|
||||
});
|
||||
|
||||
for (func::CallOp callOp : callers) {
|
||||
OpBuilder::InsertionGuard gg(builder);
|
||||
builder.setInsertionPoint(callOp);
|
||||
|
||||
SmallVector<Value> operands(callOp.getOperands().begin(), callOp.getOperands().end());
|
||||
auto outMemRefTy = dyn_cast<MemRefType>(outTy);
|
||||
if (!outMemRefTy) {
|
||||
callOp.emitError() << "expect memref out";
|
||||
return failure();
|
||||
}
|
||||
Value outVal = builder.create<memref::AllocOp>(callOp.getLoc(), outMemRefTy);
|
||||
|
||||
SmallVector<Value> tilArgs(operands);
|
||||
tilArgs.push_back(outVal);
|
||||
auto tilResTys = hostTiling.getFunctionType().getResults();
|
||||
auto tilCall = builder.create<func::CallOp>(callOp.getLoc(), hostTiling.getSymName(), tilResTys, tilArgs);
|
||||
|
||||
SmallVector<Value> newCallArgs(operands);
|
||||
newCallArgs.push_back(outVal);
|
||||
newCallArgs.append(tilCall.getResults().begin(), tilCall.getResults().end());
|
||||
|
||||
auto newCall = builder.create<func::CallOp>(callOp.getLoc(), originalKernel_.getSymName(),
|
||||
originalKernel_.getFunctionType().getResults(), newCallArgs);
|
||||
|
||||
callOp.replaceAllUsesWith(newCall.getResults());
|
||||
callOp.erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult createOrGetGetTilingStructSizeFunction(OpBuilder &builder, func::FuncOp deviceFunc) {
|
||||
ModuleOp module = deviceFunc->getParentOfType<ModuleOp>();
|
||||
if (!module) {
|
||||
deviceFunc.emitError() << "cannot find parent ModuleOp for device function";
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::string base = deviceFunc.getSymName().str();
|
||||
if (base.size() >= 2 && base.substr(base.size() - 2) == "_0") {
|
||||
base = base.substr(0, base.size() - 2);
|
||||
}
|
||||
|
||||
std::string hostName = base + "_get_tiling_struct_size_function";
|
||||
|
||||
if (auto sym = SymbolTable::lookupSymbolIn(module, StringAttr::get(module.getContext(), hostName))) {
|
||||
if (isa<func::FuncOp>(sym)) return success();
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointToStart(module.getBody());
|
||||
|
||||
auto funcTy =
|
||||
FunctionType::get(module.getContext(), /*inputs=*/TypeRange{}, /*results=*/TypeRange{builder.getI64Type()});
|
||||
auto host = builder.create<func::FuncOp>(deviceFunc.getLoc(), hostName, funcTy);
|
||||
host.setVisibility(SymbolTable::Visibility::Public);
|
||||
|
||||
host->setAttr(mockattr::kFunctionKind, StringAttr::get(builder.getContext(), mockattr::kHost));
|
||||
|
||||
Block *entry = host.addEntryBlock();
|
||||
OpBuilder b = OpBuilder::atBlockEnd(entry);
|
||||
auto zero = b.create<arith::ConstantIntOp>(host.getLoc(), /*value=*/0, /*width=*/64);
|
||||
b.create<func::ReturnOp>(host.getLoc(), ValueRange{zero});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
protected:
|
||||
func::FuncOp originalKernel_;
|
||||
ModuleOp module_;
|
||||
std::unique_ptr<KernelInfo> kernelInfo_;
|
||||
TilingInfo tilingInfo_;
|
||||
func::FuncOp tilingKernel_;
|
||||
static AutoTilingOptions options_;
|
||||
};
|
||||
|
||||
AutoTilingOptions TilingBase::options_;
|
||||
|
||||
class PureElemwiseTiling : public TilingBase {
|
||||
public:
|
||||
using TilingBase::TilingBase;
|
||||
LogicalResult applyTilingImpl(OpBuilder &) override { return success(); }
|
||||
};
|
||||
|
||||
struct TilingFunc : public mlir::impl::TilingFuncBase<TilingFunc> {
|
||||
TilingFunc() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
ModuleOp module = func->getParentOfType<ModuleOp>();
|
||||
if (!module) {
|
||||
func.emitError() << "cannot find parent ModuleOp";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
SmallVector<func::FuncOp> kernels;
|
||||
module.walk([&](func::FuncOp f) {
|
||||
if (auto kind = f->getAttrOfType<StringAttr>(mockattr::kFunctionKind);
|
||||
!kind || kind.getValue() == mockattr::kDevice)
|
||||
kernels.push_back(f);
|
||||
});
|
||||
|
||||
AutoTilingOptions opts;
|
||||
|
||||
TilingBase::setAutoTilingOptions(opts);
|
||||
|
||||
OpBuilder builder(func.getContext());
|
||||
for (func::FuncOp k : kernels) {
|
||||
PureElemwiseTiling sch(k);
|
||||
if (failed(sch.runOnOperation(builder))) {
|
||||
k.emitError() << "auto-tiling failed";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace mlir::affine
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> mlir::affine::createTilingFuncPass() {
|
||||
return std::make_unique<TilingFunc>();
|
||||
}
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "akg/Pipelines/AscendPipelines/AscendOpt.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "akg/Conversion/Passes.h"
|
||||
#include "akg/Dialect/Affine/Passes.h"
|
||||
@@ -71,7 +72,6 @@ void createAscendOptPipelineImpl(OpPassManager &pm, const mlir::AscendOptPipelin
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createMemrefCopyToLoopsPass());
|
||||
|
||||
|
||||
OpPassManager &nestedFusionPM = pm.nest<mlir::func::FuncOp>();
|
||||
nestedFusionPM.addPass(mlir::createConvertLinalgToAffineLoopsPass());
|
||||
nestedFusionPM.addPass(mlir::affine::createAffineReductionAnnotationPass());
|
||||
@@ -87,6 +87,9 @@ void createAscendOptPipelineImpl(OpPassManager &pm, const mlir::AscendOptPipelin
|
||||
nestedFusionPM.addPass(mlir::createAKGLoopTilingPass(false)); // useAutoTiling = false
|
||||
nestedFusionPM.addPass(mlir::affine::createAffineForVectPass());
|
||||
nestedFusionPM.addPass(mlir::affine::createVectorTransferTensorizePass());
|
||||
if (const char *v = std::getenv("TILINGFUNC"); v && std::string(v) == "1") {
|
||||
nestedFusionPM.addPass(mlir::affine::createTilingFuncPass());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
65
akg-mlir/tests/ut/Dialect/Affine/tiling_func.mlir
Normal file
65
akg-mlir/tests/ut/Dialect/Affine/tiling_func.mlir
Normal file
@@ -0,0 +1,65 @@
|
||||
// RUN: akg-opt %s --tiling-func | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: #map = affine_map<(d0) -> (d0)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0) -> (d0 + 512)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @Fused_Add_fusion_3324123131231234556_single_outlined_0_0_get_tiling_struct_size_function() -> i64 attributes {hacc.function_kind = "HOST"} {
|
||||
// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64
|
||||
// CHECK-NEXT: return %c0_i64 : i64
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: func.func @Fused_Add_fusion_3324123131231234556_single_outlined_0_0_tiling_function(%arg0: memref<171520xf32> {hacc.arg_type = "input", hacc.input_idx = 0 : i64}, %arg1: memref<171520xf32> {hacc.arg_type = "input", hacc.input_idx = 1 : i64}, %arg2: memref<171520xf32> {hacc.arg_type = "output", hacc.output_idx = 0 : i64}) -> (i64 {hacc.arg_type = "tiling_key"}, i64 {hacc.arg_type = "tiling_data"}, i64 {hacc.arg_type = "tiling_data"}, i64 {hacc.arg_type = "tiling_data"}, i64 {hacc.arg_type = "tiling_data"}) attributes {hacc.function_kind = "HOST", hacc.host_func_type = "tiling_function"} {
|
||||
// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64
|
||||
// CHECK-NEXT: %c12280_i64 = arith.constant 12280 : i64
|
||||
// CHECK-NEXT: %c13_i64 = arith.constant 13 : i64
|
||||
// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %c49120_i64 = arith.constant 49120 : i64
|
||||
// CHECK-NEXT: return %c0_i64, %c12280_i64, %c13_i64, %c1_i64, %c49120_i64 : i64, i64, i64, i64, i64
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: func.func @Fused_Add_fusion_3324123131231234556_single_outlined_0_0_0(%arg0: memref<171520xf32> {hacc.arg_type = "input", hacc.input_idx = 0 : i64}, %arg1: memref<171520xf32> {hacc.arg_type = "input", hacc.input_idx = 1 : i64}, %arg2: memref<171520xf32> {hacc.arg_type = "output", hacc.output_idx = 0 : i64}) attributes {enable_auto_mark_buffer_size, hacc.block_dim = 40 : i64, hacc.entry, hacc.function_kind = "DEVICE", hacc.tiling_function = @Fused_Add_fusion_3324123131231234556_single_outlined_0_0_tiling_function, hfusion.fusion_kind = "PURE_ELEMWISE"} {
|
||||
// CHECK-NEXT: %alloc = memref.alloc() {alignment = 64 : i64} : memref<171520xf32>
|
||||
// CHECK-NEXT: affine.for %arg3 = 0 to 171520 step 512 {
|
||||
// CHECK-NEXT: affine.for %arg4 = #map(%arg3) to #map1(%arg3) step 512 {
|
||||
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %subview = memref.subview %arg0[0] [512] [1] : memref<171520xf32> to memref<512xf32>
|
||||
// CHECK-NEXT: %0 = bufferization.to_tensor %subview restrict writable : memref<512xf32>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %subview_1 = memref.subview %arg1[0] [512] [1] : memref<171520xf32> to memref<512xf32>
|
||||
// CHECK-NEXT: %1 = bufferization.to_tensor %subview_1 restrict writable : memref<512xf32>
|
||||
// CHECK-NEXT: %2 = arith.addf %0, %1 : tensor<512xf32>
|
||||
// CHECK-NEXT: %subview_2 = memref.subview %alloc[0] [512] [1] : memref<171520xf32> to memref<512xf32>
|
||||
// CHECK-NEXT: %3 = bufferization.to_memref %2 : memref<512xf32>
|
||||
// CHECK-NEXT: memref.copy %3, %subview_2 : memref<512xf32> to memref<512xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: memref.copy %alloc, %arg2 : memref<171520xf32> to memref<171520xf32>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: func.func @Fused_Add_fusion_3324123131231234556(%arg0: memref<171520xf32> {hacc.arg_type = "input", hacc.input_idx = 0 : i64}, %arg1: memref<171520xf32> {hacc.arg_type = "input", hacc.input_idx = 1 : i64}, %arg2: memref<171520xf32> {hacc.arg_type = "output", hacc.output_idx = 0 : i64}) -> memref<171520xf32> attributes {OperatorType = "Elementwise", compute_capability = "", mindspore_kernel, process = "aicore"} {
|
||||
// CHECK-NEXT: call @Fused_Add_fusion_3324123131231234556_single_outlined_0_0_0(%arg0, %arg1, %arg2) : (memref<171520xf32>, memref<171520xf32>, memref<171520xf32>) -> ()
|
||||
// CHECK-NEXT: return %arg2 : memref<171520xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
#map = affine_map<(d0) -> (d0)>
|
||||
#map1 = affine_map<(d0) -> (d0 + 512)>
|
||||
module {
|
||||
func.func @Fused_Add_fusion_3324123131231234556(%arg0: memref<171520xf32>, %arg1: memref<171520xf32>) -> memref<171520xf32> attributes {OperatorType = "Elementwise", compute_capability = "", mindspore_kernel, process = "aicore"} {
|
||||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<171520xf32>
|
||||
affine.for %arg2 = 0 to 171520 step 512 {
|
||||
affine.for %arg3 = #map(%arg2) to #map1(%arg2) step 512 {
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%subview = memref.subview %arg0[0] [512] [1] : memref<171520xf32> to memref<512xf32>
|
||||
%0 = bufferization.to_tensor %subview restrict writable : memref<512xf32>
|
||||
%cst_0 = arith.constant 0.000000e+00 : f32
|
||||
%subview_1 = memref.subview %arg1[0] [512] [1] : memref<171520xf32> to memref<512xf32>
|
||||
%1 = bufferization.to_tensor %subview_1 restrict writable : memref<512xf32>
|
||||
%2 = arith.addf %0, %1 : tensor<512xf32>
|
||||
%subview_2 = memref.subview %alloc[0] [512] [1] : memref<171520xf32> to memref<512xf32>
|
||||
%3 = bufferization.to_memref %2 : memref<512xf32>
|
||||
memref.copy %3, %subview_2 : memref<512xf32> to memref<512xf32>
|
||||
}
|
||||
}
|
||||
return %alloc : memref<171520xf32>
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user