IR2Vec
Loading...
Searching...
No Matches
Corrupt.h
1#ifndef CORRUPT_H
2#define CORRUPT_H
3#include "Random.h"
4#include "Reader.h"
5#include "Triple.h"
6
7INT corrupt_head(INT id, INT h, INT r, bool filter_flag = true) {
8 INT lef, rig, mid, ll, rr;
9 if (not filter_flag) {
10 INT tmp = rand_max(id, entityTotal - 1);
11 if (tmp < h)
12 return tmp;
13 else
14 return tmp + 1;
15 }
16 lef = lefHead[h] - 1;
17 rig = rigHead[h];
18 while (lef + 1 < rig) {
19 mid = (lef + rig) >> 1;
20 if (trainHead[mid].r >= r)
21 rig = mid;
22 else
23 lef = mid;
24 }
25 ll = rig;
26 lef = lefHead[h];
27 rig = rigHead[h] + 1;
28 while (lef + 1 < rig) {
29 mid = (lef + rig) >> 1;
30 if (trainHead[mid].r <= r)
31 lef = mid;
32 else
33 rig = mid;
34 }
35 rr = lef;
36 INT tmp = rand_max(id, entityTotal - (rr - ll + 1));
37 if (tmp < trainHead[ll].t)
38 return tmp;
39 if (tmp > trainHead[rr].t - rr + ll - 1)
40 return tmp + rr - ll + 1;
41 lef = ll, rig = rr + 1;
42 while (lef + 1 < rig) {
43 mid = (lef + rig) >> 1;
44 if (trainHead[mid].t - mid + ll - 1 < tmp)
45 lef = mid;
46 else
47 rig = mid;
48 }
49 return tmp + lef - ll + 1;
50}
51
52INT corrupt_tail(INT id, INT t, INT r, bool filter_flag = true) {
53 INT lef, rig, mid, ll, rr;
54 if (not filter_flag) {
55 INT tmp = rand_max(id, entityTotal - 1);
56 if (tmp < t)
57 return tmp;
58 else
59 return tmp + 1;
60 }
61 lef = lefTail[t] - 1;
62 rig = rigTail[t];
63 while (lef + 1 < rig) {
64 mid = (lef + rig) >> 1;
65 if (trainTail[mid].r >= r)
66 rig = mid;
67 else
68 lef = mid;
69 }
70 ll = rig;
71 lef = lefTail[t];
72 rig = rigTail[t] + 1;
73 while (lef + 1 < rig) {
74 mid = (lef + rig) >> 1;
75 if (trainTail[mid].r <= r)
76 lef = mid;
77 else
78 rig = mid;
79 }
80 rr = lef;
81 INT tmp = rand_max(id, entityTotal - (rr - ll + 1));
82 if (tmp < trainTail[ll].h)
83 return tmp;
84 if (tmp > trainTail[rr].h - rr + ll - 1)
85 return tmp + rr - ll + 1;
86 lef = ll, rig = rr + 1;
87 while (lef + 1 < rig) {
88 mid = (lef + rig) >> 1;
89 if (trainTail[mid].h - mid + ll - 1 < tmp)
90 lef = mid;
91 else
92 rig = mid;
93 }
94 return tmp + lef - ll + 1;
95}
96
97INT corrupt_rel(INT id, INT h, INT t, INT r, bool p = false,
98 bool filter_flag = true) {
99 INT lef, rig, mid, ll, rr;
100 if (not filter_flag) {
101 INT tmp = rand_max(id, relationTotal - 1);
102 if (tmp < r)
103 return tmp;
104 else
105 return tmp + 1;
106 }
107 lef = lefRel[h] - 1;
108 rig = rigRel[h];
109 while (lef + 1 < rig) {
110 mid = (lef + rig) >> 1;
111 if (trainRel[mid].t >= t)
112 rig = mid;
113 else
114 lef = mid;
115 }
116 ll = rig;
117 lef = lefRel[h];
118 rig = rigRel[h] + 1;
119 while (lef + 1 < rig) {
120 mid = (lef + rig) >> 1;
121 if (trainRel[mid].t <= t)
122 lef = mid;
123 else
124 rig = mid;
125 }
126 rr = lef;
127 INT tmp;
128 if (p == false) {
129 tmp = rand_max(id, relationTotal - (rr - ll + 1));
130 } else {
131 INT start = r * (relationTotal - 1);
132 REAL sum = 1;
133 bool *record = (bool *)calloc(relationTotal - 1, sizeof(bool));
134 for (INT i = ll; i <= rr; ++i) {
135 if (trainRel[i].r > r) {
136 sum -= prob[start + trainRel[i].r - 1];
137 record[trainRel[i].r - 1] = true;
138 } else if (trainRel[i].r < r) {
139 sum -= prob[start + trainRel[i].r];
140 record[trainRel[i].r] = true;
141 }
142 }
143 REAL *prob_tmp =
144 (REAL *)calloc(relationTotal - (rr - ll + 1), sizeof(REAL));
145 INT cnt = 0;
146 REAL rec = 0;
147 for (INT i = start; i < start + relationTotal - 1; ++i) {
148 if (record[i - start])
149 continue;
150 rec += prob[i] / sum;
151 prob_tmp[cnt++] = rec;
152 }
153 REAL m = rand_max(id, 10000) / 10000.0;
154 lef = 0;
155 rig = cnt - 1;
156 while (lef < rig) {
157 mid = (lef + rig) >> 1;
158 if (prob_tmp[mid] < m)
159 lef = mid + 1;
160 else
161 rig = mid;
162 }
163 tmp = rig;
164 free(prob_tmp);
165 free(record);
166 }
167 if (tmp < trainRel[ll].r)
168 return tmp;
169 if (tmp > trainRel[rr].r - rr + ll - 1)
170 return tmp + rr - ll + 1;
171 lef = ll, rig = rr + 1;
172 while (lef + 1 < rig) {
173 mid = (lef + rig) >> 1;
174 if (trainRel[mid].r - mid + ll - 1 < tmp)
175 lef = mid;
176 else
177 rig = mid;
178 }
179 return tmp + lef - ll + 1;
180}
181
182bool _find(INT h, INT t, INT r) {
183 INT lef = 0;
184 INT rig = tripleTotal - 1;
185 INT mid;
186 while (lef + 1 < rig) {
187 INT mid = (lef + rig) >> 1;
188 if ((tripleList[mid].h < h) ||
189 (tripleList[mid].h == h && tripleList[mid].r < r) ||
190 (tripleList[mid].h == h && tripleList[mid].r == r &&
191 tripleList[mid].t < t))
192 lef = mid;
193 else
194 rig = mid;
195 }
196 if (tripleList[lef].h == h && tripleList[lef].r == r &&
197 tripleList[lef].t == t)
198 return true;
199 if (tripleList[rig].h == h && tripleList[rig].r == r &&
200 tripleList[rig].t == t)
201 return true;
202 return false;
203}
204
205INT corrupt(INT h, INT r) {
206 INT ll = tail_lef[r];
207 INT rr = tail_rig[r];
208 INT loop = 0;
209 INT t;
210 while (true) {
211 t = tail_type[rand(ll, rr)];
212 if (not _find(h, t, r)) {
213 return t;
214 } else {
215 loop++;
216 if (loop >= 1000) {
217 return corrupt_head(0, h, r);
218 }
219 }
220 }
221}
222#endif