Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
- fix syntax
- add transaltion interface for remoteParserClient
  • Loading branch information
Mehrad0711 committed Oct 7, 2021
1 parent dbaa1a1 commit 58edbfb
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 41 deletions.
33 changes: 4 additions & 29 deletions lib/prediction/localparserclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import {
const SEMANTIC_PARSING_TASK = 'almond';
const NLU_TASK = 'almond_dialogue_nlu';
const NLG_TASK = 'almond_dialogue_nlg';
const Translation_TASK = 'almond_translate';
const TRANSLATION_TASK = 'almond_translate';
const NLG_QUESTION = 'what should the agent say ?';

export interface LocalParserOptions {
Expand All @@ -60,22 +60,6 @@ function compareScore(a : PredictionCandidate, b : PredictionCandidate) : number
return b.score - a.score;
}

function substringSpan(sequence : string[], substring : string[]) : [number, number] | null {
for (let i=0; i < sequence.length; i++) {
let found = true;
for (let j = 0; j < substring.length; j++) {
if (sequence[i+j] !== substring[j]) {
found = false;
break;
}
}
if (found)
return [i, i + substring.length + 1];
}
return null;
}


export default class LocalParserClient {
private _locale : string;
private _langPack : I18n.LanguagePack;
Expand Down Expand Up @@ -280,18 +264,9 @@ export default class LocalParserClient {
});
}

