1//===-- CIRLoweringEmitter.cpp - Generate CIR lowering patterns -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This TableGen backend emits CIR operation lowering patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TableGenBackends.h"
14#include "llvm/TableGen/Error.h"
15#include "llvm/TableGen/Record.h"
16#include "llvm/TableGen/TableGenBackend.h"
17#include <string>
18#include <utility>
19#include <vector>
20
21using namespace llvm;
22using namespace clang;
23
24namespace {
25std::vector<std::string> CXXABILoweringPatterns;
26std::vector<std::string> CXXABILoweringPatternsList;
27std::vector<std::string> CXXABILoweringAttrAlwaysLegal;
28std::vector<std::string> LLVMLoweringPatterns;
29std::vector<std::string> LLVMLoweringPatternsList;
30std::string CIRAttrToValueVisitFunc;
31std::vector<std::string> CIRAttrToValueVisitorCaseTypes;
32std::vector<std::string> CIRAttrToValueVisitorDecls;
33
34struct CustomLoweringCtor {
35 struct Param {
36 std::string Type;
37 std::string Name;
38 };
39
40 std::vector<Param> Params;
41};
42
43// Adapted from mlir/lib/TableGen/Operator.cpp
44// Returns the C++ class name of the operation, which is the name of the
45// operation with the dialect prefix removed and the first underscore removed.
46// If the operation name starts with an underscore, the underscore is considered
47// part of the class name.
48std::string GetOpCppClassName(const Record *OpRecord) {
49 StringRef Name = OpRecord->getName();
50 StringRef Prefix;
51 StringRef CppClassName;
52 std::tie(args&: Prefix, args&: CppClassName) = Name.split(Separator: '_');
53 if (Prefix.empty()) {
54 // Class name with a leading underscore and without dialect prefix
55 return Name.str();
56 }
57 if (CppClassName.empty()) {
58 // Class name without dialect prefix
59 return Prefix.str();
60 }
61
62 return CppClassName.str();
63}
64
65std::string GetOpABILoweringPatternName(llvm::StringRef OpName) {
66 std::string Name = "CIR";
67 Name += OpName;
68 Name += "ABILowering";
69 return Name;
70}
71
72std::string GetOpLLVMLoweringPatternName(llvm::StringRef OpName) {
73 std::string Name = "CIRToLLVM";
74 Name += OpName;
75 Name += "Lowering";
76 return Name;
77}
78std::optional<CustomLoweringCtor> parseCustomLoweringCtor(const Record *R) {
79 if (!R)
80 return std::nullopt;
81
82 CustomLoweringCtor Ctor;
83 const DagInit *Args = R->getValueAsDag(FieldName: "dagParams");
84
85 for (const auto &[Arg, Name] : Args->getArgAndNames()) {
86 Ctor.Params.push_back(
87 x: {.Type: Arg->getAsUnquotedString(), .Name: Name->getAsUnquotedString()});
88 }
89
90 return Ctor;
91}
92
93void emitCustomParamList(raw_ostream &Code,
94 ArrayRef<CustomLoweringCtor::Param> Params) {
95 for (const CustomLoweringCtor::Param &Param : Params) {
96 Code << ", ";
97 Code << Param.Type << " " << Param.Name;
98 }
99}
100
101void emitCustomInitList(raw_ostream &Code,
102 ArrayRef<CustomLoweringCtor::Param> Params) {
103 for (const CustomLoweringCtor::Param &P : Params)
104 Code << ", " << P.Name << "(" << P.Name << ")";
105}
106
107void GenerateABILoweringPattern(llvm::StringRef OpName,
108 llvm::StringRef PatternName) {
109 std::string CodeBuffer;
110 llvm::raw_string_ostream Code(CodeBuffer);
111
112 Code << "class " << PatternName
113 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
114 Code << " [[maybe_unused]] mlir::DataLayout *dataLayout;\n";
115 Code << " [[maybe_unused]] cir::LowerModule *lowerModule;\n";
116 Code << "\n";
117
118 Code << "public:\n";
119 Code << " " << PatternName
120 << "(mlir::MLIRContext *context, const mlir::TypeConverter "
121 "&typeConverter, mlir::DataLayout &dataLayout, cir::LowerModule "
122 "&lowerModule)\n";
123 Code << " : OpConversionPattern<cir::" << OpName
124 << ">(typeConverter, context), dataLayout(&dataLayout), "
125 "lowerModule(&lowerModule) {}\n";
126 Code << "\n";
127
128 Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
129 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
130 "const override;\n";
131
132 Code << "};\n";
133
134 CXXABILoweringPatterns.push_back(x: std::move(CodeBuffer));
135}
136
137void GenerateLLVMLoweringPattern(llvm::StringRef OpName,
138 llvm::StringRef PatternName, bool IsRecursive,
139 llvm::StringRef ExtraDecl,
140 const Record *CustomCtorRec,
141 llvm::StringRef LLVMOp, bool HasZeroResult) {
142 std::optional<CustomLoweringCtor> CustomCtor =
143 parseCustomLoweringCtor(R: CustomCtorRec);
144 std::string CodeBuffer;
145 llvm::raw_string_ostream Code(CodeBuffer);
146
147 Code << "class " << PatternName
148 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
149 Code << " [[maybe_unused]] mlir::DataLayout const &dataLayout;\n";
150
151 if (CustomCtor) {
152 for (const CustomLoweringCtor::Param &P : CustomCtor->Params)
153 Code << " " << P.Type << " " << P.Name << ";\n";
154 }
155
156 Code << "\n";
157
158 Code << "public:\n";
159 Code << " using mlir::OpConversionPattern<cir::" << OpName
160 << ">::OpConversionPattern;\n";
161
162 // Constructor
163 Code << " " << PatternName
164 << "(const mlir::TypeConverter &typeConverter, "
165 "mlir::MLIRContext *context, const mlir::DataLayout &dataLayout";
166
167 if (CustomCtor)
168 emitCustomParamList(Code, Params: CustomCtor->Params);
169
170 Code << ")\n";
171
172 Code << " : OpConversionPattern<cir::" << OpName
173 << ">(typeConverter, context), dataLayout(dataLayout)";
174
175 if (CustomCtor)
176 emitCustomInitList(Code, Params: CustomCtor->Params);
177
178 Code << " {\n";
179
180 if (IsRecursive)
181 Code << " setHasBoundedRewriteRecursion();\n";
182
183 Code << " }\n\n";
184
185 if (!LLVMOp.empty()) {
186 // Generate the matchAndRewrite body automatically.
187 Code
188 << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
189 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
190 "const override {\n";
191 if (HasZeroResult) {
192 Code << " rewriter.replaceOpWithNewOp<mlir::LLVM::" << LLVMOp
193 << ">(op, mlir::TypeRange{}, adaptor.getOperands());\n";
194 } else {
195 Code << " mlir::Type resTy = "
196 "typeConverter->convertType(op.getType());\n";
197 Code << " rewriter.replaceOpWithNewOp<mlir::LLVM::" << LLVMOp
198 << ">(op, resTy, adaptor.getOperands());\n";
199 }
200 Code << " return mlir::success();\n";
201 Code << " }\n";
202 } else {
203 Code
204 << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
205 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
206 "const override;\n";
207 }
208
209 if (!ExtraDecl.empty()) {
210 Code << "\nprivate:\n";
211 Code << ExtraDecl << "\n";
212 }
213
214 Code << "};\n";
215
216 LLVMLoweringPatterns.push_back(x: std::move(CodeBuffer));
217}
218
219void Generate(const Record *OpRecord) {
220 std::string OpName = GetOpCppClassName(OpRecord);
221
222 if (OpRecord->getValueAsBit(FieldName: "hasCXXABILowering")) {
223 std::string PatternName = GetOpABILoweringPatternName(OpName);
224 GenerateABILoweringPattern(OpName, PatternName);
225 CXXABILoweringPatternsList.push_back(x: std::move(PatternName));
226 }
227
228 if (OpRecord->getValueAsBit(FieldName: "hasLLVMLowering")) {
229 std::string PatternName = GetOpLLVMLoweringPatternName(OpName);
230 bool IsRecursive = OpRecord->getValueAsBit(FieldName: "isLLVMLoweringRecursive");
231 const Record *CustomCtor =
232 OpRecord->getValueAsOptionalDef(FieldName: "customLLVMLoweringConstructorDecl");
233 llvm::StringRef ExtraDecl =
234 OpRecord->getValueAsString(FieldName: "extraLLVMLoweringPatternDecl");
235
236 llvm::StringRef LLVMOp = OpRecord->getValueAsString(FieldName: "llvmOp");
237
238 if (!LLVMOp.empty() && CustomCtor)
239 PrintFatalError(ErrorLoc: OpRecord->getLoc(),
240 Msg: "op '" + OpName +
241 "' has both llvmOp and a custom lowering "
242 "constructor, which is not supported");
243
244 const DagInit *ResultsDag = OpRecord->getValueAsDag(FieldName: "results");
245 bool IsZeroResult = ResultsDag->getNumArgs() == 0;
246 GenerateLLVMLoweringPattern(OpName, PatternName, IsRecursive, ExtraDecl,
247 CustomCtorRec: CustomCtor, LLVMOp, HasZeroResult: IsZeroResult);
248 // Only automatically register patterns that use the default constructor.
249 // Patterns with a custom constructor must be manually registered by the
250 // lowering pass.
251 if (!CustomCtor)
252 LLVMLoweringPatternsList.push_back(x: std::move(PatternName));
253 }
254}
255
256void GenerateCIREnumAttrs(const Record *Record) {
257 std::string OpName = GetOpCppClassName(OpRecord: Record);
258 // EnumAttr is in a separate hierarchy, so we have to set these separately, as
259 // they never have an 'illegal' CXXABI type in them.
260 CXXABILoweringAttrAlwaysLegal.push_back(x: "cir::" + OpName);
261}
262
263void GenerateAttrToValueVisitor(const Record *Rec) {
264 const Record *DialectRec = Rec->getValueAsDef(FieldName: "dialect");
265 llvm::StringRef Ns = DialectRec->getValueAsString(FieldName: "cppNamespace");
266 Ns.consume_front(Prefix: "::");
267 std::string CppClassRef = Ns.str();
268 CppClassRef += "::";
269 CppClassRef += Rec->getValueAsString(FieldName: "cppClassName");
270
271 std::string CodeBuffer;
272 llvm::raw_string_ostream Code(CodeBuffer);
273 Code << " mlir::Value visitCirAttr(" << CppClassRef << " attr);";
274 CIRAttrToValueVisitorDecls.push_back(x: std::move(CodeBuffer));
275 CIRAttrToValueVisitorCaseTypes.push_back(x: std::move(CppClassRef));
276}
277
278void GenerateAttrToValueVisitFunc() {
279 std::string CodeBuffer;
280 llvm::raw_string_ostream Code(CodeBuffer);
281 Code << " mlir::Value visit(mlir::Attribute attr) {\n"
282 << " return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)\n"
283 << " .Case<\n "
284 << llvm::join(R&: CIRAttrToValueVisitorCaseTypes, Separator: ",\n ")
285 << ">(\n"
286 << " [&](auto attrT) { return visitCirAttr(attrT); })\n"
287 << " .Default([this](mlir::Attribute attr) {\n"
288 << " mlir::emitError(parentOp->getLoc(), \"unsupported CIR "
289 "attribute in LLVM constant lowering\")\n"
290 << " << attr;\n"
291 << " return mlir::Value();\n"
292 << " });\n"
293 << " }\n";
294 CIRAttrToValueVisitFunc = std::move(CodeBuffer);
295}
296
297void GenerateCIRAttrs(const Record *Record) {
298 std::string OpName = GetOpCppClassName(OpRecord: Record);
299 if (!Record->getValueAsBit(FieldName: "canHaveIllegalCXXABIType"))
300 CXXABILoweringAttrAlwaysLegal.push_back(x: "cir::" + OpName);
301 if (Record->getValueAsBit(FieldName: "hasAttrToValueLowering"))
302 GenerateAttrToValueVisitor(Rec: Record);
303}
304} // namespace
305
306void clang::EmitCIRLowering(const llvm::RecordKeeper &RK,
307 llvm::raw_ostream &OS) {
308 emitSourceFileHeader(Desc: "Lowering patterns for CIR operations", OS);
309 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "CIR_Op"))
310 Generate(OpRecord);
311 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "EnumAttr"))
312 GenerateCIREnumAttrs(Record: OpRecord);
313 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "CIR_Attr"))
314 GenerateCIRAttrs(Record: OpRecord);
315 GenerateAttrToValueVisitFunc();
316
317 OS << "#ifdef GET_CIR_ATTR_TO_VALUE_VISITOR_DECLS\n"
318 << CIRAttrToValueVisitFunc << "\n"
319 << llvm::join(R&: CIRAttrToValueVisitorDecls, Separator: "\n") << "\n"
320 << "#endif // GET_CIR_ATTR_TO_VALUE_VISITOR_DECLS\n\n";
321
322 OS << "#ifdef GET_ABI_LOWERING_PATTERNS\n"
323 << llvm::join(R&: CXXABILoweringPatterns, Separator: "\n") << "#endif\n\n";
324 OS << "#ifdef GET_ABI_LOWERING_PATTERNS_LIST\n"
325 << llvm::join(R&: CXXABILoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
326
327 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS\n"
328 << llvm::join(R&: LLVMLoweringPatterns, Separator: "\n") << "#endif\n\n";
329 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS_LIST\n"
330 << llvm::join(R&: LLVMLoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
331
332 OS << "#ifdef CXX_ABI_ALWAYS_LEGAL_ATTRS\n"
333 << llvm::join(R&: CXXABILoweringAttrAlwaysLegal, Separator: ",\n") << "\n#endif\n\n";
334}
335