1//===-- CIRLoweringEmitter.cpp - Generate CIR lowering patterns -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This TableGen backend emits CIR operation lowering patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TableGenBackends.h"
14#include "llvm/TableGen/Record.h"
15#include "llvm/TableGen/TableGenBackend.h"
16#include <string>
17#include <utility>
18#include <vector>
19
20using namespace llvm;
21using namespace clang;
22
23namespace {
24std::vector<std::string> CXXABILoweringPatterns;
25std::vector<std::string> CXXABILoweringPatternsList;
26std::vector<std::string> LLVMLoweringPatterns;
27std::vector<std::string> LLVMLoweringPatternsList;
28
29struct CustomLoweringCtor {
30 struct Param {
31 std::string Type;
32 std::string Name;
33 };
34
35 std::vector<Param> Params;
36};
37
38// Adapted from mlir/lib/TableGen/Operator.cpp
39// Returns the C++ class name of the operation, which is the name of the
40// operation with the dialect prefix removed and the first underscore removed.
41// If the operation name starts with an underscore, the underscore is considered
42// part of the class name.
43std::string GetOpCppClassName(const Record *OpRecord) {
44 StringRef Name = OpRecord->getName();
45 StringRef Prefix;
46 StringRef CppClassName;
47 std::tie(args&: Prefix, args&: CppClassName) = Name.split(Separator: '_');
48 if (Prefix.empty()) {
49 // Class name with a leading underscore and without dialect prefix
50 return Name.str();
51 }
52 if (CppClassName.empty()) {
53 // Class name without dialect prefix
54 return Prefix.str();
55 }
56
57 return CppClassName.str();
58}
59
60std::string GetOpABILoweringPatternName(llvm::StringRef OpName) {
61 std::string Name = "CIR";
62 Name += OpName;
63 Name += "ABILowering";
64 return Name;
65}
66
67std::string GetOpLLVMLoweringPatternName(llvm::StringRef OpName) {
68 std::string Name = "CIRToLLVM";
69 Name += OpName;
70 Name += "Lowering";
71 return Name;
72}
73std::optional<CustomLoweringCtor> parseCustomLoweringCtor(const Record *R) {
74 if (!R)
75 return std::nullopt;
76
77 CustomLoweringCtor Ctor;
78 const DagInit *Args = R->getValueAsDag(FieldName: "dagParams");
79
80 for (const auto &[Arg, Name] : Args->getArgAndNames()) {
81 Ctor.Params.push_back(
82 x: {.Type: Arg->getAsUnquotedString(), .Name: Name->getAsUnquotedString()});
83 }
84
85 return Ctor;
86}
87
88void emitCustomParamList(raw_ostream &Code,
89 ArrayRef<CustomLoweringCtor::Param> Params) {
90 for (const CustomLoweringCtor::Param &Param : Params) {
91 Code << ", ";
92 Code << Param.Type << " " << Param.Name;
93 }
94}
95
96void emitCustomInitList(raw_ostream &Code,
97 ArrayRef<CustomLoweringCtor::Param> Params) {
98 for (const CustomLoweringCtor::Param &P : Params)
99 Code << ", " << P.Name << "(" << P.Name << ")";
100}
101
102void GenerateABILoweringPattern(llvm::StringRef OpName,
103 llvm::StringRef PatternName) {
104 std::string CodeBuffer;
105 llvm::raw_string_ostream Code(CodeBuffer);
106
107 Code << "class " << PatternName
108 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
109 Code << " [[maybe_unused]] mlir::DataLayout *dataLayout;\n";
110 Code << " [[maybe_unused]] cir::LowerModule *lowerModule;\n";
111 Code << "\n";
112
113 Code << "public:\n";
114 Code << " " << PatternName
115 << "(mlir::MLIRContext *context, const mlir::TypeConverter "
116 "&typeConverter, mlir::DataLayout &dataLayout, cir::LowerModule "
117 "&lowerModule)\n";
118 Code << " : OpConversionPattern<cir::" << OpName
119 << ">(typeConverter, context), dataLayout(&dataLayout), "
120 "lowerModule(&lowerModule) {}\n";
121 Code << "\n";
122
123 Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
124 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
125 "const override;\n";
126
127 Code << "};\n";
128
129 CXXABILoweringPatterns.push_back(x: std::move(CodeBuffer));
130}
131
132void GenerateLLVMLoweringPattern(llvm::StringRef OpName,
133 llvm::StringRef PatternName, bool IsRecursive,
134 llvm::StringRef ExtraDecl,
135 const Record *CustomCtorRec) {
136 std::optional<CustomLoweringCtor> CustomCtor =
137 parseCustomLoweringCtor(R: CustomCtorRec);
138 std::string CodeBuffer;
139 llvm::raw_string_ostream Code(CodeBuffer);
140
141 Code << "class " << PatternName
142 << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n";
143 Code << " [[maybe_unused]] mlir::DataLayout const &dataLayout;\n";
144
145 if (CustomCtor) {
146 for (const CustomLoweringCtor::Param &P : CustomCtor->Params)
147 Code << " " << P.Type << " " << P.Name << ";\n";
148 }
149
150 Code << "\n";
151
152 Code << "public:\n";
153 Code << " using mlir::OpConversionPattern<cir::" << OpName
154 << ">::OpConversionPattern;\n";
155
156 // Constructor
157 Code << " " << PatternName
158 << "(const mlir::TypeConverter &typeConverter, "
159 "mlir::MLIRContext *context, const mlir::DataLayout &dataLayout";
160
161 if (CustomCtor)
162 emitCustomParamList(Code, Params: CustomCtor->Params);
163
164 Code << ")\n";
165
166 Code << " : OpConversionPattern<cir::" << OpName
167 << ">(typeConverter, context), dataLayout(dataLayout)";
168
169 if (CustomCtor)
170 emitCustomInitList(Code, Params: CustomCtor->Params);
171
172 Code << " {\n";
173
174 if (IsRecursive)
175 Code << " setHasBoundedRewriteRecursion();\n";
176
177 Code << " }\n\n";
178
179 Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName
180 << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) "
181 "const override;\n";
182
183 if (!ExtraDecl.empty()) {
184 Code << "\nprivate:\n";
185 Code << ExtraDecl << "\n";
186 }
187
188 Code << "};\n";
189
190 LLVMLoweringPatterns.push_back(x: std::move(CodeBuffer));
191}
192
193void Generate(const Record *OpRecord) {
194 std::string OpName = GetOpCppClassName(OpRecord);
195
196 if (OpRecord->getValueAsBit(FieldName: "hasCXXABILowering")) {
197 std::string PatternName = GetOpABILoweringPatternName(OpName);
198 GenerateABILoweringPattern(OpName, PatternName);
199 CXXABILoweringPatternsList.push_back(x: std::move(PatternName));
200 }
201
202 if (OpRecord->getValueAsBit(FieldName: "hasLLVMLowering")) {
203 std::string PatternName = GetOpLLVMLoweringPatternName(OpName);
204 bool IsRecursive = OpRecord->getValueAsBit(FieldName: "isLLVMLoweringRecursive");
205 const Record *CustomCtor =
206 OpRecord->getValueAsOptionalDef(FieldName: "customLLVMLoweringConstructorDecl");
207 llvm::StringRef ExtraDecl =
208 OpRecord->getValueAsString(FieldName: "extraLLVMLoweringPatternDecl");
209
210 GenerateLLVMLoweringPattern(OpName, PatternName, IsRecursive, ExtraDecl,
211 CustomCtorRec: CustomCtor);
212 // Only automatically register patterns that use the default constructor.
213 // Patterns with a custom constructor must be manually registered by the
214 // lowering pass.
215 if (!CustomCtor)
216 LLVMLoweringPatternsList.push_back(x: std::move(PatternName));
217 }
218}
219} // namespace
220
221void clang::EmitCIRLowering(const llvm::RecordKeeper &RK,
222 llvm::raw_ostream &OS) {
223 emitSourceFileHeader(Desc: "Lowering patterns for CIR operations", OS);
224 for (const auto *OpRecord : RK.getAllDerivedDefinitions(ClassName: "CIR_Op"))
225 Generate(OpRecord);
226
227 OS << "#ifdef GET_ABI_LOWERING_PATTERNS\n"
228 << llvm::join(R&: CXXABILoweringPatterns, Separator: "\n") << "#endif\n\n";
229 OS << "#ifdef GET_ABI_LOWERING_PATTERNS_LIST\n"
230 << llvm::join(R&: CXXABILoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
231
232 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS\n"
233 << llvm::join(R&: LLVMLoweringPatterns, Separator: "\n") << "#endif\n\n";
234 OS << "#ifdef GET_LLVM_LOWERING_PATTERNS_LIST\n"
235 << llvm::join(R&: LLVMLoweringPatternsList, Separator: ",\n") << "\n#endif\n\n";
236}
237