1//===-- CIRLoweringEmitter.cpp - Generate CIR lowering patterns -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This TableGen backend emits CIR operation lowering patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TableGenBackends.h"
14#include "llvm/TableGen/Record.h"
15#include "llvm/TableGen/TableGenBackend.h"
16#include <string>
17#include <utility>
18#include <vector>
19
20using namespace llvm;
21using namespace clang;
22
23namespace {
24std::vector<std::string> CXXABILoweringPatterns;
25std::vector<std::string> CXXABILoweringPatternsList;
26std::vector<std::string> LLVMLoweringPatterns;
27std::vector<std::string> LLVMLoweringPatternsList;
28
29struct CustomLoweringCtor {
30 struct Param {
31 std::string Type;
32 std::string Name;
33 };
34
35 std::vector<Param> Params;
36};
37
38// Adapted from mlir/lib/TableGen/Operator.cpp
39// Returns the C++ class name of the operation, which is the name of the
40// operation with the dialect prefix removed and the first underscore removed.
41// If the operation name starts with an underscore, the underscore is considered
42// part of the class name.
43std::string GetOpCppClassName(const Record *OpRecord) {
44 StringRef Name = OpRecord->getName();
45 StringRef Prefix;
46 StringRef CppClassName;
47 std::tie(args&: Prefix, args&: CppClassName) = Name.split(Separator: '_');
48 if (Prefix.empty()) {
49 // Class name with a leading underscore and without dialect prefix
50 return Name.str();
51 }
52 if (CppClassName.empty()) {
53 // Class name without dialect prefix
54 return Prefix.str();
55 }
56
57 return CppClassName.str();
58}
59
60std::string GetOpABILoweringPatternName(llvm::StringRef OpName) {
61 std::string Name = "CIR";
62 Name += OpName;
63 Name += "ABILowering";
64 return Name;
65}
66
67std::string GetOpLLVMLoweringPatternName(llvm::StringRef OpName) {
68 std::string Name = "CIRToLLVM";
69 Name += OpName;
70 Name += "Lowering";
71 return Name;
72}
73std::optional<CustomLoweringCtor> parseCustomLoweringCtor(const Record *R) {
74 if (!R)
75 return std::nullopt;
76
77 CustomLoweringCtor Ctor;
78 const DagInit *Args = R->getValueAsDag(FieldName: "dagParams");
79
80 for (const auto &[Arg, Name] : Args->getArgAndNames()) {
81 Ctor.Params.push_back(
82 x: {.Type: Arg->getAsUnquotedString(), .Name: Name->getAsUnquotedString()});
83 }
84
85 return Ctor;
86}
87
88void emitCustomParamList(raw_ostream &Code,
89 ArrayRef<CustomLoweringCtor::Param> Params) {
90 for (const CustomLoweringCtor::Param &Param : Params) {
91 Code << ", ";
92 Code << Param.Type << " " << Param.Name;
93 }
94}
95
96void emitCustomInitList(raw_ostream &Code,
97 ArrayRef<CustomLoweringCtor::Param> Params) {
98 for (const CustomLoweringCtor::Param &P : Params)
99 Code << ", " << P.Name << "(" << P.Name << ")";
100}
101
102void GenerateABILoweringPattern(llvm::StringRef OpName,
103 llvm::StringRef PatternName) {
104 std::string CodeBuffer;
105 llvm::raw_string_ostream Code(CodeBuffer);
106
107 Code << "class " << PatternName
108 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
109 Code << " [[maybe_unused]] mlir::DataLayout *dataLayout;\n";
110 Code << " [[maybe_unused]] cir::LowerModule *lowerModule;\n";
111 Code << "\n";
112
113 Code << "public:\n";
114 Code << " " << PatternName
115 << "(mlir::MLIRContext *context, const mlir::TypeConverter "
116 "&typeConverter, mlir::DataLayout &dataLayout, cir::LowerModule "
117 "&lowerModule)\n";
118 Code << " : OpConversionPattern<cir::" << OpName
119 << ">(typeConverter, context), dataLayout(&dataLayout), "
120 "lowerModule(&lowerModule) {}\n";
121 Code << "\n";
122
123 Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
124 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
125 "const override;\n";
126
127 Code << "};\n";
128
129 CXXABILoweringPatterns.push_back(x: std::move(CodeBuffer));
130}
131
132void GenerateLLVMLoweringPattern(llvm::StringRef OpName,
133 llvm::StringRef PatternName, bool IsRecursive,
134 llvm::StringRef ExtraDecl,
135 const Record *CustomCtorRec) {
136 std::optional<CustomLoweringCtor> CustomCtor =
137 parseCustomLoweringCtor(R: CustomCtorRec);
138 std::string CodeBuffer;
139 llvm::raw_string_ostream Code(CodeBuffer);
140
141 Code << "class " << PatternName
142 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
143 Code << " [[maybe_unused]] cir::LowerModule *lowerMod;\n";
144 Code << " [[maybe_unused]] mlir::DataLayout const &dataLayout;\n";
145
146 if (CustomCtor) {
147 for (const CustomLoweringCtor::Param &P : CustomCtor->Params)
148 Code << " " << P.Type << " " << P.Name << ";\n";
149 }
150
151 Code << "\n";
152
153 Code << "public:\n";
154 Code << " using mlir::OpConversionPattern<cir::" << OpName
155 << ">::OpConversionPattern;\n";
156
157 // Constructor
158 Code << " " << PatternName
159 << "(mlir::TypeConverter const "
160 "&typeConverter, mlir::MLIRContext *context, "
161 "cir::LowerModule *lowerMod, mlir::DataLayout const "
162 "&dataLayout";
163
164 if (CustomCtor)
165 emitCustomParamList(Code, Params: CustomCtor->Params);
166
167 Code << ")\n";
168
169 Code << " : OpConversionPattern<cir::" << OpName
170 << ">(typeConverter, context), lowerMod(lowerMod), "
171 "dataLayout(dataLayout)";
172
173 if (CustomCtor)
174 emitCustomInitList(Code, Params: CustomCtor->Params);
175
176 Code << " {\n";
177
178 if (IsRecursive)
179 Code << " setHasBoundedRewriteRecursion();\n";
180
181 Code << " }\n\n";
182
183 Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
184 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
185 "const override;\n";
186
187 if (!ExtraDecl.empty()) {
188 Code << "\nprivate:\n";
189 Code << ExtraDecl << "\n";
190 }
191
192 Code << "};\n";
193
194 LLVMLoweringPatterns.push_back(x: std::move(CodeBuffer));
195}
196
197void Generate(const Record *OpRecord) {
198 std::string OpName = GetOpCppClassName(OpRecord);
199
200 if (OpRecord->getValueAsBit(FieldName: "hasCXXABILowering")) {
201 std::string PatternName = GetOpABILoweringPatternName(OpName);
202 GenerateABILoweringPattern(OpName, PatternName);
203 CXXABILoweringPatternsList.push_back(x: std::move(PatternName));
204 }
205
206 if (OpRecord->getValueAsBit(FieldName: "hasLLVMLowering")) {
207 std::string PatternName = GetOpLLVMLoweringPatternName(OpName);
208 bool IsRecursive = OpRecord->getValueAsBit(FieldName: "isLLVMLoweringRecursive");
209 const Record *CustomCtor =
210 OpRecord->getValueAsOptionalDef(FieldName: "customLLVMLoweringConstructorDecl");
211 llvm::StringRef ExtraDecl =
212 OpRecord->getValueAsString(FieldName: "extraLLVMLoweringPatternDecl");
213
214 GenerateLLVMLoweringPattern(OpName, PatternName, IsRecursive, ExtraDecl,
215 CustomCtorRec: CustomCtor);
216 // Only automatically register patterns that use the default constructor.
217 // Patterns with a custom constructor must be manually registered by the
218 // lowering pass.
219 if (!CustomCtor)
220 LLVMLoweringPatternsList.push_back(x: std::move(PatternName));
221 }
222}
223} // namespace
224
225void clang::EmitCIRLowering(const llvm::RecordKeeper &RK,
226 llvm::raw_ostream &OS) {
227 emitSourceFileHeader(Desc: "Lowering patterns for CIR operations", OS);
228 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "CIR_Op"))
229 Generate(OpRecord);
230
231 OS << "#ifdef GET_ABI_LOWERING_PATTERNS\n"
232 << llvm::join(R&: CXXABILoweringPatterns, Separator: "\n") << "#endif\n\n";
233 OS << "#ifdef GET_ABI_LOWERING_PATTERNS_LIST\n"
234 << llvm::join(R&: CXXABILoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
235
236 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS\n"
237 << llvm::join(R&: LLVMLoweringPatterns, Separator: "\n") << "#endif\n\n";
238 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS_LIST\n"
239 << llvm::join(R&: LLVMLoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
240}
241