import logging
import time
from enum import Enum
import heapq
import numpy as np
import tensorflow as tf
from rhyme_helper import RhymeWords
logging.getLogger('tensorflow').setLevel(logging.WARNING)
SampleType = Enum('SampleType',('max_prob', 'weighted_sample', 'rhyme','select_given'))
class CharRNNLM(object):
def __init__(self, is_training, batch_size, num_unrollings, vocab_size,w2v_model,
hidden_size, max_grad_norm, embedding_size, num_layers,
learning_rate, cell_type, dropout=0.0, input_dropout=0.0, infer=False):
self.batch_size = batch_size
self.num_unrollings = num_unrollings
if infer:
self.batch_size = 1
self.num_unrollings = 1
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.max_grad_norm = max_grad_norm
self.num_layers = num_layers
self.embedding_size = embedding_size
self.cell_type = cell_type
self.dropout = dropout
self.input_dropout = input_dropout
self.w2v_model = w2v_model
if embedding_size <= 0:
self.input_size = vocab_size
self.input_dropout = 0.0
else:
self.input_size = embedding_size
self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='inputs')
self.targets = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='targets')
if self.cell_type == 'rnn':
cell_fn = tf.nn.rnn_cell.BasicRNNCell
elif self.cell_type == 'lstm':
cell_fn = tf.nn.rnn_cell.LSTMCell
elif self.cell_type == 'gru':
cell_fn = tf.nn.rnn_cell.GRUCell
params = dict()
#params = {'input_size': self.input_size}
if self.cell_type == 'lstm':
params['forget_bias'] = 1.0 # 1.0 is default value
cell = cell_fn(self.hidden_size, **params)
cells = [cell]
#params['input_size'] = self.hidden_size
for i in range(self.num_layers-1):
higher_layer_cell = cell_fn(self.hidden_size, **params)
cells.append(higher_layer_cell)
if is_training and self.dropout > 0:
cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells]
multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells)
with tf.name_scope('initial_state'):
self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32)
if self.cell_type == 'rnn' or self.cell_type == 'gru':
self.initial_state = tuple(
[tf.placeholder(tf.float32,
[self.batch_size, multi_cell.state_size[idx]],
'initial_state_'+str(idx+1)) for idx in range(self.num_layers)])
elif self.cell_type == 'lstm':
self.initial_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(
tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][0]],
'initial_lstm_state_'+str(idx+1)),
tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][1]],
'initial_lstm_state_'+str(idx+1)))
for idx in range(self.num_layers)])
with tf.name_scope('embedding_layer'):
if embedding_size > 0:
# self.embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size])
self.embedding = tf.get_variable("word_embeddings",
initializer=self.w2v_model.vectors.astype(np.float32))
else:
self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32)
inputs = tf.nn.embedding_lookup(self.embedding, self.input_data)
if is_training and self.input_dropout > 0:
inputs = tf.nn.dropout(inputs, 1-self.input_dropout)
with tf.name_scope('slice_inputs'):
# num_unrollings * (batch_size, embedding_size), the format of rnn inputs.
sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(
axis = 1, num_or_size_splits = self.num_unrollings, value = inputs)]
# sliced_inputs: list of shape xx
# inputs: A length T list of inputs, each a Tensor of shape [batch_size, input_size]
# initial_state: An initial state for the RNN.
# If cell.state_size is an integer, this must be a Tensor of appropriate
# type and shape [batch_size, cell.state_size]
# outputs: a length T list of outputs (one for each input), or a nested tuple of such elements.
# state: the final state
outputs, final_state = tf.nn.static_rnn(
cell = multi_cell,
inputs = sliced_inputs,
initial_state=self.initial_state)
self.final_state = final_state
with tf.name_scope('flatten_outputs'):
flat_outputs = tf.reshape(tf.concat(axis = 1, values = outputs), [-1, hidden_size])
with tf.name_scope('flatten_targets'):
flat_targets = tf.reshape(tf.concat(axis = 1, values = self.targets), [-1])
with tf.variable_scope('softmax') as sm_vs:
softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size])
softmax_b = tf.get_variable('softmax_b', [vocab_size])
self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
with tf.name_scope('loss'):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits = self.logits, labels = flat_targets)
self.mean_loss = tf.reduce_mean(loss)
with tf.name_scope('loss_montor'):
count = tf.Variable(1.0, name='count')
sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss')
self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0),
count.assign(0.0), name='reset_loss_monitor')
self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss),
count.assign(count+1), name='update_loss_monitor')
with tf.control_dependencies([self.update_loss_monitor]):
self.average_loss = sum_mean_loss / count
self.ppl = tf.exp(self.average_loss)
average_loss_summary = tf.summary.scalar(
name = 'average loss', tensor = self.average_loss)
ppl_summary = tf.summary.scalar(
name = 'perplexity', tensor = self.ppl)
self.summaries = tf.summary.merge(
inputs = [average_loss_summary, ppl_summary], name='loss_monitor')
self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.0))
# self.learning_rate = tf.constant(learning_rate)
self.learning_rate = tf.placeholder(tf.float32, [], name='learning_rate')
if is_training:
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm)
optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step)
def run_epoch(self, session, batch_generator, is_training, learning_rate, verbose=0, freq=10):
epoch_size = batch_generator.num_batches
if verbose > 0:
logging.info('epoch_size: %d', epoch_size)
logging.info('data_size: %d', batch_generator.seq_length)
logging.info('num_unrollings: %d', self.num_unrollings)
logging.info('batch_size: %d', self.batch_size)
if is_training:
extra_op = self.train_op
else:
extra_op = tf.no_op()
if self.cell_type in ['rnn', 'gru']:
state = self.zero_state.eval()
else:
state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
np.zeros((self.batch_size, self.hidden_size)))
for _ in range(self.num_layers)])
self.reset_loss_monitor.run()
batch_generator.reset_batch_pointer()
start_time = time.time()
ppl_cumsum = 0
for step in range(epoch_size):
x, y = batch_generator.next_batch()
ops = [self.average_loss, self.ppl, self.final_state, extra_op,
self.summaries, self.global_step]
feed_dict = {self.input_data: x, self.targets: y, self.initial_state: state,
self.learning_rate: learning_rate}
results = session.run(ops, feed_dict)
average_loss, ppl, final_state, _, summary_str, global_step = results
ppl_cumsum += ppl
# if (verbose > 0) and ((step+1) % freq == 0):
if ((step+1) % freq == 0):
logging.info('%.1f%%, step:%d, perplexity: %.3f, speed: %.0f words',
(step + 1) * 1.0 / epoch_size * 100, step, ppl_cumsum/(step+1),
(step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
logging.info("Perplexity: %.3f, speed: %.0f words per sec",
ppl, (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
return ppl, summary_str, global_step
def sample_seq(self, session, length, start_text, sample_type= SampleType.max_prob,given='',rhyme_ref='',rhyme_idx = 0):
#state = self.zero_state.eval()
if self.cell_type in ['rnn', 'gru']:
state = self.zero_state.eval()
else:
state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
np.zeros((self.batch_size, self.hidden_size)))
for _ in range(self.num_layers)])
# use start_text to warm up the RNN.
start_text = self.check_start(start_text)
if start_text is not None and len(start_text) > 0:
seq = list(start_text)
for char in start_text[:-1]:
x = np.array([[self.w2v_model.vocab_hash[char]]])
state = session.run(self.final_state, {self.input_data: x, self.initial_state: state})
x = np.array([[self.w2v_model.vocab_hash[start_text[-1]]]])
else:
x = np.array([[np.random.randint(0, self.vocab_size)]])
seq = []
for i in range(length):
state, logits = session.run([self.final_state, self.logits],
{self.input_data: x, self.initial_state: state})
unnormalized_probs = np.exp(logits[0] - np.max(logits[0]))
probs = unnormalized_probs / np.sum(unnormalized_probs)
if rhyme_ref and i == rhyme_idx :
sample = self.select_rhyme(rhyme_ref, probs)
elif sample_type == SampleType.max_prob:
sample = np.argmax(probs)
elif sample_type == SampleType.select_given:
sample, given = self.select_by_given(given, probs)
else: #SampleType.weighted_sample
sample = np.random.choice(self.vocab_size, 1, p=probs)[0]
seq.append(self.w2v_model.vocab[sample])
x = np.array([[sample]])
return ''.join(seq)
def select_by_given(self, given, probs, max_prob=False):
if given:
seq_probs = zip(probs, range(0, self.vocab_size))
topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0])
for _, seq in topn:
if self.w2v_model.vocab[seq] in given:
given = given.replace(self.w2v_model.vocab[seq], '')
return seq, given
if max_prob:
return np.argmax(probs), given
return np.random.choice(self.vocab_size, 1, p=probs)[0],given
def select_rhyme(self, rhyme_ref, probs):
if rhyme_ref:
rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref)
if rhyme_set:
seq_probs = zip(probs, range(0, self.vocab_size))
topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0])
for _, seq in topn:
if self.w2v_model.vocab[seq] in rhyme_set:
return seq
return np.argmax(probs)
def check_start(self, text):
idx = text.find('<')
if idx > -1:
text = text[:idx]
valid_text = []
for w in text:
if w in self.w2v_model.vocab:
valid_text.append(w)
return ''.join(valid_text)