00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043 #include "ckd_alloc.h"
00044 #include "ngram_model_arpa.h"
00045 #include "err.h"
00046 #include "pio.h"
00047 #include "listelem_alloc.h"
00048 #include "strfuncs.h"
00049
00050 #include <string.h>
00051 #include <limits.h>
00052
00053 static ngram_funcs_t ngram_model_arpa_funcs;
00054
00055 #define TSEG_BASE(m,b) ((m)->lm3g.tseg_base[(b)>>LOG_BG_SEG_SZ])
00056 #define FIRST_BG(m,u) ((m)->lm3g.unigrams[u].bigrams)
00057 #define FIRST_TG(m,b) (TSEG_BASE((m),(b))+((m)->lm3g.bigrams[b].trigrams))
00058
00059
00060
00061
00062
00063 static void
00064 init_sorted_list(sorted_list_t * l)
00065 {
00066
00067 l->list = ckd_calloc(MAX_SORTED_ENTRIES,
00068 sizeof(sorted_entry_t));
00069 l->list[0].val.l = INT_MIN;
00070 l->list[0].lower = 0;
00071 l->list[0].higher = 0;
00072 l->free = 1;
00073 }
00074
00075 static void
00076 free_sorted_list(sorted_list_t * l)
00077 {
00078 free(l->list);
00079 }
00080
00081 static lmprob_t *
00082 vals_in_sorted_list(sorted_list_t * l)
00083 {
00084 lmprob_t *vals;
00085 int32 i;
00086
00087 vals = ckd_calloc(l->free, sizeof(lmprob_t));
00088 for (i = 0; i < l->free; i++)
00089 vals[i] = l->list[i].val;
00090 return (vals);
00091 }
00092
00093 static int32
00094 sorted_id(sorted_list_t * l, int32 *val)
00095 {
00096 int32 i = 0;
00097
00098 for (;;) {
00099 if (*val == l->list[i].val.l)
00100 return (i);
00101 if (*val < l->list[i].val.l) {
00102 if (l->list[i].lower == 0) {
00103 if (l->free >= MAX_SORTED_ENTRIES) {
00104
00105 E_WARN("sorted list overflow (%d => %d)\n",
00106 *val, l->list[i].val.l);
00107 return i;
00108 }
00109
00110 l->list[i].lower = l->free;
00111 (l->free)++;
00112 i = l->list[i].lower;
00113 l->list[i].val.l = *val;
00114 return (i);
00115 }
00116 else
00117 i = l->list[i].lower;
00118 }
00119 else {
00120 if (l->list[i].higher == 0) {
00121 if (l->free >= MAX_SORTED_ENTRIES) {
00122
00123 E_WARN("sorted list overflow (%d => %d)\n",
00124 *val, l->list[i].val);
00125 return i;
00126 }
00127
00128 l->list[i].higher = l->free;
00129 (l->free)++;
00130 i = l->list[i].higher;
00131 l->list[i].val.l = *val;
00132 return (i);
00133 }
00134 else
00135 i = l->list[i].higher;
00136 }
00137 }
00138 }
00139
00140
00141
00142
00143 static int
00144 ReadNgramCounts(FILE * fp, int32 * n_ug, int32 * n_bg, int32 * n_tg)
00145 {
00146 char string[256];
00147 int32 ngram, ngram_cnt;
00148
00149
00150 do
00151 fgets(string, sizeof(string), fp);
00152 while ((strcmp(string, "\\data\\\n") != 0) && (!feof(fp)));
00153
00154 if (strcmp(string, "\\data\\\n") != 0) {
00155 E_ERROR("No \\data\\ mark in LM file\n");
00156 return -1;
00157 }
00158
00159 *n_ug = *n_bg = *n_tg = 0;
00160 while (fgets(string, sizeof(string), fp) != NULL) {
00161 if (sscanf(string, "ngram %d=%d", &ngram, &ngram_cnt) != 2)
00162 break;
00163 switch (ngram) {
00164 case 1:
00165 *n_ug = ngram_cnt;
00166 break;
00167 case 2:
00168 *n_bg = ngram_cnt;
00169 break;
00170 case 3:
00171 *n_tg = ngram_cnt;
00172 break;
00173 default:
00174 E_ERROR("Unknown ngram (%d)\n", ngram);
00175 return -1;
00176 }
00177 }
00178
00179
00180 while ((strcmp(string, "\\1-grams:\n") != 0) && (!feof(fp)))
00181 fgets(string, sizeof(string), fp);
00182
00183
00184 if ((*n_ug <= 0) || (*n_bg <= 0) || (*n_tg < 0)) {
00185 E_ERROR("Bad or missing ngram count\n");
00186 return -1;
00187 }
00188 return 0;
00189 }
00190
00191
00192
00193
00194
00195
00196 static int
00197 ReadUnigrams(FILE * fp, ngram_model_arpa_t * model)
00198 {
00199 ngram_model_t *base = &model->base;
00200 char string[256];
00201 int32 wcnt;
00202 float p1;
00203
00204 E_INFO("Reading unigrams\n");
00205
00206 wcnt = 0;
00207 while ((fgets(string, sizeof(string), fp) != NULL) &&
00208 (strcmp(string, "\\2-grams:\n") != 0)) {
00209 char *wptr[3], *name;
00210 float32 bo_wt = 0.0f;
00211 int n;
00212
00213 if ((n = str2words(string, wptr, 3)) < 2) {
00214 if (string[0] != '\n')
00215 E_WARN("Format error; unigram ignored: %s\n", string);
00216 continue;
00217 }
00218 else {
00219 p1 = (float)atof_c(wptr[0]);
00220 name = wptr[1];
00221 if (n == 3)
00222 bo_wt = (float)atof_c(wptr[2]);
00223 }
00224
00225 if (wcnt >= base->n_counts[0]) {
00226 E_ERROR("Too many unigrams\n");
00227 return -1;
00228 }
00229
00230
00231 base->word_str[wcnt] = ckd_salloc(name);
00232 if ((hash_table_enter(base->wid, base->word_str[wcnt], (void *)(long)wcnt))
00233 != (void *)(long)wcnt) {
00234 E_WARN("Duplicate word in dictionary: %s\n", base->word_str[wcnt]);
00235 }
00236 model->lm3g.unigrams[wcnt].prob1.l = logmath_log10_to_log(base->lmath, p1);
00237 model->lm3g.unigrams[wcnt].bo_wt1.l = logmath_log10_to_log(base->lmath, bo_wt);
00238 wcnt++;
00239 }
00240
00241 if (base->n_counts[0] != wcnt) {
00242 E_WARN("lm_t.ucount(%d) != #unigrams read(%d)\n",
00243 base->n_counts[0], wcnt);
00244 base->n_counts[0] = wcnt;
00245 base->n_words = wcnt;
00246 }
00247 return 0;
00248 }
00249
00250
00251
00252
00253 static int
00254 ReadBigrams(FILE * fp, ngram_model_arpa_t * model)
00255 {
00256 ngram_model_t *base = &model->base;
00257 char string[1024];
00258 int32 w1, w2, prev_w1, bgcount;
00259 bigram_t *bgptr;
00260
00261 E_INFO("Reading bigrams\n");
00262
00263 bgcount = 0;
00264 bgptr = model->lm3g.bigrams;
00265 prev_w1 = -1;
00266
00267 while (fgets(string, sizeof(string), fp) != NULL) {
00268 float32 p, bo_wt = 0.0f;
00269 int32 p2, bo_wt2;
00270 char *wptr[4], *word1, *word2;
00271 int n;
00272
00273 wptr[3] = NULL;
00274 if ((n = str2words(string, wptr, 4)) < 3) {
00275 if (string[0] != '\n')
00276 break;
00277 continue;
00278 }
00279 else {
00280 p = (float32)atof_c(wptr[0]);
00281 word1 = wptr[1];
00282 word2 = wptr[2];
00283 if (wptr[3])
00284 bo_wt = (float32)atof_c(wptr[3]);
00285 }
00286
00287 if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00288 E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00289 word1, word1, word2);
00290 continue;
00291 }
00292 if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00293 E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00294 word2, word1, word2);
00295 continue;
00296 }
00297
00298
00299
00300 p = (float32)((int32)(p * 10000)) / 10000;
00301 bo_wt = (float32)((int32)(bo_wt * 10000)) / 10000;
00302
00303 p2 = logmath_log10_to_log(base->lmath, p);
00304 bo_wt2 = logmath_log10_to_log(base->lmath, bo_wt);
00305
00306 if (bgcount >= base->n_counts[1]) {
00307 E_ERROR("Too many bigrams\n");
00308 return -1;
00309 }
00310
00311 bgptr->wid = w2;
00312 bgptr->prob2 = sorted_id(&model->sorted_prob2, &p2);
00313 if (base->n_counts[2] > 0)
00314 bgptr->bo_wt2 = sorted_id(&model->sorted_bo_wt2, &bo_wt2);
00315
00316 if (w1 != prev_w1) {
00317 if (w1 < prev_w1) {
00318 E_ERROR("Bigrams not in unigram order\n");
00319 return -1;
00320 }
00321
00322 for (prev_w1++; prev_w1 <= w1; prev_w1++)
00323 model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00324 prev_w1 = w1;
00325 }
00326
00327 bgcount++;
00328 bgptr++;
00329
00330 if ((bgcount & 0x0000ffff) == 0) {
00331 E_INFOCONT(".");
00332 }
00333 }
00334 if ((strcmp(string, "\\end\\") != 0)
00335 && (strcmp(string, "\\3-grams:") != 0)) {
00336 E_ERROR("Bad bigram: %s\n", string);
00337 return -1;
00338 }
00339
00340 for (prev_w1++; prev_w1 <= base->n_counts[0]; prev_w1++)
00341 model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00342
00343 return 0;
00344 }
00345
00346
00347
00348
00349 static int
00350 ReadTrigrams(FILE * fp, ngram_model_arpa_t * model)
00351 {
00352 ngram_model_t *base = &model->base;
00353 char string[1024];
00354 int32 i, w1, w2, w3, prev_w1, prev_w2, tgcount, prev_bg, bg, endbg;
00355 int32 seg, prev_seg, prev_seg_lastbg;
00356 trigram_t *tgptr;
00357 bigram_t *bgptr;
00358
00359 E_INFO("Reading trigrams\n");
00360
00361 tgcount = 0;
00362 tgptr = model->lm3g.trigrams;
00363 prev_w1 = -1;
00364 prev_w2 = -1;
00365 prev_bg = -1;
00366 prev_seg = -1;
00367
00368 while (fgets(string, sizeof(string), fp) != NULL) {
00369 float32 p;
00370 int32 p3;
00371 char *wptr[4], *word1, *word2, *word3;
00372
00373 if (str2words(string, wptr, 4) != 4) {
00374 if (string[0] != '\n')
00375 break;
00376 continue;
00377 }
00378 else {
00379 p = (float32)atof_c(wptr[0]);
00380 word1 = wptr[1];
00381 word2 = wptr[2];
00382 word3 = wptr[3];
00383 }
00384
00385 if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00386 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00387 word1, word1, word2, word3);
00388 continue;
00389 }
00390 if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00391 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00392 word2, word1, word2, word3);
00393 continue;
00394 }
00395 if ((w3 = ngram_wid(base, word3)) == NGRAM_INVALID_WID) {
00396 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00397 word3, word1, word2, word3);
00398 continue;
00399 }
00400
00401
00402
00403 p = (float32)((int32)(p * 10000)) / 10000;
00404 p3 = logmath_log10_to_log(base->lmath, p);
00405
00406 if (tgcount >= base->n_counts[2]) {
00407 E_ERROR("Too many trigrams\n");
00408 return -1;
00409 }
00410
00411 tgptr->wid = w3;
00412 tgptr->prob3 = sorted_id(&model->sorted_prob3, &p3);
00413
00414 if ((w1 != prev_w1) || (w2 != prev_w2)) {
00415
00416 if ((w1 < prev_w1) || ((w1 == prev_w1) && (w2 < prev_w2))) {
00417 E_ERROR("Trigrams not in bigram order\n");
00418 return -1;
00419 }
00420
00421 bg = (w1 !=
00422 prev_w1) ? model->lm3g.unigrams[w1].bigrams : prev_bg + 1;
00423 endbg = model->lm3g.unigrams[w1 + 1].bigrams;
00424 bgptr = model->lm3g.bigrams + bg;
00425 for (; (bg < endbg) && (bgptr->wid != w2); bg++, bgptr++);
00426 if (bg >= endbg) {
00427 E_ERROR("Missing bigram for trigram: %s", string);
00428 return -1;
00429 }
00430
00431
00432 seg = bg >> LOG_BG_SEG_SZ;
00433 for (i = prev_seg + 1; i <= seg; i++)
00434 model->lm3g.tseg_base[i] = tgcount;
00435
00436
00437 if (prev_seg < seg) {
00438 int32 tgoff = 0;
00439
00440 if (prev_seg >= 0) {
00441 tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00442 if (tgoff > 65535) {
00443 E_ERROR("Offset from tseg_base > 65535\n");
00444 return -1;
00445 }
00446 }
00447
00448 prev_seg_lastbg = ((prev_seg + 1) << LOG_BG_SEG_SZ) - 1;
00449 bgptr = model->lm3g.bigrams + prev_bg;
00450 for (++prev_bg, ++bgptr; prev_bg <= prev_seg_lastbg;
00451 prev_bg++, bgptr++)
00452 bgptr->trigrams = tgoff;
00453
00454 for (; prev_bg <= bg; prev_bg++, bgptr++)
00455 bgptr->trigrams = 0;
00456 }
00457 else {
00458 int32 tgoff;
00459
00460 tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00461 if (tgoff > 65535) {
00462 E_ERROR("Offset from tseg_base > 65535\n");
00463 return -1;
00464 }
00465
00466 bgptr = model->lm3g.bigrams + prev_bg;
00467 for (++prev_bg, ++bgptr; prev_bg <= bg; prev_bg++, bgptr++)
00468 bgptr->trigrams = tgoff;
00469 }
00470
00471 prev_w1 = w1;
00472 prev_w2 = w2;
00473 prev_bg = bg;
00474 prev_seg = seg;
00475 }
00476
00477 tgcount++;
00478 tgptr++;
00479
00480 if ((tgcount & 0x0000ffff) == 0) {
00481 E_INFOCONT(".");
00482 }
00483 }
00484 if (strcmp(string, "\\end\\") != 0) {
00485 E_ERROR("Bad trigram: %s\n", string);
00486 return -1;
00487 }
00488
00489 for (prev_bg++; prev_bg <= base->n_counts[1]; prev_bg++) {
00490 if ((prev_bg & (BG_SEG_SZ - 1)) == 0)
00491 model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ] = tgcount;
00492 if ((tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ]) > 65535) {
00493 E_ERROR("Offset from tseg_base > 65535\n");
00494 return -1;
00495 }
00496 model->lm3g.bigrams[prev_bg].trigrams =
00497 tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ];
00498 }
00499 return 0;
00500 }
00501
00502 static unigram_t *
00503 new_unigram_table(int32 n_ug)
00504 {
00505 unigram_t *table;
00506 int32 i;
00507
00508 table = ckd_calloc(n_ug, sizeof(unigram_t));
00509 for (i = 0; i < n_ug; i++) {
00510 table[i].prob1.l = INT_MIN;
00511 table[i].bo_wt1.l = INT_MIN;
00512 }
00513 return table;
00514 }
00515
00516 ngram_model_t *
00517 ngram_model_arpa_read(cmd_ln_t *config,
00518 const char *file_name,
00519 logmath_t *lmath)
00520 {
00521 FILE *fp;
00522 int32 is_pipe;
00523 int32 n_unigram;
00524 int32 n_bigram;
00525 int32 n_trigram;
00526 int32 n;
00527 ngram_model_arpa_t *model;
00528 ngram_model_t *base;
00529
00530 if ((fp = fopen_comp(file_name, "r", &is_pipe)) == NULL) {
00531 E_ERROR("File %s not found\n", file_name);
00532 return NULL;
00533 }
00534
00535
00536 if (ReadNgramCounts(fp, &n_unigram, &n_bigram, &n_trigram) == -1) {
00537 fclose_comp(fp, is_pipe);
00538 return NULL;
00539 }
00540 E_INFO("ngrams 1=%d, 2=%d, 3=%d\n", n_unigram, n_bigram, n_trigram);
00541
00542
00543 model = ckd_calloc(1, sizeof(*model));
00544 base = &model->base;
00545 if (n_trigram > 0)
00546 n = 3;
00547 else if (n_bigram > 0)
00548 n = 2;
00549 else
00550 n = 1;
00551
00552 ngram_model_init(base, &ngram_model_arpa_funcs, lmath, n, n_unigram);
00553 base->n_counts[0] = n_unigram;
00554 base->n_counts[1] = n_bigram;
00555 base->n_counts[2] = n_trigram;
00556 base->writable = TRUE;
00557
00558
00559
00560
00561
00562 model->lm3g.unigrams = new_unigram_table(n_unigram + 1);
00563 model->lm3g.bigrams =
00564 ckd_calloc(n_bigram + 1, sizeof(bigram_t));
00565 if (n_trigram > 0)
00566 model->lm3g.trigrams =
00567 ckd_calloc(n_trigram, sizeof(trigram_t));
00568
00569 if (n_trigram > 0) {
00570 model->lm3g.tseg_base =
00571 ckd_calloc((n_bigram + 1) / BG_SEG_SZ + 1,
00572 sizeof(int32));
00573 }
00574 if (ReadUnigrams(fp, model) == -1) {
00575 fclose_comp(fp, is_pipe);
00576 ngram_model_free(base);
00577 return NULL;
00578 }
00579 E_INFO("%8d = #unigrams created\n", base->n_counts[0]);
00580
00581 init_sorted_list(&model->sorted_prob2);
00582 if (base->n_counts[2] > 0)
00583 init_sorted_list(&model->sorted_bo_wt2);
00584
00585 if (ReadBigrams(fp, model) == -1) {
00586 fclose_comp(fp, is_pipe);
00587 ngram_model_free(base);
00588 return NULL;
00589 }
00590
00591 base->n_counts[1] = FIRST_BG(model, base->n_counts[0]);
00592 model->lm3g.n_prob2 = model->sorted_prob2.free;
00593 model->lm3g.prob2 = vals_in_sorted_list(&model->sorted_prob2);
00594 free_sorted_list(&model->sorted_prob2);
00595 E_INFO("%8d = #bigrams created\n", base->n_counts[1]);
00596 E_INFO("%8d = #prob2 entries\n", model->lm3g.n_prob2);
00597
00598 if (base->n_counts[2] > 0) {
00599
00600 model->lm3g.n_bo_wt2 = model->sorted_bo_wt2.free;
00601 model->lm3g.bo_wt2 = vals_in_sorted_list(&model->sorted_bo_wt2);
00602 free_sorted_list(&model->sorted_bo_wt2);
00603 E_INFO("%8d = #bo_wt2 entries\n", model->lm3g.n_bo_wt2);
00604
00605 init_sorted_list(&model->sorted_prob3);
00606
00607 if (ReadTrigrams(fp, model) == -1) {
00608 fclose_comp(fp, is_pipe);
00609 ngram_model_free(base);
00610 return NULL;
00611 }
00612
00613 base->n_counts[2] = FIRST_TG(model, base->n_counts[1]);
00614 model->lm3g.n_prob3 = model->sorted_prob3.free;
00615 model->lm3g.prob3 = vals_in_sorted_list(&model->sorted_prob3);
00616 E_INFO("%8d = #trigrams created\n", base->n_counts[2]);
00617 E_INFO("%8d = #prob3 entries\n", model->lm3g.n_prob3);
00618
00619 free_sorted_list(&model->sorted_prob3);
00620
00621
00622 model->lm3g.tginfo = ckd_calloc(n_unigram, sizeof(tginfo_t *));
00623 model->lm3g.le = listelem_alloc_init(sizeof(tginfo_t));
00624 }
00625
00626 fclose_comp(fp, is_pipe);
00627 return base;
00628 }
00629
00630 int
00631 ngram_model_arpa_write(ngram_model_t *model,
00632 const char *file_name)
00633 {
00634 return -1;
00635 }
00636
00637 static int
00638 ngram_model_arpa_apply_weights(ngram_model_t *base, float32 lw,
00639 float32 wip, float32 uw)
00640 {
00641 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00642 lm3g_apply_weights(base, &model->lm3g, lw, wip, uw);
00643 return 0;
00644 }
00645
00646
00647
00648
00649 #define NGRAM_MODEL_TYPE ngram_model_arpa_t
00650 #include "lm3g_templates.c"
00651
00652 static void
00653 ngram_model_arpa_free(ngram_model_t *base)
00654 {
00655 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00656 ckd_free(model->lm3g.unigrams);
00657 ckd_free(model->lm3g.bigrams);
00658 ckd_free(model->lm3g.trigrams);
00659 ckd_free(model->lm3g.prob2);
00660 ckd_free(model->lm3g.bo_wt2);
00661 ckd_free(model->lm3g.prob3);
00662 lm3g_tginfo_free(base, &model->lm3g);
00663 ckd_free(model->lm3g.tseg_base);
00664 }
00665
00666 static ngram_funcs_t ngram_model_arpa_funcs = {
00667 ngram_model_arpa_free,
00668 ngram_model_arpa_apply_weights,
00669 lm3g_template_score,
00670 lm3g_template_raw_score,
00671 lm3g_template_add_ug,
00672 lm3g_template_flush,
00673 lm3g_template_iter,
00674 lm3g_template_mgrams,
00675 lm3g_template_successors,
00676 lm3g_template_iter_get,
00677 lm3g_template_iter_next,
00678 lm3g_template_iter_free
00679 };