IR2Vec
Loading...
Searching...
No Matches
Test.h
1#ifndef TEST_H
2#define TEST_H
3#include "Corrupt.h"
4#include "Reader.h"
5#include "Setting.h"
6
7/*=====================================================================================
8link prediction
9======================================================================================*/
10INT lastHead = 0;
11INT lastTail = 0;
12INT lastRel = 0;
13REAL l1_filter_tot = 0, l1_tot = 0, r1_tot = 0, r1_filter_tot = 0, l_tot = 0,
14 r_tot = 0, l_filter_rank = 0, l_rank = 0, l_filter_reci_rank = 0,
15 l_reci_rank = 0;
16REAL l3_filter_tot = 0, l3_tot = 0, r3_tot = 0, r3_filter_tot = 0,
17 l_filter_tot = 0, r_filter_tot = 0, r_filter_rank = 0, r_rank = 0,
18 r_filter_reci_rank = 0, r_reci_rank = 0;
19REAL rel3_tot = 0, rel3_filter_tot = 0, rel_filter_tot = 0, rel_filter_rank = 0,
20 rel_rank = 0, rel_filter_reci_rank = 0, rel_reci_rank = 0, rel_tot = 0,
21 rel1_tot = 0, rel1_filter_tot = 0;
22
23REAL l1_filter_tot_constrain = 0, l1_tot_constrain = 0, r1_tot_constrain = 0,
24 r1_filter_tot_constrain = 0, l_tot_constrain = 0, r_tot_constrain = 0,
25 l_filter_rank_constrain = 0, l_rank_constrain = 0,
26 l_filter_reci_rank_constrain = 0, l_reci_rank_constrain = 0;
27REAL l3_filter_tot_constrain = 0, l3_tot_constrain = 0, r3_tot_constrain = 0,
28 r3_filter_tot_constrain = 0, l_filter_tot_constrain = 0,
29 r_filter_tot_constrain = 0, r_filter_rank_constrain = 0,
30 r_rank_constrain = 0, r_filter_reci_rank_constrain = 0,
31 r_reci_rank_constrain = 0;
32REAL hit1, hit3, hit10, mr, mrr;
33REAL hit1TC, hit3TC, hit10TC, mrTC, mrrTC;
34
35extern "C" void initTest() {
36 lastHead = 0;
37 lastTail = 0;
38 lastRel = 0;
39 l1_filter_tot = 0, l1_tot = 0, r1_tot = 0, r1_filter_tot = 0, l_tot = 0,
40 r_tot = 0, l_filter_rank = 0, l_rank = 0, l_filter_reci_rank = 0,
41 l_reci_rank = 0;
42 l3_filter_tot = 0, l3_tot = 0, r3_tot = 0, r3_filter_tot = 0,
43 l_filter_tot = 0, r_filter_tot = 0, r_filter_rank = 0, r_rank = 0,
44 r_filter_reci_rank = 0, r_reci_rank = 0;
45 REAL rel3_tot = 0, rel3_filter_tot = 0, rel_filter_tot = 0,
46 rel_filter_rank = 0, rel_rank = 0, rel_filter_reci_rank = 0,
47 rel_reci_rank = 0, rel_tot = 0, rel1_tot = 0, rel1_filter_tot = 0;
48
49 l1_filter_tot_constrain = 0, l1_tot_constrain = 0, r1_tot_constrain = 0,
50 r1_filter_tot_constrain = 0, l_tot_constrain = 0, r_tot_constrain = 0,
51 l_filter_rank_constrain = 0, l_rank_constrain = 0,
52 l_filter_reci_rank_constrain = 0, l_reci_rank_constrain = 0;
53 l3_filter_tot_constrain = 0, l3_tot_constrain = 0, r3_tot_constrain = 0,
54 r3_filter_tot_constrain = 0, l_filter_tot_constrain = 0,
55 r_filter_tot_constrain = 0, r_filter_rank_constrain = 0, r_rank_constrain = 0,
56 r_filter_reci_rank_constrain = 0, r_reci_rank_constrain = 0;
57}
58
59extern "C" void getHeadBatch(INT *ph, INT *pt, INT *pr) {
60 for (INT i = 0; i < entityTotal; i++) {
61 ph[i] = i;
62 pt[i] = testList[lastHead].t;
63 pr[i] = testList[lastHead].r;
64 }
65 lastHead++;
66}
67
68extern "C" void getTailBatch(INT *ph, INT *pt, INT *pr) {
69 for (INT i = 0; i < entityTotal; i++) {
70 ph[i] = testList[lastTail].h;
71 pt[i] = i;
72 pr[i] = testList[lastTail].r;
73 }
74 lastTail++;
75}
76
77extern "C" void getRelBatch(INT *ph, INT *pt, INT *pr) {
78 for (INT i = 0; i < relationTotal; i++) {
79 ph[i] = testList[lastRel].h;
80 pt[i] = testList[lastRel].t;
81 pr[i] = i;
82 }
83}
84
85extern "C" void testHead(REAL *con, INT lastHead, bool type_constrain = false) {
86 INT h = testList[lastHead].h;
87 INT t = testList[lastHead].t;
88 INT r = testList[lastHead].r;
89 INT lef, rig;
90 if (type_constrain) {
91 lef = head_lef[r];
92 rig = head_rig[r];
93 }
94 REAL minimal = con[h];
95 INT l_s = 0;
96 INT l_filter_s = 0;
97 INT l_s_constrain = 0;
98 INT l_filter_s_constrain = 0;
99
100 for (INT j = 0; j < entityTotal; j++) {
101 if (j != h) {
102 REAL value = con[j];
103 if (value < minimal) {
104 l_s += 1;
105 if (not _find(j, t, r))
106 l_filter_s += 1;
107 }
108 if (type_constrain) {
109 while (lef < rig && head_type[lef] < j)
110 lef++;
111 if (lef < rig && j == head_type[lef]) {
112 if (value < minimal) {
113 l_s_constrain += 1;
114 if (not _find(j, t, r)) {
115 l_filter_s_constrain += 1;
116 }
117 }
118 }
119 }
120 }
121 }
122
123 if (l_filter_s < 10)
124 l_filter_tot += 1;
125 if (l_s < 10)
126 l_tot += 1;
127 if (l_filter_s < 3)
128 l3_filter_tot += 1;
129 if (l_s < 3)
130 l3_tot += 1;
131 if (l_filter_s < 1)
132 l1_filter_tot += 1;
133 if (l_s < 1)
134 l1_tot += 1;
135
136 l_filter_rank += (l_filter_s + 1);
137 l_rank += (1 + l_s);
138 l_filter_reci_rank += 1.0 / (l_filter_s + 1);
139 l_reci_rank += 1.0 / (l_s + 1);
140
141 if (type_constrain) {
142 if (l_filter_s_constrain < 10)
143 l_filter_tot_constrain += 1;
144 if (l_s_constrain < 10)
145 l_tot_constrain += 1;
146 if (l_filter_s_constrain < 3)
147 l3_filter_tot_constrain += 1;
148 if (l_s_constrain < 3)
149 l3_tot_constrain += 1;
150 if (l_filter_s_constrain < 1)
151 l1_filter_tot_constrain += 1;
152 if (l_s_constrain < 1)
153 l1_tot_constrain += 1;
154
155 l_filter_rank_constrain += (l_filter_s_constrain + 1);
156 l_rank_constrain += (1 + l_s_constrain);
157 l_filter_reci_rank_constrain += 1.0 / (l_filter_s_constrain + 1);
158 l_reci_rank_constrain += 1.0 / (l_s_constrain + 1);
159 }
160}
161
162extern "C" void testTail(REAL *con, INT lastTail, bool type_constrain = false) {
163 INT h = testList[lastTail].h;
164 INT t = testList[lastTail].t;
165 INT r = testList[lastTail].r;
166 INT lef, rig;
167 if (type_constrain) {
168 lef = tail_lef[r];
169 rig = tail_rig[r];
170 }
171 REAL minimal = con[t];
172 INT r_s = 0;
173 INT r_filter_s = 0;
174 INT r_s_constrain = 0;
175 INT r_filter_s_constrain = 0;
176 for (INT j = 0; j < entityTotal; j++) {
177 if (j != t) {
178 REAL value = con[j];
179 if (value < minimal) {
180 r_s += 1;
181 if (not _find(h, j, r))
182 r_filter_s += 1;
183 }
184 if (type_constrain) {
185 while (lef < rig && tail_type[lef] < j)
186 lef++;
187 if (lef < rig && j == tail_type[lef]) {
188 if (value < minimal) {
189 r_s_constrain += 1;
190 if (not _find(h, j, r)) {
191 r_filter_s_constrain += 1;
192 }
193 }
194 }
195 }
196 }
197 }
198
199 if (r_filter_s < 10)
200 r_filter_tot += 1;
201 if (r_s < 10)
202 r_tot += 1;
203 if (r_filter_s < 3)
204 r3_filter_tot += 1;
205 if (r_s < 3)
206 r3_tot += 1;
207 if (r_filter_s < 1)
208 r1_filter_tot += 1;
209 if (r_s < 1)
210 r1_tot += 1;
211
212 r_filter_rank += (1 + r_filter_s);
213 r_rank += (1 + r_s);
214 r_filter_reci_rank += 1.0 / (1 + r_filter_s);
215 r_reci_rank += 1.0 / (1 + r_s);
216
217 if (type_constrain) {
218 if (r_filter_s_constrain < 10)
219 r_filter_tot_constrain += 1;
220 if (r_s_constrain < 10)
221 r_tot_constrain += 1;
222 if (r_filter_s_constrain < 3)
223 r3_filter_tot_constrain += 1;
224 if (r_s_constrain < 3)
225 r3_tot_constrain += 1;
226 if (r_filter_s_constrain < 1)
227 r1_filter_tot_constrain += 1;
228 if (r_s_constrain < 1)
229 r1_tot_constrain += 1;
230
231 r_filter_rank_constrain += (1 + r_filter_s_constrain);
232 r_rank_constrain += (1 + r_s_constrain);
233 r_filter_reci_rank_constrain += 1.0 / (1 + r_filter_s_constrain);
234 r_reci_rank_constrain += 1.0 / (1 + r_s_constrain);
235 }
236}
237
238extern "C" void testRel(REAL *con) {
239 INT h = testList[lastRel].h;
240 INT t = testList[lastRel].t;
241 INT r = testList[lastRel].r;
242
243 REAL minimal = con[r];
244 INT rel_s = 0;
245 INT rel_filter_s = 0;
246
247 for (INT j = 0; j < relationTotal; j++) {
248 if (j != r) {
249 REAL value = con[j];
250 if (value < minimal) {
251 rel_s += 1;
252 if (not _find(h, t, j))
253 rel_filter_s += 1;
254 }
255 }
256 }
257
258 if (rel_filter_s < 10)
259 rel_filter_tot += 1;
260 if (rel_s < 10)
261 rel_tot += 1;
262 if (rel_filter_s < 3)
263 rel3_filter_tot += 1;
264 if (rel_s < 3)
265 rel3_tot += 1;
266 if (rel_filter_s < 1)
267 rel1_filter_tot += 1;
268 if (rel_s < 1)
269 rel1_tot += 1;
270
271 rel_filter_rank += (rel_filter_s + 1);
272 rel_rank += (1 + rel_s);
273 rel_filter_reci_rank += 1.0 / (rel_filter_s + 1);
274 rel_reci_rank += 1.0 / (rel_s + 1);
275
276 lastRel++;
277}
278
279extern "C" void test_link_prediction(bool type_constrain = false) {
280 l_rank /= testTotal;
281 r_rank /= testTotal;
282 l_reci_rank /= testTotal;
283 r_reci_rank /= testTotal;
284
285 l_tot /= testTotal;
286 l3_tot /= testTotal;
287 l1_tot /= testTotal;
288
289 r_tot /= testTotal;
290 r3_tot /= testTotal;
291 r1_tot /= testTotal;
292
293 // with filter
294 l_filter_rank /= testTotal;
295 r_filter_rank /= testTotal;
296 l_filter_reci_rank /= testTotal;
297 r_filter_reci_rank /= testTotal;
298
299 l_filter_tot /= testTotal;
300 l3_filter_tot /= testTotal;
301 l1_filter_tot /= testTotal;
302
303 r_filter_tot /= testTotal;
304 r3_filter_tot /= testTotal;
305 r1_filter_tot /= testTotal;
306
307 printf("no type constraint results:\n");
308
309 printf("metric:\t\t\t MRR \t\t MR \t\t hit@10 \t hit@3 \t hit@1 \n");
310 printf("l(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", l_reci_rank, l_rank,
311 l_tot, l3_tot, l1_tot);
312 printf("r(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", r_reci_rank, r_rank,
313 r_tot, r3_tot, r1_tot);
314 printf("averaged(raw):\t\t %f \t %f \t %f \t %f \t %f \n",
315 (l_reci_rank + r_reci_rank) / 2, (l_rank + r_rank) / 2,
316 (l_tot + r_tot) / 2, (l3_tot + r3_tot) / 2, (l1_tot + r1_tot) / 2);
317 printf("\n");
318 printf("l(filter):\t\t %f \t %f \t %f \t %f \t %f \n", l_filter_reci_rank,
319 l_filter_rank, l_filter_tot, l3_filter_tot, l1_filter_tot);
320 printf("r(filter):\t\t %f \t %f \t %f \t %f \t %f \n", r_filter_reci_rank,
321 r_filter_rank, r_filter_tot, r3_filter_tot, r1_filter_tot);
322 printf("averaged(filter):\t %f \t %f \t %f \t %f \t %f \n",
323 (l_filter_reci_rank + r_filter_reci_rank) / 2,
324 (l_filter_rank + r_filter_rank) / 2, (l_filter_tot + r_filter_tot) / 2,
325 (l3_filter_tot + r3_filter_tot) / 2,
326 (l1_filter_tot + r1_filter_tot) / 2);
327
328 mrr = (l_filter_reci_rank + r_filter_reci_rank) / 2;
329 mr = (l_filter_rank + r_filter_rank) / 2;
330 hit10 = (l_filter_tot + r_filter_tot) / 2;
331 hit3 = (l3_filter_tot + r3_filter_tot) / 2;
332 hit1 = (l1_filter_tot + r1_filter_tot) / 2;
333
334 if (type_constrain) {
335 // type constrain
336 l_rank_constrain /= testTotal;
337 r_rank_constrain /= testTotal;
338 l_reci_rank_constrain /= testTotal;
339 r_reci_rank_constrain /= testTotal;
340
341 l_tot_constrain /= testTotal;
342 l3_tot_constrain /= testTotal;
343 l1_tot_constrain /= testTotal;
344
345 r_tot_constrain /= testTotal;
346 r3_tot_constrain /= testTotal;
347 r1_tot_constrain /= testTotal;
348
349 // with filter
350 l_filter_rank_constrain /= testTotal;
351 r_filter_rank_constrain /= testTotal;
352 l_filter_reci_rank_constrain /= testTotal;
353 r_filter_reci_rank_constrain /= testTotal;
354
355 l_filter_tot_constrain /= testTotal;
356 l3_filter_tot_constrain /= testTotal;
357 l1_filter_tot_constrain /= testTotal;
358
359 r_filter_tot_constrain /= testTotal;
360 r3_filter_tot_constrain /= testTotal;
361 r1_filter_tot_constrain /= testTotal;
362
363 printf("type constraint results:\n");
364
365 printf("metric:\t\t\t MRR \t\t MR \t\t hit@10 \t hit@3 \t hit@1 \n");
366 printf("l(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", l_reci_rank_constrain,
367 l_rank_constrain, l_tot_constrain, l3_tot_constrain,
368 l1_tot_constrain);
369 printf("r(raw):\t\t\t %f \t %f \t %f \t %f \t %f \n", r_reci_rank_constrain,
370 r_rank_constrain, r_tot_constrain, r3_tot_constrain,
371 r1_tot_constrain);
372 printf("averaged(raw):\t\t %f \t %f \t %f \t %f \t %f \n",
373 (l_reci_rank_constrain + r_reci_rank_constrain) / 2,
374 (l_rank_constrain + r_rank_constrain) / 2,
375 (l_tot_constrain + r_tot_constrain) / 2,
376 (l3_tot_constrain + r3_tot_constrain) / 2,
377 (l1_tot_constrain + r1_tot_constrain) / 2);
378 printf("\n");
379 printf("l(filter):\t\t %f \t %f \t %f \t %f \t %f \n",
380 l_filter_reci_rank_constrain, l_filter_rank_constrain,
381 l_filter_tot_constrain, l3_filter_tot_constrain,
382 l1_filter_tot_constrain);
383 printf("r(filter):\t\t %f \t %f \t %f \t %f \t %f \n",
384 r_filter_reci_rank_constrain, r_filter_rank_constrain,
385 r_filter_tot_constrain, r3_filter_tot_constrain,
386 r1_filter_tot_constrain);
387 printf("averaged(filter):\t %f \t %f \t %f \t %f \t %f \n",
388 (l_filter_reci_rank_constrain + r_filter_reci_rank_constrain) / 2,
389 (l_filter_rank_constrain + r_filter_rank_constrain) / 2,
390 (l_filter_tot_constrain + r_filter_tot_constrain) / 2,
391 (l3_filter_tot_constrain + r3_filter_tot_constrain) / 2,
392 (l1_filter_tot_constrain + r1_filter_tot_constrain) / 2);
393
394 mrrTC = (l_filter_reci_rank_constrain + r_filter_reci_rank_constrain) / 2;
395 mrTC = (l_filter_rank_constrain + r_filter_rank_constrain) / 2;
396 hit10TC = (l_filter_tot_constrain + r_filter_tot_constrain) / 2;
397 hit3TC = (l3_filter_tot_constrain + r3_filter_tot_constrain) / 2;
398 hit1TC = (l1_filter_tot_constrain + r1_filter_tot_constrain) / 2;
399 }
400}
401
402extern "C" void test_relation_prediction() {
403 rel_rank /= testTotal;
404 rel_reci_rank /= testTotal;
405
406 rel_tot /= testTotal;
407 rel3_tot /= testTotal;
408 rel1_tot /= testTotal;
409
410 // with filter
411 rel_filter_rank /= testTotal;
412 rel_filter_reci_rank /= testTotal;
413
414 rel_filter_tot /= testTotal;
415 rel3_filter_tot /= testTotal;
416 rel1_filter_tot /= testTotal;
417
418 printf("no type constraint results:\n");
419
420 printf("metric:\t\t\t MRR \t\t MR \t\t hit@10 \t hit@3 \t hit@1 \n");
421 printf("averaged(raw):\t\t %f \t %f \t %f \t %f \t %f \n", rel_reci_rank,
422 rel_rank, rel_tot, rel3_tot, rel1_tot);
423 printf("\n");
424 printf("averaged(filter):\t %f \t %f \t %f \t %f \t %f \n",
425 rel_filter_reci_rank, rel_filter_rank, rel_filter_tot, rel3_filter_tot,
426 rel1_filter_tot);
427}
428
429extern "C" REAL getTestLinkHit10(bool type_constrain = false) {
430 if (type_constrain)
431 return hit10TC;
432 printf("%f\n", hit10);
433 return hit10;
434}
435
436extern "C" REAL getTestLinkHit3(bool type_constrain = false) {
437 if (type_constrain)
438 return hit3TC;
439 return hit3;
440}
441
442extern "C" REAL getTestLinkHit1(bool type_constrain = false) {
443 if (type_constrain)
444 return hit1TC;
445 return hit1;
446}
447
448extern "C" REAL getTestLinkMR(bool type_constrain = false) {
449 if (type_constrain)
450 return mrTC;
451 return mr;
452}
453
454extern "C" REAL getTestLinkMRR(bool type_constrain = false) {
455 if (type_constrain)
456 return mrrTC;
457 return mrr;
458}
459
460/*=====================================================================================
461triple classification
462======================================================================================*/
463Triple *negTestList = NULL;
464
465extern "C" void getNegTest() {
466 if (negTestList == NULL)
467 negTestList = (Triple *)calloc(testTotal, sizeof(Triple));
468 for (INT i = 0; i < testTotal; i++) {
469 negTestList[i] = testList[i];
470 if (randd(0) % 1000 < 500)
471 negTestList[i].t = corrupt_head(0, testList[i].h, testList[i].r);
472 else
473 negTestList[i].h = corrupt_tail(0, testList[i].t, testList[i].r);
474 }
475}
476
477extern "C" void getTestBatch(INT *ph, INT *pt, INT *pr, INT *nh, INT *nt,
478 INT *nr) {
479 getNegTest();
480 for (INT i = 0; i < testTotal; i++) {
481 ph[i] = testList[i].h;
482 pt[i] = testList[i].t;
483 pr[i] = testList[i].r;
484 nh[i] = negTestList[i].h;
485 nt[i] = negTestList[i].t;
486 nr[i] = negTestList[i].r;
487 }
488}
489#endif
Definition Triple.h:5