if word_level:
text = [meta_token] + prefix.split() if prefix else [meta_token]
else:
text = [meta_token] + list(prefix) if prefix else [meta_token]
next_char = ""
if not isinstance(temperature, list):
temperature = [temperature]
if model_input_count(model) > 1:
model = Model(inputs=model.input[0], outputs=model.output[1])
while next_char != meta_token and len(text) < max_gen_length:
encoded_text = textgenrnn_encode_sequence(text[-maxlen:],
vocab, maxlen)
next_temperature = temperature[(len(text) - 1) % len(temperature)]
next_index = textgenrnn_sample(
model.predict(encoded_text, batch_size=1)[0],
next_temperature)
next_char = indices_char[next_index]
text += [next_char]
collapse_char = " " if word_level else ""
// if single text, ignore sequences generated w/ padding
// if not single text, strip the <s> meta_tokens
if single_text:
text = text[maxlen:]
else:
text = text[1:-1]
text_joined = collapse_char.join(text)
// If word level, remove spaces around punctuation for cleanliness.
if word_level:
// left_punct = "!%),.:;?@]_}\\n\\t""
// right_punct = "$([_\\n\\t""
punct = "\\n\\t"
text_joined = re.sub(" ([{}]) ".format(punct), r"\1", text_joined)
// text_joined = re.sub(" ([{}])".format(
// left_punct), r"\1", text_joined)
// text_joined = re.sub("([{}]) ".format(
After Change
prefix_t = [x.lower() for x in prefix.split()]
if not word_level and prefix:
prefix_t = list(prefix)
if single_text:
text = prefix_t if prefix else [""]
max_gen_length += maxlen
else:
text = [meta_token] + prefix_t if prefix else [meta_token]
next_char = ""
if not isinstance(temperature, list):
temperature = [temperature]
if model_input_count(model) > 1:
model = Model(inputs=model.input[0], outputs=model.output[1])
while next_char != meta_token and len(text) < max_gen_length:
encoded_text = textgenrnn_encode_sequence(text[-maxlen:],
vocab, maxlen)
next_temperature = temperature[(len(text) - 1) % len(temperature)]
next_index = textgenrnn_sample(
model.predict(encoded_text, batch_size=1)[0],
next_temperature)
next_char = indices_char[next_index]
text += [next_char]
collapse_char = " " if word_level else ""
// if single text, ignore sequences generated w/ padding
// if not single text, strip the <s> meta_tokens
if single_text:
text = text[maxlen:]
else:
text = text[1:-1]
text_joined = collapse_char.join(text)
// If word level, remove spaces around punctuation for cleanliness.
if word_level:
// left_punct = "!%),.:;?@]_}\\n\\t""
// right_punct = "$([_\\n\\t""
punct = "\\n\\t"
text_joined = re.sub(" ([{}]) ".format(punct), r"\1", text_joined)
// text_joined = re.sub(" ([{}])".format(
// left_punct), r"\1", text_joined)
// text_joined = re.sub("([{}]) ".format(