1//===-- CIRLoweringEmitter.cpp - Generate CIR lowering patterns -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This TableGen backend emits CIR operation lowering patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TableGenBackends.h"
14#include "llvm/TableGen/Error.h"
15#include "llvm/TableGen/Record.h"
16#include "llvm/TableGen/TableGenBackend.h"
17#include <string>
18#include <utility>
19#include <vector>
20
21using namespace llvm;
22using namespace clang;
23
24namespace {
25std::vector<std::string> CXXABILoweringPatterns;
26std::vector<std::string> CXXABILoweringPatternsList;
27std::vector<std::string> CXXABILoweringAttrAlwaysLegal;
28std::vector<std::string> LLVMLoweringPatterns;
29std::vector<std::string> LLVMLoweringPatternsList;
30
31struct CustomLoweringCtor {
32 struct Param {
33 std::string Type;
34 std::string Name;
35 };
36
37 std::vector<Param> Params;
38};
39
40// Adapted from mlir/lib/TableGen/Operator.cpp
41// Returns the C++ class name of the operation, which is the name of the
42// operation with the dialect prefix removed and the first underscore removed.
43// If the operation name starts with an underscore, the underscore is considered
44// part of the class name.
45std::string GetOpCppClassName(const Record *OpRecord) {
46 StringRef Name = OpRecord->getName();
47 StringRef Prefix;
48 StringRef CppClassName;
49 std::tie(args&: Prefix, args&: CppClassName) = Name.split(Separator: '_');
50 if (Prefix.empty()) {
51 // Class name with a leading underscore and without dialect prefix
52 return Name.str();
53 }
54 if (CppClassName.empty()) {
55 // Class name without dialect prefix
56 return Prefix.str();
57 }
58
59 return CppClassName.str();
60}
61
62std::string GetOpABILoweringPatternName(llvm::StringRef OpName) {
63 std::string Name = "CIR";
64 Name += OpName;
65 Name += "ABILowering";
66 return Name;
67}
68
69std::string GetOpLLVMLoweringPatternName(llvm::StringRef OpName) {
70 std::string Name = "CIRToLLVM";
71 Name += OpName;
72 Name += "Lowering";
73 return Name;
74}
75std::optional<CustomLoweringCtor> parseCustomLoweringCtor(const Record *R) {
76 if (!R)
77 return std::nullopt;
78
79 CustomLoweringCtor Ctor;
80 const DagInit *Args = R->getValueAsDag(FieldName: "dagParams");
81
82 for (const auto &[Arg, Name] : Args->getArgAndNames()) {
83 Ctor.Params.push_back(
84 x: {.Type: Arg->getAsUnquotedString(), .Name: Name->getAsUnquotedString()});
85 }
86
87 return Ctor;
88}
89
90void emitCustomParamList(raw_ostream &Code,
91 ArrayRef<CustomLoweringCtor::Param> Params) {
92 for (const CustomLoweringCtor::Param &Param : Params) {
93 Code << ", ";
94 Code << Param.Type << " " << Param.Name;
95 }
96}
97
98void emitCustomInitList(raw_ostream &Code,
99 ArrayRef<CustomLoweringCtor::Param> Params) {
100 for (const CustomLoweringCtor::Param &P : Params)
101 Code << ", " << P.Name << "(" << P.Name << ")";
102}
103
104void GenerateABILoweringPattern(llvm::StringRef OpName,
105 llvm::StringRef PatternName) {
106 std::string CodeBuffer;
107 llvm::raw_string_ostream Code(CodeBuffer);
108
109 Code << "class " << PatternName
110 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
111 Code << " [[maybe_unused]] mlir::DataLayout *dataLayout;\n";
112 Code << " [[maybe_unused]] cir::LowerModule *lowerModule;\n";
113 Code << "\n";
114
115 Code << "public:\n";
116 Code << " " << PatternName
117 << "(mlir::MLIRContext *context, const mlir::TypeConverter "
118 "&typeConverter, mlir::DataLayout &dataLayout, cir::LowerModule "
119 "&lowerModule)\n";
120 Code << " : OpConversionPattern<cir::" << OpName
121 << ">(typeConverter, context), dataLayout(&dataLayout), "
122 "lowerModule(&lowerModule) {}\n";
123 Code << "\n";
124
125 Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
126 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
127 "const override;\n";
128
129 Code << "};\n";
130
131 CXXABILoweringPatterns.push_back(x: std::move(CodeBuffer));
132}
133
134void GenerateLLVMLoweringPattern(llvm::StringRef OpName,
135 llvm::StringRef PatternName, bool IsRecursive,
136 llvm::StringRef ExtraDecl,
137 const Record *CustomCtorRec,
138 llvm::StringRef LLVMOp) {
139 std::optional<CustomLoweringCtor> CustomCtor =
140 parseCustomLoweringCtor(R: CustomCtorRec);
141 std::string CodeBuffer;
142 llvm::raw_string_ostream Code(CodeBuffer);
143
144 Code << "class " << PatternName
145 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
146 Code << " [[maybe_unused]] mlir::DataLayout const &dataLayout;\n";
147
148 if (CustomCtor) {
149 for (const CustomLoweringCtor::Param &P : CustomCtor->Params)
150 Code << " " << P.Type << " " << P.Name << ";\n";
151 }
152
153 Code << "\n";
154
155 Code << "public:\n";
156 Code << " using mlir::OpConversionPattern<cir::" << OpName
157 << ">::OpConversionPattern;\n";
158
159 // Constructor
160 Code << " " << PatternName
161 << "(const mlir::TypeConverter &typeConverter, "
162 "mlir::MLIRContext *context, const mlir::DataLayout &dataLayout";
163
164 if (CustomCtor)
165 emitCustomParamList(Code, Params: CustomCtor->Params);
166
167 Code << ")\n";
168
169 Code << " : OpConversionPattern<cir::" << OpName
170 << ">(typeConverter, context), dataLayout(dataLayout)";
171
172 if (CustomCtor)
173 emitCustomInitList(Code, Params: CustomCtor->Params);
174
175 Code << " {\n";
176
177 if (IsRecursive)
178 Code << " setHasBoundedRewriteRecursion();\n";
179
180 Code << " }\n\n";
181
182 if (!LLVMOp.empty()) {
183 // Generate the matchAndRewrite body automatically.
184 Code
185 << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
186 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
187 "const override {\n";
188 Code
189 << " mlir::Type resTy = typeConverter->convertType(op.getType());\n";
190 Code << " rewriter.replaceOpWithNewOp<mlir::LLVM::" << LLVMOp
191 << ">(op, resTy, adaptor.getOperands());\n";
192 Code << " return mlir::success();\n";
193 Code << " }\n";
194 } else {
195 Code
196 << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
197 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
198 "const override;\n";
199 }
200
201 if (!ExtraDecl.empty()) {
202 Code << "\nprivate:\n";
203 Code << ExtraDecl << "\n";
204 }
205
206 Code << "};\n";
207
208 LLVMLoweringPatterns.push_back(x: std::move(CodeBuffer));
209}
210
211void Generate(const Record *OpRecord) {
212 std::string OpName = GetOpCppClassName(OpRecord);
213
214 if (OpRecord->getValueAsBit(FieldName: "hasCXXABILowering")) {
215 std::string PatternName = GetOpABILoweringPatternName(OpName);
216 GenerateABILoweringPattern(OpName, PatternName);
217 CXXABILoweringPatternsList.push_back(x: std::move(PatternName));
218 }
219
220 if (OpRecord->getValueAsBit(FieldName: "hasLLVMLowering")) {
221 std::string PatternName = GetOpLLVMLoweringPatternName(OpName);
222 bool IsRecursive = OpRecord->getValueAsBit(FieldName: "isLLVMLoweringRecursive");
223 const Record *CustomCtor =
224 OpRecord->getValueAsOptionalDef(FieldName: "customLLVMLoweringConstructorDecl");
225 llvm::StringRef ExtraDecl =
226 OpRecord->getValueAsString(FieldName: "extraLLVMLoweringPatternDecl");
227
228 llvm::StringRef LLVMOp = OpRecord->getValueAsString(FieldName: "llvmOp");
229
230 if (!LLVMOp.empty() && CustomCtor)
231 PrintFatalError(ErrorLoc: OpRecord->getLoc(),
232 Msg: "op '" + OpName +
233 "' has both llvmOp and a custom lowering "
234 "constructor, which is not supported");
235
236 GenerateLLVMLoweringPattern(OpName, PatternName, IsRecursive, ExtraDecl,
237 CustomCtorRec: CustomCtor, LLVMOp);
238 // Only automatically register patterns that use the default constructor.
239 // Patterns with a custom constructor must be manually registered by the
240 // lowering pass.
241 if (!CustomCtor)
242 LLVMLoweringPatternsList.push_back(x: std::move(PatternName));
243 }
244}
245
246void GenerateCIREnumAttrs(const Record *Record) {
247 std::string OpName = GetOpCppClassName(OpRecord: Record);
248 // EnumAttr is in a separate hierarchy, so we have to set these separately, as
249 // they never have an 'illegal' CXXABI type in them.
250 CXXABILoweringAttrAlwaysLegal.push_back(x: "cir::" + OpName);
251}
252void GenerateCIRAttrs(const Record *Record) {
253 std::string OpName = GetOpCppClassName(OpRecord: Record);
254 if (!Record->getValueAsBit(FieldName: "canHaveIllegalCXXABIType"))
255 CXXABILoweringAttrAlwaysLegal.push_back(x: "cir::" + OpName);
256}
257} // namespace
258
259void clang::EmitCIRLowering(const llvm::RecordKeeper &RK,
260 llvm::raw_ostream &OS) {
261 emitSourceFileHeader(Desc: "Lowering patterns for CIR operations", OS);
262 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "CIR_Op"))
263 Generate(OpRecord);
264 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "EnumAttr"))
265 GenerateCIREnumAttrs(Record: OpRecord);
266 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "CIR_Attr"))
267 GenerateCIRAttrs(Record: OpRecord);
268
269 OS << "#ifdef GET_ABI_LOWERING_PATTERNS\n"
270 << llvm::join(R&: CXXABILoweringPatterns, Separator: "\n") << "#endif\n\n";
271 OS << "#ifdef GET_ABI_LOWERING_PATTERNS_LIST\n"
272 << llvm::join(R&: CXXABILoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
273
274 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS\n"
275 << llvm::join(R&: LLVMLoweringPatterns, Separator: "\n") << "#endif\n\n";
276 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS_LIST\n"
277 << llvm::join(R&: LLVMLoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
278
279 OS << "#ifdef CXX_ABI_ALWAYS_LEGAL_ATTRS\n"
280 << llvm::join(R&: CXXABILoweringAttrAlwaysLegal, Separator: ",\n") << "\n#endif\n\n";
281}
282