async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record<string, any>) : Promise<GenerationResult[]> {
if (contextEntities) {
const allEntities = Object.keys(contextEntities).map((ent) => ent.split(' '));
for (const entity of allEntities) {
const span = substringSpan(input, entity);
if (span) {
input.splice(span[0], 0, '"');
input.splice(span[1], 0, '"');
}
}
}
const candidates = await this._predictor.predict('', input.join(' '), undefined, Translation_TASK, 'id-null', translationOptions);
async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record<string, unknown>) : Promise<GenerationResult[]> {
input = Utils.qpisEntities(input, contextEntities);
const candidates = await this._predictor.predict('', input.join(' '), undefined, TRANSLATION_TASK, 'id-null', translationOptions);
return candidates.map((cand) => {
return {
answer: cand.answer,
Expand Down
12 changes: 6 additions & 6 deletions lib/prediction/predictor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class LocalWorker extends events.EventEmitter {
this._requests.clear();
}

request(task : string, minibatch : Example[], options : Record<string, any>) : Promise<RawPredictionCandidate[][]> {
request(task : string, minibatch : Example[], options : Record<string, unknown>) : Promise<RawPredictionCandidate[][]> {
const id = this._nextId ++;

return new Promise((resolve, reject) => {
Expand Down Expand Up @@ -179,7 +179,7 @@ class RemoteWorker extends events.EventEmitter {
start() {}
stop() {}

async request(task : string, minibatch : Example[], options : Record<string, any>) : Promise<RawPredictionCandidate[][]> {
async request(task : string, minibatch : Example[], options : Record<string, unknown>) : Promise<RawPredictionCandidate[][]> {
const response = await Tp.Helpers.Http.post(this._url, JSON.stringify({
task,
instances: minibatch,
Expand Down Expand Up @@ -244,7 +244,7 @@ export default class Predictor {
});
}

private _startRequest(ex : Example, task : string, options : Record<string, any>, now : number) {
private _startRequest(ex : Example, task : string, options : Record<string, unknown>, now : number) {
assert(this._minibatch.length === 0);
this._minibatch.push(ex);
this._minibatchTask = task;
Expand All @@ -257,7 +257,7 @@ export default class Predictor {
}, this._maxLatency);
}

private _addRequest(ex : Example, task : string, options : Record<string, any>) {
private _addRequest(ex : Example, task : string, options : Record<string, unknown>) {
const now = Date.now();
if (this._minibatch.length === 0) {
this._startRequest(ex, task, options, now);
Expand All @@ -271,7 +271,7 @@ export default class Predictor {
}
}

predict(context : string, question : string = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string, options : Record<string, any> = {}) : Promise<RawPredictionCandidate[]> {
predict(context : string, question : string = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string, options : Record<string, unknown> = {}) : Promise<RawPredictionCandidate[]> {

// ensure we have a worker, in case it recently died
if (!this._worker)
Expand All @@ -283,7 +283,7 @@ export default class Predictor {
resolve = _resolve;
reject = _reject;
});
this._addRequest({ context, question, answer, example_id, resolve, reject }, task, options);
this._addRequest({ context, question, answer, resolve, reject }, task, options);

return promise;
}
Expand Down
24 changes: 24 additions & 0 deletions lib/prediction/remoteparserclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import * as ThingTalk from 'thingtalk';
import * as Tp from 'thingpedia';
import * as Utils from '../utils/misc-utils';
import qs from 'qs';

import { EntityMap } from '../utils/entity-utils';
Expand Down Expand Up @@ -168,4 +169,27 @@ export default class RemoteParserClient {

return parsed.candidates;
}

async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record<string, unknown>) : Promise<GenerationResult[]> {
input = Utils.qpisEntities(input, contextEntities);

const data = {
input: input.join(' '),
tgt_locale: translationOptions.tgt_locale,
entities: contextEntities,
alignment: translationOptions.alignment,
src_locale: translationOptions.src_locale,
align_remove_output_quotation: translationOptions.align_remove_output_quotation
};

const response = await Tp.Helpers.Http.post(`${this._baseUrl}/translate`, JSON.stringify(data), {
dataContentType: 'application/json' //'
});
const parsed = JSON.parse(response);
if (parsed.error)
throw new Error('Error received from Genie server: ' + parsed.error);

return parsed.candidates;
}

}
34 changes: 34 additions & 0 deletions lib/utils/misc-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import {
makeDummyEntity,
makeDummyEntities,
renumberEntities,
EntityMap,
} from './entity-utils';

class ValidationError extends Error {
Expand Down Expand Up @@ -200,6 +201,37 @@ function isHumanEntity(type : Type|string) : boolean {
return false;
}

function substringSpan(sequence : string[], substring : string[]) : [number, number] | null {
for (let i=0; i < sequence.length; i++) {
let found = true;
for (let j = 0; j < substring.length; j++) {
if (sequence[i+j] !== substring[j]) {
found = false;
break;
}
}
if (found)
return [i, i + substring.length + 1];
}
return null;
}


function qpisEntities(input : string[], contextEntities : EntityMap|undefined) : string[] {
if (contextEntities) {
const allEntities = Object.keys(contextEntities).map((ent) => ent.split(' '));
for (const entity of allEntities) {
const span = substringSpan(input, entity);
if (span) {
input.splice(span[0], 0, '"');
input.splice(span[1], 0, '"');
}
}
}
return input;
}


export {
splitParams,
split,
Expand All @@ -212,4 +244,6 @@ export {
makeDummyEntity,
makeDummyEntities,
renumberEntities,

qpisEntities
};
37 changes: 31 additions & 6 deletions tool/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ interface TranslationData {
limit ?: string;
alignment ?: boolean;
src_locale ?: string;
align_remove_output_quotation ?: boolean
}
const Translation_PARAMS = {
input: 'string',
Expand All @@ -165,8 +166,31 @@ const Translation_PARAMS = {
limit: '?number',
alignment: '?boolean',
src_locale: '?string',
align_remove_output_quotation: '?boolean'
};


// const VALID_PARSER_OPTIONS = new Set([
// "num_beams",
// "num_beam_groups",
// "diversity_penalty",
// "num_outputs",
// "no_repeat_ngram_size",
// "top_p",
// "top_k",
// "repetition_penalty",
// "temperature",
// "max_output_length",
// "reduce_metrics",
// "database_dir",
// "do_alignment",
// "align_preserve_input_quotation",
// "align_remove_output_quotation",
// "src_locale",
// "tgt_locale",
// "translate_example_split"
// ]);

async function queryTranslate(params : Record<string, string>,
data : TranslationData,
res : express.Response) {
Expand All @@ -177,14 +201,15 @@ async function queryTranslate(params : Record<string, string>,
return;
}

if (! data.src_locale)
if (!data.src_locale)
data.src_locale = 'en-US';

const translationOptions : Record<string, any> = {
'do_alignment': data.alignment,
'align_remove_output_quotation': true,
'src_locale': data.src_locale,
'tgt_locale': data.tgt_locale
const translationOptions : Record<string, unknown> = {
'src_locale': data.src_locale,
'tgt_locale': data.tgt_locale,
'do_alignment': data.alignment,
'align_remove_output_quotation': data.align_remove_output_quotation,

};

const result = await res.app.backend.translator!.translateUtterance(
Expand Down

0 comments on commit 58edbfb

Please sign in to comment.