IR2Vec
Loading...
Searching...
No Matches
FlowAware.h
1//===- FlowAware.h - Flow-aware embeddings of IR2Vec ------------*- C++ -*-===//
2//
3// Part of the IR2Vec Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef __IR2Vec_FA_H__
10#define __IR2Vec_FA_H__
11
12#include "utils.h"
13
14#include "llvm/ADT/MapVector.h"
15#include "llvm/ADT/SmallSet.h"
16#include "llvm/Analysis/CallGraph.h"
17#include "llvm/Analysis/LoopInfo.h"
18#include "llvm/IR/BasicBlock.h"
19#include "llvm/IR/Dominators.h"
20#include "llvm/IR/Function.h"
21#include "llvm/Pass.h"
22#include "llvm/Support/raw_ostream.h"
23#include <fstream>
24#include <unordered_map>
25
26class IR2Vec_FA {
27
28private:
29 llvm::Module &M;
30 std::string res;
31 IR2Vec::VocabTy &vocabulary;
32 IR2Vec::Vector pgmVector;
33 unsigned dataMissCounter;
34 unsigned cyclicCounter;
35
36 llvm::SmallDenseMap<llvm::StringRef, unsigned> memWriteOps;
37 llvm::SmallDenseMap<const llvm::Instruction *, bool> livelinessMap;
38 llvm::SmallDenseMap<llvm::StringRef, unsigned> memAccessOps;
39
40 llvm::SmallMapVector<const llvm::Instruction *, IR2Vec::Vector, 128>
41 instVecMap;
42 llvm::SmallMapVector<const llvm::BasicBlock *, IR2Vec::Vector, 16> bbVecMap;
43 llvm::SmallMapVector<const llvm::Function *, IR2Vec::Vector, 16> funcVecMap;
44
45 llvm::SmallMapVector<const llvm::Function *,
46 llvm::SmallVector<const llvm::Function *, 10>, 16>
47 funcCallMap;
48
49 llvm::SmallMapVector<const llvm::Instruction *,
50 llvm::SmallVector<const llvm::Instruction *, 10>, 16>
51 writeDefsMap;
52
53 llvm::SmallMapVector<const llvm::Instruction *,
54 llvm::SmallVector<const llvm::Instruction *, 10>, 16>
55 instReachingDefsMap;
56
57 // Reverse instReachingDefsMap
58 llvm::SmallMapVector<const llvm::Instruction *,
59 llvm::SmallVector<const llvm::Instruction *, 10>, 16>
60 reverseReachingDefsMap;
61
62 llvm::SmallVector<const llvm::Instruction *, 20> instSolvedBySolver;
63
64 llvm::SmallVector<llvm::SmallVector<const llvm::Instruction *, 10>, 10>
65 allSCCs;
66
67 llvm::SmallMapVector<const llvm::Instruction *,
68 llvm::SmallVector<llvm::Instruction *, 16>, 16>
69 killMap;
70
71 std::map<int, std::vector<int>> SCCAdjList;
72
73 void getAllSCC();
74
75 IR2Vec::Vector getValue(std::string key);
76 void collectWriteDefsMap(llvm::Module &M);
77 void getTransitiveUse(
78 const llvm::Instruction *root, const llvm::Instruction *def,
79 llvm::SmallVector<const llvm::Instruction *, 100> &visitedList,
80 llvm::SmallVector<const llvm::Instruction *, 10> toAppend = {});
81 llvm::SmallVector<const llvm::Instruction *, 10>
82 getReachingDefs(const llvm::Instruction *, unsigned i);
83
84 void solveSingleComponent(
85 const llvm::Instruction &I,
86 llvm::SmallMapVector<const llvm::Instruction *, IR2Vec::Vector, 16>
87 &instValMap);
88 void getPartialVec(const llvm::Instruction &I,
89 llvm::SmallMapVector<const llvm::Instruction *,
90 IR2Vec::Vector, 16> &instValMap);
91
92 void solveInsts(llvm::SmallMapVector<const llvm::Instruction *,
93 IR2Vec::Vector, 16> &instValMap);
94 std::vector<int> topoOrder(int size);
95
96 void topoDFS(int vertex, std::vector<bool> &Visited,
97 std::vector<int> &visitStack);
98
99 void inst2Vec(const llvm::Instruction &I,
100 llvm::SmallVector<llvm::Function *, 15> &funcStack,
101 llvm::SmallMapVector<const llvm::Instruction *, IR2Vec::Vector,
102 16> &instValMap);
103 void traverseRD(const llvm::Instruction *inst,
104 std::unordered_map<const llvm::Instruction *, bool> &Visited,
105 llvm::SmallVector<const llvm::Instruction *, 10> &timeStack);
106
107 void DFSUtil(const llvm::Instruction *inst,
108 std::unordered_map<const llvm::Instruction *, bool> &Visited,
109 llvm::SmallVector<const llvm::Instruction *, 10> &set);
110
111 void bb2Vec(llvm::BasicBlock &B,
112 llvm::SmallVector<llvm::Function *, 15> &funcStack);
113 IR2Vec::Vector func2Vec(llvm::Function &F,
114 llvm::SmallVector<llvm::Function *, 15> &funcStack);
115
116 bool isMemOp(llvm::StringRef opcode, unsigned &operand,
117 llvm::SmallDenseMap<llvm::StringRef, unsigned> map);
118 std::string splitAndPipeFunctionName(std::string s);
119
120 void TransitiveReads(llvm::SmallVector<llvm::Instruction *, 16> &Killlist,
121 llvm::Instruction *Inst, llvm::BasicBlock *ParentBB);
122 llvm::SmallVector<llvm::Instruction *, 16>
123 createKilllist(llvm::Instruction *Arg, llvm::Instruction *writeInst);
124
125 // For Debugging
126 void print(IR2Vec::Vector t, unsigned pos) { llvm::outs() << t[pos]; }
127
128 void updateFuncVecMap(
129 llvm::Function *function,
130 llvm::SmallSet<const llvm::Function *, 16> &visitedFunctions);
131
132 void updateFuncVecMapWithCallee(const llvm::Function *function);
133
134public:
135 IR2Vec_FA(llvm::Module &M, IR2Vec::VocabTy &vocab) : M{M}, vocabulary{vocab} {
136
137 pgmVector = IR2Vec::Vector(IR2Vec::DIM, 0);
138 res = "";
139
140 memWriteOps.try_emplace("store", 1);
141 memWriteOps.try_emplace("cmpxchg", 0);
142 memWriteOps.try_emplace("atomicrmw", 0);
143
144 memAccessOps.try_emplace("getelementptr", 0);
145 memAccessOps.try_emplace("load", 0);
146
147 dataMissCounter = 0;
148 cyclicCounter = 0;
149
150 collectWriteDefsMap(M);
151
152 llvm::CallGraph cg = llvm::CallGraph(M);
153
154 for (auto callItr = cg.begin(); callItr != cg.end(); callItr++) {
155 if (callItr->first && !callItr->first->isDeclaration()) {
156 auto ParentFunc = callItr->first;
157 llvm::CallGraphNode *cgn = callItr->second.get();
158 if (cgn) {
159
160 for (auto It = cgn->begin(); It != cgn->end(); It++) {
161
162 auto func = It->second->getFunction();
163 if (func && !func->isDeclaration()) {
164 funcCallMap[ParentFunc].push_back(func);
165 }
166 }
167 }
168 }
169 }
170 }
171
172 void generateFlowAwareEncodings(std::ostream *o = nullptr,
173 std::ostream *missCount = nullptr,
174 std::ostream *cyclicCount = nullptr);
175
176 // newly added
177
178 void generateFlowAwareEncodingsForFunction(
179 std::ostream *o = nullptr, std::string name = "",
180 std::ostream *missCount = nullptr, std::ostream *cyclicCount = nullptr);
181
182 llvm::SmallMapVector<const llvm::Instruction *, IR2Vec::Vector, 128>
183 getInstVecMap() {
184 return instVecMap;
185 }
186
187 llvm::SmallMapVector<const llvm::BasicBlock *, IR2Vec::Vector, 16>
188 getBBVecMap() {
189 return bbVecMap;
190 }
191
192 llvm::SmallMapVector<const llvm::Function *, IR2Vec::Vector, 16>
193 getFuncVecMap() {
194 return funcVecMap;
195 }
196
197 IR2Vec::Vector getProgramVector() { return pgmVector; }
198};
199
200#endif
Definition FlowAware.h:26