|
0 |
import logging
|
|
1 |
import time
|
|
2 |
from enum import Enum
|
|
3 |
import heapq
|
|
4 |
import numpy as np
|
|
5 |
import tensorflow as tf
|
|
6 |
from rhyme_helper import RhymeWords
|
|
7 |
|
|
8 |
logging.getLogger('tensorflow').setLevel(logging.WARNING)
|
|
9 |
SampleType = Enum('SampleType',('max_prob', 'weighted_sample', 'rhyme','select_given'))
|
|
10 |
|
|
11 |
class CharRNNLM(object):
|
|
12 |
def __init__(self, is_training, batch_size, num_unrollings, vocab_size,w2v_model,
|
|
13 |
hidden_size, max_grad_norm, embedding_size, num_layers,
|
|
14 |
learning_rate, cell_type, dropout=0.0, input_dropout=0.0, infer=False):
|
|
15 |
self.batch_size = batch_size
|
|
16 |
self.num_unrollings = num_unrollings
|
|
17 |
if infer:
|
|
18 |
self.batch_size = 1
|
|
19 |
self.num_unrollings = 1
|
|
20 |
self.hidden_size = hidden_size
|
|
21 |
self.vocab_size = vocab_size
|
|
22 |
self.max_grad_norm = max_grad_norm
|
|
23 |
self.num_layers = num_layers
|
|
24 |
self.embedding_size = embedding_size
|
|
25 |
self.cell_type = cell_type
|
|
26 |
self.dropout = dropout
|
|
27 |
self.input_dropout = input_dropout
|
|
28 |
self.w2v_model = w2v_model
|
|
29 |
|
|
30 |
if embedding_size <= 0:
|
|
31 |
self.input_size = vocab_size
|
|
32 |
self.input_dropout = 0.0
|
|
33 |
else:
|
|
34 |
self.input_size = embedding_size
|
|
35 |
|
|
36 |
self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='inputs')
|
|
37 |
self.targets = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='targets')
|
|
38 |
|
|
39 |
if self.cell_type == 'rnn':
|
|
40 |
cell_fn = tf.nn.rnn_cell.BasicRNNCell
|
|
41 |
elif self.cell_type == 'lstm':
|
|
42 |
cell_fn = tf.nn.rnn_cell.LSTMCell
|
|
43 |
elif self.cell_type == 'gru':
|
|
44 |
cell_fn = tf.nn.rnn_cell.GRUCell
|
|
45 |
|
|
46 |
params = dict()
|
|
47 |
#params = {'input_size': self.input_size}
|
|
48 |
if self.cell_type == 'lstm':
|
|
49 |
params['forget_bias'] = 1.0 # 1.0 is default value
|
|
50 |
cell = cell_fn(self.hidden_size, **params)
|
|
51 |
|
|
52 |
cells = [cell]
|
|
53 |
#params['input_size'] = self.hidden_size
|
|
54 |
for i in range(self.num_layers-1):
|
|
55 |
higher_layer_cell = cell_fn(self.hidden_size, **params)
|
|
56 |
cells.append(higher_layer_cell)
|
|
57 |
|
|
58 |
if is_training and self.dropout > 0:
|
|
59 |
cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells]
|
|
60 |
|
|
61 |
multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells)
|
|
62 |
|
|
63 |
with tf.name_scope('initial_state'):
|
|
64 |
self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32)
|
|
65 |
if self.cell_type == 'rnn' or self.cell_type == 'gru':
|
|
66 |
self.initial_state = tuple(
|
|
67 |
[tf.placeholder(tf.float32,
|
|
68 |
[self.batch_size, multi_cell.state_size[idx]],
|
|
69 |
'initial_state_'+str(idx+1)) for idx in range(self.num_layers)])
|
|
70 |
elif self.cell_type == 'lstm':
|
|
71 |
self.initial_state = tuple(
|
|
72 |
[tf.nn.rnn_cell.LSTMStateTuple(
|
|
73 |
tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][0]],
|
|
74 |
'initial_lstm_state_'+str(idx+1)),
|
|
75 |
tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][1]],
|
|
76 |
'initial_lstm_state_'+str(idx+1)))
|
|
77 |
for idx in range(self.num_layers)])
|
|
78 |
|
|
79 |
with tf.name_scope('embedding_layer'):
|
|
80 |
if embedding_size > 0:
|
|
81 |
# self.embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size])
|
|
82 |
self.embedding = tf.get_variable("word_embeddings",
|
|
83 |
initializer=self.w2v_model.vectors.astype(np.float32))
|
|
84 |
else:
|
|
85 |
self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32)
|
|
86 |
|
|
87 |
inputs = tf.nn.embedding_lookup(self.embedding, self.input_data)
|
|
88 |
if is_training and self.input_dropout > 0:
|
|
89 |
inputs = tf.nn.dropout(inputs, 1-self.input_dropout)
|
|
90 |
|
|
91 |
with tf.name_scope('slice_inputs'):
|
|
92 |
# num_unrollings * (batch_size, embedding_size), the format of rnn inputs.
|
|
93 |
sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(
|
|
94 |
axis = 1, num_or_size_splits = self.num_unrollings, value = inputs)]
|
|
95 |
|
|
96 |
# sliced_inputs: list of shape xx
|
|
97 |
# inputs: A length T list of inputs, each a Tensor of shape [batch_size, input_size]
|
|
98 |
# initial_state: An initial state for the RNN.
|
|
99 |
# If cell.state_size is an integer, this must be a Tensor of appropriate
|
|
100 |
# type and shape [batch_size, cell.state_size]
|
|
101 |
# outputs: a length T list of outputs (one for each input), or a nested tuple of such elements.
|
|
102 |
# state: the final state
|
|
103 |
outputs, final_state = tf.nn.static_rnn(
|
|
104 |
cell = multi_cell,
|
|
105 |
inputs = sliced_inputs,
|
|
106 |
initial_state=self.initial_state)
|
|
107 |
self.final_state = final_state
|
|
108 |
|
|
109 |
with tf.name_scope('flatten_outputs'):
|
|
110 |
flat_outputs = tf.reshape(tf.concat(axis = 1, values = outputs), [-1, hidden_size])
|
|
111 |
|
|
112 |
with tf.name_scope('flatten_targets'):
|
|
113 |
flat_targets = tf.reshape(tf.concat(axis = 1, values = self.targets), [-1])
|
|
114 |
|
|
115 |
with tf.variable_scope('softmax') as sm_vs:
|
|
116 |
softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size])
|
|
117 |
softmax_b = tf.get_variable('softmax_b', [vocab_size])
|
|
118 |
self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b
|
|
119 |
self.probs = tf.nn.softmax(self.logits)
|
|
120 |
|
|
121 |
with tf.name_scope('loss'):
|
|
122 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
123 |
logits = self.logits, labels = flat_targets)
|
|
124 |
self.mean_loss = tf.reduce_mean(loss)
|
|
125 |
|
|
126 |
with tf.name_scope('loss_montor'):
|
|
127 |
count = tf.Variable(1.0, name='count')
|
|
128 |
sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss')
|
|
129 |
|
|
130 |
self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0),
|
|
131 |
count.assign(0.0), name='reset_loss_monitor')
|
|
132 |
self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss),
|
|
133 |
count.assign(count+1), name='update_loss_monitor')
|
|
134 |
|
|
135 |
with tf.control_dependencies([self.update_loss_monitor]):
|
|
136 |
self.average_loss = sum_mean_loss / count
|
|
137 |
self.ppl = tf.exp(self.average_loss)
|
|
138 |
|
|
139 |
average_loss_summary = tf.summary.scalar(
|
|
140 |
name = 'average loss', tensor = self.average_loss)
|
|
141 |
ppl_summary = tf.summary.scalar(
|
|
142 |
name = 'perplexity', tensor = self.ppl)
|
|
143 |
|
|
144 |
self.summaries = tf.summary.merge(
|
|
145 |
inputs = [average_loss_summary, ppl_summary], name='loss_monitor')
|
|
146 |
|
|
147 |
self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.0))
|
|
148 |
|
|
149 |
# self.learning_rate = tf.constant(learning_rate)
|
|
150 |
self.learning_rate = tf.placeholder(tf.float32, [], name='learning_rate')
|
|
151 |
|
|
152 |
if is_training:
|
|
153 |
tvars = tf.trainable_variables()
|
|
154 |
grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm)
|
|
155 |
optimizer = tf.train.AdamOptimizer(self.learning_rate)
|
|
156 |
self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step)
|
|
157 |
|
|
158 |
|
|
159 |
def run_epoch(self, session, batch_generator, is_training, learning_rate, verbose=0, freq=10):
|
|
160 |
epoch_size = batch_generator.num_batches
|
|
161 |
|
|
162 |
if verbose > 0:
|
|
163 |
logging.info('epoch_size: %d', epoch_size)
|
|
164 |
logging.info('data_size: %d', batch_generator.seq_length)
|
|
165 |
logging.info('num_unrollings: %d', self.num_unrollings)
|
|
166 |
logging.info('batch_size: %d', self.batch_size)
|
|
167 |
|
|
168 |
if is_training:
|
|
169 |
extra_op = self.train_op
|
|
170 |
else:
|
|
171 |
extra_op = tf.no_op()
|
|
172 |
|
|
173 |
if self.cell_type in ['rnn', 'gru']:
|
|
174 |
state = self.zero_state.eval()
|
|
175 |
else:
|
|
176 |
state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
|
|
177 |
np.zeros((self.batch_size, self.hidden_size)))
|
|
178 |
for _ in range(self.num_layers)])
|
|
179 |
|
|
180 |
self.reset_loss_monitor.run()
|
|
181 |
batch_generator.reset_batch_pointer()
|
|
182 |
start_time = time.time()
|
|
183 |
ppl_cumsum = 0
|
|
184 |
for step in range(epoch_size):
|
|
185 |
x, y = batch_generator.next_batch()
|
|
186 |
|
|
187 |
ops = [self.average_loss, self.ppl, self.final_state, extra_op,
|
|
188 |
self.summaries, self.global_step]
|
|
189 |
|
|
190 |
feed_dict = {self.input_data: x, self.targets: y, self.initial_state: state,
|
|
191 |
self.learning_rate: learning_rate}
|
|
192 |
|
|
193 |
results = session.run(ops, feed_dict)
|
|
194 |
average_loss, ppl, final_state, _, summary_str, global_step = results
|
|
195 |
ppl_cumsum += ppl
|
|
196 |
|
|
197 |
# if (verbose > 0) and ((step+1) % freq == 0):
|
|
198 |
if ((step+1) % freq == 0):
|
|
199 |
logging.info('%.1f%%, step:%d, perplexity: %.3f, speed: %.0f words',
|
|
200 |
(step + 1) * 1.0 / epoch_size * 100, step, ppl_cumsum/(step+1),
|
|
201 |
(step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
|
|
202 |
logging.info("Perplexity: %.3f, speed: %.0f words per sec",
|
|
203 |
ppl, (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
|
|
204 |
|
|
205 |
return ppl, summary_str, global_step
|
|
206 |
|
|
207 |
def sample_seq(self, session, length, start_text, sample_type= SampleType.max_prob,given='',rhyme_ref='',rhyme_idx = 0):
|
|
208 |
#state = self.zero_state.eval()
|
|
209 |
if self.cell_type in ['rnn', 'gru']:
|
|
210 |
state = self.zero_state.eval()
|
|
211 |
else:
|
|
212 |
state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
|
|
213 |
np.zeros((self.batch_size, self.hidden_size)))
|
|
214 |
for _ in range(self.num_layers)])
|
|
215 |
|
|
216 |
# use start_text to warm up the RNN.
|
|
217 |
start_text = self.check_start(start_text)
|
|
218 |
if start_text is not None and len(start_text) > 0:
|
|
219 |
seq = list(start_text)
|
|
220 |
for char in start_text[:-1]:
|
|
221 |
x = np.array([[self.w2v_model.vocab_hash[char]]])
|
|
222 |
state = session.run(self.final_state, {self.input_data: x, self.initial_state: state})
|
|
223 |
x = np.array([[self.w2v_model.vocab_hash[start_text[-1]]]])
|
|
224 |
else:
|
|
225 |
x = np.array([[np.random.randint(0, self.vocab_size)]])
|
|
226 |
seq = []
|
|
227 |
|
|
228 |
for i in range(length):
|
|
229 |
state, logits = session.run([self.final_state, self.logits],
|
|
230 |
{self.input_data: x, self.initial_state: state})
|
|
231 |
unnormalized_probs = np.exp(logits[0] - np.max(logits[0]))
|
|
232 |
probs = unnormalized_probs / np.sum(unnormalized_probs)
|
|
233 |
|
|
234 |
if rhyme_ref and i == rhyme_idx :
|
|
235 |
sample = self.select_rhyme(rhyme_ref,probs)
|
|
236 |
elif sample_type == SampleType.max_prob:
|
|
237 |
sample = np.argmax(probs)
|
|
238 |
elif sample_type == SampleType.select_given:
|
|
239 |
sample,given = self.select_by_given(given,probs)
|
|
240 |
else: #SampleType.weighted_sample
|
|
241 |
sample = np.random.choice(self.vocab_size, 1, p=probs)[0]
|
|
242 |
|
|
243 |
seq.append(self.w2v_model.vocab[sample])
|
|
244 |
x = np.array([[sample]])
|
|
245 |
|
|
246 |
return ''.join(seq)
|
|
247 |
|
|
248 |
def select_by_given(self,given,probs,max_prob = False):
|
|
249 |
if given:
|
|
250 |
seq_probs = zip(probs,range(0,self.vocab_size))
|
|
251 |
topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0])
|
|
252 |
|
|
253 |
for _,seq in topn:
|
|
254 |
if self.w2v_model.vocab[seq] in given:
|
|
255 |
given = given.replace(self.w2v_model.vocab[seq],'')
|
|
256 |
return seq,given
|
|
257 |
if max_prob:
|
|
258 |
return np.argmax(probs),given
|
|
259 |
|
|
260 |
return np.random.choice(self.vocab_size, 1, p=probs)[0],given
|
|
261 |
|
|
262 |
|
|
263 |
def select_rhyme(self,rhyme_ref,probs):
|
|
264 |
if rhyme_ref:
|
|
265 |
rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref)
|
|
266 |
if rhyme_set:
|
|
267 |
seq_probs = zip(probs,range(0,self.vocab_size))
|
|
268 |
topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0])
|
|
269 |
|
|
270 |
for _,seq in topn:
|
|
271 |
if self.w2v_model.vocab[seq] in rhyme_set:
|
|
272 |
return seq
|
|
273 |
|
|
274 |
return np.argmax(probs)
|
|
275 |
|
|
276 |
def check_start(self,text):
|
|
277 |
idx = text.find('<')
|
|
278 |
if idx > -1:
|
|
279 |
text = text[:idx]
|
|
280 |
|
|
281 |
valid_text = []
|
|
282 |
for w in text:
|
|
283 |
if w in self.w2v_model.vocab:
|
|
284 |
valid_text.append(w)
|
|
285 |
return ''.join(valid_text)
|