-
Notifications
You must be signed in to change notification settings - Fork 497
/
func.h
231 lines (189 loc) · 7.73 KB
/
func.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
// Copyright (C) 2022-2024 Exaloop Inc. <https://exaloop.io>
#pragma once
#include "codon/cir/flow.h"
#include "codon/cir/util/iterators.h"
#include "codon/cir/var.h"
namespace codon {
namespace ir {
/// CIR function
class Func : public AcceptorExtend<Func, Var> {
private:
/// unmangled (source code) name of the function
std::string unmangledName;
/// whether the function is a generator
bool generator;
/// Parent type if func is a method, or null if not
types::Type *parentType;
protected:
/// list of arguments
std::list<Var *> args;
std::vector<Var *> doGetUsedVariables() const override;
int doReplaceUsedVariable(id_t id, Var *newVar) override;
std::vector<types::Type *> doGetUsedTypes() const override;
int doReplaceUsedType(const std::string &name, types::Type *newType) override;
public:
static const char NodeId;
/// Constructs an unrealized CIR function.
/// @param name the function's name
explicit Func(std::string name = "")
: AcceptorExtend(nullptr, true, false, std::move(name)), generator(false),
parentType(nullptr) {}
/// Re-initializes the function with a new type and names.
/// @param newType the function's new type
/// @param names the function's new argument names
void realize(types::Type *newType, const std::vector<std::string> &names);
/// @return iterator to the first arg
auto arg_begin() { return args.begin(); }
/// @return iterator beyond the last arg
auto arg_end() { return args.end(); }
/// @return iterator to the first arg
auto arg_begin() const { return args.begin(); }
/// @return iterator beyond the last arg
auto arg_end() const { return args.end(); }
/// @return a pointer to the last arg
Var *arg_front() { return args.front(); }
/// @return a pointer to the last arg
Var *arg_back() { return args.back(); }
/// @return a pointer to the last arg
const Var *arg_back() const { return args.back(); }
/// @return a pointer to the first arg
const Var *arg_front() const { return args.front(); }
/// @return the function's unmangled (source code) name
std::string getUnmangledName() const { return unmangledName; }
/// Sets the unmangled name.
/// @param v the new value
void setUnmangledName(std::string v) { unmangledName = std::move(v); }
/// @return true if the function is a generator
bool isGenerator() const { return generator; }
/// Sets the function's generator flag.
/// @param v the new value
void setGenerator(bool v = true) { generator = v; }
/// @return the variable corresponding to the given argument name
/// @param n the argument name
Var *getArgVar(const std::string &n);
/// @return the parent type
types::Type *getParentType() const { return parentType; }
/// Sets the parent type.
/// @param p the new parent
void setParentType(types::Type *p) { parentType = p; }
};
class BodiedFunc : public AcceptorExtend<BodiedFunc, Func> {
private:
/// list of variables defined and used within the function
std::list<Var *> symbols;
/// the function body
Value *body = nullptr;
/// whether the function is a JIT input
bool jit = false;
public:
static const char NodeId;
using AcceptorExtend::AcceptorExtend;
/// @return iterator to the first symbol
auto begin() { return symbols.begin(); }
/// @return iterator beyond the last symbol
auto end() { return symbols.end(); }
/// @return iterator to the first symbol
auto begin() const { return symbols.begin(); }
/// @return iterator beyond the last symbol
auto end() const { return symbols.end(); }
/// @return a pointer to the first symbol
Var *front() { return symbols.front(); }
/// @return a pointer to the last symbol
Var *back() { return symbols.back(); }
/// @return a pointer to the first symbol
const Var *front() const { return symbols.front(); }
/// @return a pointer to the last symbol
const Var *back() const { return symbols.back(); }
/// Inserts an symbol at the given position.
/// @param pos the position
/// @param v the symbol
/// @return an iterator to the newly added symbol
template <typename It> auto insert(It pos, Var *v) { return symbols.insert(pos, v); }
/// Appends an symbol.
/// @param v the new symbol
void push_back(Var *v) { symbols.push_back(v); }
/// Erases the symbol at the given position.
/// @param pos the position
/// @return symbol_iterator following the removed symbol.
template <typename It> auto erase(It pos) { return symbols.erase(pos); }
/// @return the function body
Flow *getBody() { return cast<Flow>(body); }
/// @return the function body
const Flow *getBody() const { return cast<Flow>(body); }
/// Sets the function's body.
/// @param b the new body
void setBody(Flow *b) { body = b; }
/// @return true if the function is a JIT input
bool isJIT() const { return jit; }
/// Changes the function's JIT input status.
/// @param v true if JIT input, false otherwise
void setJIT(bool v = true) { jit = v; }
protected:
std::vector<Value *> doGetUsedValues() const override {
return body ? std::vector<Value *>{body} : std::vector<Value *>{};
}
int doReplaceUsedValue(id_t id, Value *newValue) override;
std::vector<Var *> doGetUsedVariables() const override;
int doReplaceUsedVariable(id_t id, Var *newVar) override;
};
class ExternalFunc : public AcceptorExtend<ExternalFunc, Func> {
public:
static const char NodeId;
using AcceptorExtend::AcceptorExtend;
/// @return true if the function is variadic
bool isVariadic() const { return cast<types::FuncType>(getType())->isVariadic(); }
};
/// Internal, LLVM-only function.
class InternalFunc : public AcceptorExtend<InternalFunc, Func> {
public:
static const char NodeId;
using AcceptorExtend::AcceptorExtend;
};
/// LLVM function defined in Seq source.
class LLVMFunc : public AcceptorExtend<LLVMFunc, Func> {
private:
/// literals that must be formatted into the body
std::vector<types::Generic> llvmLiterals;
/// declares for llvm-only function
std::string llvmDeclares;
/// body of llvm-only function
std::string llvmBody;
public:
static const char NodeId;
using AcceptorExtend::AcceptorExtend;
/// Sets the LLVM literals.
/// @param v the new values.
void setLLVMLiterals(std::vector<types::Generic> v) { llvmLiterals = std::move(v); }
/// @return iterator to the first literal
auto literal_begin() { return llvmLiterals.begin(); }
/// @return iterator beyond the last literal
auto literal_end() { return llvmLiterals.end(); }
/// @return iterator to the first literal
auto literal_begin() const { return llvmLiterals.begin(); }
/// @return iterator beyond the last literal
auto literal_end() const { return llvmLiterals.end(); }
/// @return a reference to the first literal
auto &literal_front() { return llvmLiterals.front(); }
/// @return a reference to the last literal
auto &literal_back() { return llvmLiterals.back(); }
/// @return a reference to the first literal
auto &literal_front() const { return llvmLiterals.front(); }
/// @return a reference to the last literal
auto &literal_back() const { return llvmLiterals.back(); }
/// @return the LLVM declarations
const std::string &getLLVMDeclarations() const { return llvmDeclares; }
/// Sets the LLVM declarations.
/// @param v the new value
void setLLVMDeclarations(std::string v) { llvmDeclares = std::move(v); }
/// @return the LLVM body
const std::string &getLLVMBody() const { return llvmBody; }
/// Sets the LLVM body.
/// @param v the new value
void setLLVMBody(std::string v) { llvmBody = std::move(v); }
protected:
std::vector<types::Type *> doGetUsedTypes() const override;
int doReplaceUsedType(const std::string &name, types::Type *newType) override;
};
} // namespace ir
} // namespace codon
template <> struct fmt::formatter<codon::ir::Func> : fmt::ostream_formatter {};