IR2Vec
Loading...
Searching...
No Matches
Reader.h
1#ifndef READER_H
2#define READER_H
3#include "Setting.h"
4#include "Triple.h"
5#include <algorithm>
6#include <cmath>
7#include <cstdlib>
8#include <iostream>
9
10INT *freqRel, *freqEnt;
11INT *lefHead, *rigHead;
12INT *lefTail, *rigTail;
13INT *lefRel, *rigRel;
14REAL *left_mean, *right_mean;
15REAL *prob;
16
17Triple *trainList;
18Triple *trainHead;
19Triple *trainTail;
20Triple *trainRel;
21
22INT *testLef, *testRig;
23INT *validLef, *validRig;
24
25extern "C" void importProb(REAL temp) {
26 if (prob != NULL)
27 free(prob);
28 FILE *fin;
29 fin = fopen((inPath + "kl_prob.txt").c_str(), "r");
30 printf("Current temperature:%f\n", temp);
31 prob = (REAL *)calloc(relationTotal * (relationTotal - 1), sizeof(REAL));
32 INT tmp;
33 for (INT i = 0; i < relationTotal * (relationTotal - 1); ++i) {
34 tmp = fscanf(fin, "%f", &prob[i]);
35 }
36 REAL sum = 0.0;
37 for (INT i = 0; i < relationTotal; ++i) {
38 for (INT j = 0; j < relationTotal - 1; ++j) {
39 REAL tmp = exp(-prob[i * (relationTotal - 1) + j] / temp);
40 sum += tmp;
41 prob[i * (relationTotal - 1) + j] = tmp;
42 }
43 for (INT j = 0; j < relationTotal - 1; ++j) {
44 prob[i * (relationTotal - 1) + j] /= sum;
45 }
46 sum = 0;
47 }
48 fclose(fin);
49}
50
51extern "C" void importTrainFiles() {
52
53 printf("The toolkit is importing datasets.\n");
54 FILE *fin;
55 int tmp;
56
57 if (rel_file == "")
58 fin = fopen((inPath + "relation2id.txt").c_str(), "r");
59 else
60 fin = fopen(rel_file.c_str(), "r");
61 tmp = fscanf(fin, "%ld", &relationTotal);
62 printf("The total of relations is %ld.\n", relationTotal);
63 fclose(fin);
64
65 if (ent_file == "")
66 fin = fopen((inPath + "entity2id.txt").c_str(), "r");
67 else
68 fin = fopen(ent_file.c_str(), "r");
69 tmp = fscanf(fin, "%ld", &entityTotal);
70 printf("The total of entities is %ld.\n", entityTotal);
71 fclose(fin);
72
73 if (train_file == "")
74 fin = fopen((inPath + "train2id.txt").c_str(), "r");
75 else
76 fin = fopen(train_file.c_str(), "r");
77 tmp = fscanf(fin, "%ld", &trainTotal);
78 trainList = (Triple *)calloc(trainTotal, sizeof(Triple));
79 trainHead = (Triple *)calloc(trainTotal, sizeof(Triple));
80 trainTail = (Triple *)calloc(trainTotal, sizeof(Triple));
81 trainRel = (Triple *)calloc(trainTotal, sizeof(Triple));
82 freqRel = (INT *)calloc(relationTotal, sizeof(INT));
83 freqEnt = (INT *)calloc(entityTotal, sizeof(INT));
84 for (INT i = 0; i < trainTotal; i++) {
85 tmp = fscanf(fin, "%ld", &trainList[i].h);
86 tmp = fscanf(fin, "%ld", &trainList[i].t);
87 tmp = fscanf(fin, "%ld", &trainList[i].r);
88 }
89 fclose(fin);
90 std::sort(trainList, trainList + trainTotal, Triple::cmp_head);
91 tmp = trainTotal;
92 trainTotal = 1;
93 trainHead[0] = trainTail[0] = trainRel[0] = trainList[0];
94 freqEnt[trainList[0].t] += 1;
95 freqEnt[trainList[0].h] += 1;
96 freqRel[trainList[0].r] += 1;
97 for (INT i = 1; i < tmp; i++)
98 if (trainList[i].h != trainList[i - 1].h ||
99 trainList[i].r != trainList[i - 1].r ||
100 trainList[i].t != trainList[i - 1].t) {
101 trainHead[trainTotal] = trainTail[trainTotal] = trainRel[trainTotal] =
102 trainList[trainTotal] = trainList[i];
103 trainTotal++;
104 freqEnt[trainList[i].t]++;
105 freqEnt[trainList[i].h]++;
106 freqRel[trainList[i].r]++;
107 }
108
109 std::sort(trainHead, trainHead + trainTotal, Triple::cmp_head);
110 std::sort(trainTail, trainTail + trainTotal, Triple::cmp_tail);
111 std::sort(trainRel, trainRel + trainTotal, Triple::cmp_rel);
112 printf("The total of train triples is %ld.\n", trainTotal);
113
114 lefHead = (INT *)calloc(entityTotal, sizeof(INT));
115 rigHead = (INT *)calloc(entityTotal, sizeof(INT));
116 lefTail = (INT *)calloc(entityTotal, sizeof(INT));
117 rigTail = (INT *)calloc(entityTotal, sizeof(INT));
118 lefRel = (INT *)calloc(entityTotal, sizeof(INT));
119 rigRel = (INT *)calloc(entityTotal, sizeof(INT));
120 memset(rigHead, -1, sizeof(INT) * entityTotal);
121 memset(rigTail, -1, sizeof(INT) * entityTotal);
122 memset(rigRel, -1, sizeof(INT) * entityTotal);
123 for (INT i = 1; i < trainTotal; i++) {
124 if (trainTail[i].t != trainTail[i - 1].t) {
125 rigTail[trainTail[i - 1].t] = i - 1;
126 lefTail[trainTail[i].t] = i;
127 }
128 if (trainHead[i].h != trainHead[i - 1].h) {
129 rigHead[trainHead[i - 1].h] = i - 1;
130 lefHead[trainHead[i].h] = i;
131 }
132 if (trainRel[i].h != trainRel[i - 1].h) {
133 rigRel[trainRel[i - 1].h] = i - 1;
134 lefRel[trainRel[i].h] = i;
135 }
136 }
137 lefHead[trainHead[0].h] = 0;
138 rigHead[trainHead[trainTotal - 1].h] = trainTotal - 1;
139 lefTail[trainTail[0].t] = 0;
140 rigTail[trainTail[trainTotal - 1].t] = trainTotal - 1;
141 lefRel[trainRel[0].h] = 0;
142 rigRel[trainRel[trainTotal - 1].h] = trainTotal - 1;
143
144 left_mean = (REAL *)calloc(relationTotal, sizeof(REAL));
145 right_mean = (REAL *)calloc(relationTotal, sizeof(REAL));
146 for (INT i = 0; i < entityTotal; i++) {
147 for (INT j = lefHead[i] + 1; j <= rigHead[i]; j++)
148 if (trainHead[j].r != trainHead[j - 1].r)
149 left_mean[trainHead[j].r] += 1.0;
150 if (lefHead[i] <= rigHead[i])
151 left_mean[trainHead[lefHead[i]].r] += 1.0;
152 for (INT j = lefTail[i] + 1; j <= rigTail[i]; j++)
153 if (trainTail[j].r != trainTail[j - 1].r)
154 right_mean[trainTail[j].r] += 1.0;
155 if (lefTail[i] <= rigTail[i])
156 right_mean[trainTail[lefTail[i]].r] += 1.0;
157 }
158 for (INT i = 0; i < relationTotal; i++) {
159 left_mean[i] = freqRel[i] / left_mean[i];
160 right_mean[i] = freqRel[i] / right_mean[i];
161 }
162}
163
164Triple *testList;
165Triple *validList;
166Triple *tripleList;
167
168extern "C" void importTestFiles() {
169 FILE *fin;
170 INT tmp;
171
172 if (rel_file == "")
173 fin = fopen((inPath + "relation2id.txt").c_str(), "r");
174 else
175 fin = fopen(rel_file.c_str(), "r");
176 tmp = fscanf(fin, "%ld", &relationTotal);
177 fclose(fin);
178
179 if (ent_file == "")
180 fin = fopen((inPath + "entity2id.txt").c_str(), "r");
181 else
182 fin = fopen(ent_file.c_str(), "r");
183 tmp = fscanf(fin, "%ld", &entityTotal);
184 fclose(fin);
185
186 FILE *f_kb1, *f_kb2, *f_kb3;
187 if (train_file == "")
188 f_kb2 = fopen((inPath + "train2id.txt").c_str(), "r");
189 else
190 f_kb2 = fopen(train_file.c_str(), "r");
191 if (test_file == "")
192 f_kb1 = fopen((inPath + "test2id.txt").c_str(), "r");
193 else
194 f_kb1 = fopen(test_file.c_str(), "r");
195 if (valid_file == "")
196 f_kb3 = fopen((inPath + "valid2id.txt").c_str(), "r");
197 else
198 f_kb3 = fopen(valid_file.c_str(), "r");
199 tmp = fscanf(f_kb1, "%ld", &testTotal);
200 tmp = fscanf(f_kb2, "%ld", &trainTotal);
201 tmp = fscanf(f_kb3, "%ld", &validTotal);
202 tripleTotal = testTotal + trainTotal + validTotal;
203 testList = (Triple *)calloc(testTotal, sizeof(Triple));
204 validList = (Triple *)calloc(validTotal, sizeof(Triple));
205 tripleList = (Triple *)calloc(tripleTotal, sizeof(Triple));
206 for (INT i = 0; i < testTotal; i++) {
207 tmp = fscanf(f_kb1, "%ld", &testList[i].h);
208 tmp = fscanf(f_kb1, "%ld", &testList[i].t);
209 tmp = fscanf(f_kb1, "%ld", &testList[i].r);
210 tripleList[i] = testList[i];
211 }
212 for (INT i = 0; i < trainTotal; i++) {
213 tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].h);
214 tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].t);
215 tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].r);
216 }
217 for (INT i = 0; i < validTotal; i++) {
218 tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].h);
219 tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].t);
220 tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].r);
221 validList[i] = tripleList[i + testTotal + trainTotal];
222 }
223 fclose(f_kb1);
224 fclose(f_kb2);
225 fclose(f_kb3);
226
227 std::sort(tripleList, tripleList + tripleTotal, Triple::cmp_head);
228 std::sort(testList, testList + testTotal, Triple::cmp_rel2);
229 std::sort(validList, validList + validTotal, Triple::cmp_rel2);
230 printf("The total of test triples is %ld.\n", testTotal);
231 printf("The total of valid triples is %ld.\n", validTotal);
232
233 testLef = (INT *)calloc(relationTotal, sizeof(INT));
234 testRig = (INT *)calloc(relationTotal, sizeof(INT));
235 memset(testLef, -1, sizeof(INT) * relationTotal);
236 memset(testRig, -1, sizeof(INT) * relationTotal);
237 for (INT i = 1; i < testTotal; i++) {
238 if (testList[i].r != testList[i - 1].r) {
239 testRig[testList[i - 1].r] = i - 1;
240 testLef[testList[i].r] = i;
241 }
242 }
243 testLef[testList[0].r] = 0;
244 testRig[testList[testTotal - 1].r] = testTotal - 1;
245
246 validLef = (INT *)calloc(relationTotal, sizeof(INT));
247 validRig = (INT *)calloc(relationTotal, sizeof(INT));
248 memset(validLef, -1, sizeof(INT) * relationTotal);
249 memset(validRig, -1, sizeof(INT) * relationTotal);
250 for (INT i = 1; i < validTotal; i++) {
251 if (validList[i].r != validList[i - 1].r) {
252 validRig[validList[i - 1].r] = i - 1;
253 validLef[validList[i].r] = i;
254 }
255 }
256 validLef[validList[0].r] = 0;
257 validRig[validList[validTotal - 1].r] = validTotal - 1;
258}
259
260INT *head_lef;
261INT *head_rig;
262INT *tail_lef;
263INT *tail_rig;
264INT *head_type;
265INT *tail_type;
266
267extern "C" void importTypeFiles() {
268
269 head_lef = (INT *)calloc(relationTotal, sizeof(INT));
270 head_rig = (INT *)calloc(relationTotal, sizeof(INT));
271 tail_lef = (INT *)calloc(relationTotal, sizeof(INT));
272 tail_rig = (INT *)calloc(relationTotal, sizeof(INT));
273 INT total_lef = 0;
274 INT total_rig = 0;
275 FILE *f_type = fopen((inPath + "type_constrain.txt").c_str(), "r");
276 INT tmp;
277 tmp = fscanf(f_type, "%ld", &tmp);
278 for (INT i = 0; i < relationTotal; i++) {
279 INT rel, tot;
280 tmp = fscanf(f_type, "%ld %ld", &rel, &tot);
281 for (INT j = 0; j < tot; j++) {
282 tmp = fscanf(f_type, "%ld", &tmp);
283 total_lef++;
284 }
285 tmp = fscanf(f_type, "%ld%ld", &rel, &tot);
286 for (INT j = 0; j < tot; j++) {
287 tmp = fscanf(f_type, "%ld", &tmp);
288 total_rig++;
289 }
290 }
291 fclose(f_type);
292 head_type = (INT *)calloc(total_lef, sizeof(INT));
293 tail_type = (INT *)calloc(total_rig, sizeof(INT));
294 total_lef = 0;
295 total_rig = 0;
296 f_type = fopen((inPath + "type_constrain.txt").c_str(), "r");
297 tmp = fscanf(f_type, "%ld", &tmp);
298 for (INT i = 0; i < relationTotal; i++) {
299 INT rel, tot;
300 tmp = fscanf(f_type, "%ld%ld", &rel, &tot);
301 head_lef[rel] = total_lef;
302 for (INT j = 0; j < tot; j++) {
303 tmp = fscanf(f_type, "%ld", &head_type[total_lef]);
304 total_lef++;
305 }
306 head_rig[rel] = total_lef;
307 std::sort(head_type + head_lef[rel], head_type + head_rig[rel]);
308 tmp = fscanf(f_type, "%ld%ld", &rel, &tot);
309 tail_lef[rel] = total_rig;
310 for (INT j = 0; j < tot; j++) {
311 tmp = fscanf(f_type, "%ld", &tail_type[total_rig]);
312 total_rig++;
313 }
314 tail_rig[rel] = total_rig;
315 std::sort(tail_type + tail_lef[rel], tail_type + tail_rig[rel]);
316 }
317 fclose(f_type);
318}
319
320#endif
Definition Triple.h:5