850ec5b
yangsaisai 7 years ago
5 changed file(s) with 356 addition(s) and 118 deletion(s). Raw diff Collapse all Expand all
0 import logging
1 import time
2 from enum import Enum
3 import heapq
4 import numpy as np
5 import tensorflow as tf
6 from rhyme_helper import RhymeWords
7
8 logging.getLogger('tensorflow').setLevel(logging.WARNING)
9 SampleType = Enum('SampleType',('max_prob', 'weighted_sample', 'rhyme','select_given'))
10
11 class CharRNNLM(object):
12 def __init__(self, is_training, batch_size, num_unrollings, vocab_size,w2v_model,
13 hidden_size, max_grad_norm, embedding_size, num_layers,
14 learning_rate, cell_type, dropout=0.0, input_dropout=0.0, infer=False):
15 self.batch_size = batch_size
16 self.num_unrollings = num_unrollings
17 if infer:
18 self.batch_size = 1
19 self.num_unrollings = 1
20 self.hidden_size = hidden_size
21 self.vocab_size = vocab_size
22 self.max_grad_norm = max_grad_norm
23 self.num_layers = num_layers
24 self.embedding_size = embedding_size
25 self.cell_type = cell_type
26 self.dropout = dropout
27 self.input_dropout = input_dropout
28 self.w2v_model = w2v_model
29
30 if embedding_size <= 0:
31 self.input_size = vocab_size
32 self.input_dropout = 0.0
33 else:
34 self.input_size = embedding_size
35
36 self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='inputs')
37 self.targets = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='targets')
38
39 if self.cell_type == 'rnn':
40 cell_fn = tf.nn.rnn_cell.BasicRNNCell
41 elif self.cell_type == 'lstm':
42 cell_fn = tf.nn.rnn_cell.LSTMCell
43 elif self.cell_type == 'gru':
44 cell_fn = tf.nn.rnn_cell.GRUCell
45
46 params = dict()
47 #params = {'input_size': self.input_size}
48 if self.cell_type == 'lstm':
49 params['forget_bias'] = 1.0 # 1.0 is default value
50 cell = cell_fn(self.hidden_size, **params)
51
52 cells = [cell]
53 #params['input_size'] = self.hidden_size
54 for i in range(self.num_layers-1):
55 higher_layer_cell = cell_fn(self.hidden_size, **params)
56 cells.append(higher_layer_cell)
57
58 if is_training and self.dropout > 0:
59 cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells]
60
61 multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells)
62
63 with tf.name_scope('initial_state'):
64 self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32)
65 if self.cell_type == 'rnn' or self.cell_type == 'gru':
66 self.initial_state = tuple(
67 [tf.placeholder(tf.float32,
68 [self.batch_size, multi_cell.state_size[idx]],
69 'initial_state_'+str(idx+1)) for idx in range(self.num_layers)])
70 elif self.cell_type == 'lstm':
71 self.initial_state = tuple(
72 [tf.nn.rnn_cell.LSTMStateTuple(
73 tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][0]],
74 'initial_lstm_state_'+str(idx+1)),
75 tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][1]],
76 'initial_lstm_state_'+str(idx+1)))
77 for idx in range(self.num_layers)])
78
79 with tf.name_scope('embedding_layer'):
80 if embedding_size > 0:
81 # self.embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size])
82 self.embedding = tf.get_variable("word_embeddings",
83 initializer=self.w2v_model.vectors.astype(np.float32))
84 else:
85 self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32)
86
87 inputs = tf.nn.embedding_lookup(self.embedding, self.input_data)
88 if is_training and self.input_dropout > 0:
89 inputs = tf.nn.dropout(inputs, 1-self.input_dropout)
90
91 with tf.name_scope('slice_inputs'):
92 # num_unrollings * (batch_size, embedding_size), the format of rnn inputs.
93 sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(
94 axis = 1, num_or_size_splits = self.num_unrollings, value = inputs)]
95
96 # sliced_inputs: list of shape xx
97 # inputs: A length T list of inputs, each a Tensor of shape [batch_size, input_size]
98 # initial_state: An initial state for the RNN.
99 # If cell.state_size is an integer, this must be a Tensor of appropriate
100 # type and shape [batch_size, cell.state_size]
101 # outputs: a length T list of outputs (one for each input), or a nested tuple of such elements.
102 # state: the final state
103 outputs, final_state = tf.nn.static_rnn(
104 cell = multi_cell,
105 inputs = sliced_inputs,
106 initial_state=self.initial_state)
107 self.final_state = final_state
108
109 with tf.name_scope('flatten_outputs'):
110 flat_outputs = tf.reshape(tf.concat(axis = 1, values = outputs), [-1, hidden_size])
111
112 with tf.name_scope('flatten_targets'):
113 flat_targets = tf.reshape(tf.concat(axis = 1, values = self.targets), [-1])
114
115 with tf.variable_scope('softmax') as sm_vs:
116 softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size])
117 softmax_b = tf.get_variable('softmax_b', [vocab_size])
118 self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b
119 self.probs = tf.nn.softmax(self.logits)
120
121 with tf.name_scope('loss'):
122 loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
123 logits = self.logits, labels = flat_targets)
124 self.mean_loss = tf.reduce_mean(loss)
125
126 with tf.name_scope('loss_montor'):
127 count = tf.Variable(1.0, name='count')
128 sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss')
129
130 self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0),
131 count.assign(0.0), name='reset_loss_monitor')
132 self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss),
133 count.assign(count+1), name='update_loss_monitor')
134
135 with tf.control_dependencies([self.update_loss_monitor]):
136 self.average_loss = sum_mean_loss / count
137 self.ppl = tf.exp(self.average_loss)
138
139 average_loss_summary = tf.summary.scalar(
140 name = 'average loss', tensor = self.average_loss)
141 ppl_summary = tf.summary.scalar(
142 name = 'perplexity', tensor = self.ppl)
143
144 self.summaries = tf.summary.merge(
145 inputs = [average_loss_summary, ppl_summary], name='loss_monitor')
146
147 self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.0))
148
149 # self.learning_rate = tf.constant(learning_rate)
150 self.learning_rate = tf.placeholder(tf.float32, [], name='learning_rate')
151
152 if is_training:
153 tvars = tf.trainable_variables()
154 grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm)
155 optimizer = tf.train.AdamOptimizer(self.learning_rate)
156 self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step)
157
158
159 def run_epoch(self, session, batch_generator, is_training, learning_rate, verbose=0, freq=10):
160 epoch_size = batch_generator.num_batches
161
162 if verbose > 0:
163 logging.info('epoch_size: %d', epoch_size)
164 logging.info('data_size: %d', batch_generator.seq_length)
165 logging.info('num_unrollings: %d', self.num_unrollings)
166 logging.info('batch_size: %d', self.batch_size)
167
168 if is_training:
169 extra_op = self.train_op
170 else:
171 extra_op = tf.no_op()
172
173 if self.cell_type in ['rnn', 'gru']:
174 state = self.zero_state.eval()
175 else:
176 state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
177 np.zeros((self.batch_size, self.hidden_size)))
178 for _ in range(self.num_layers)])
179
180 self.reset_loss_monitor.run()
181 batch_generator.reset_batch_pointer()
182 start_time = time.time()
183 ppl_cumsum = 0
184 for step in range(epoch_size):
185 x, y = batch_generator.next_batch()
186
187 ops = [self.average_loss, self.ppl, self.final_state, extra_op,
188 self.summaries, self.global_step]
189
190 feed_dict = {self.input_data: x, self.targets: y, self.initial_state: state,
191 self.learning_rate: learning_rate}
192
193 results = session.run(ops, feed_dict)
194 average_loss, ppl, final_state, _, summary_str, global_step = results
195 ppl_cumsum += ppl
196
197 # if (verbose > 0) and ((step+1) % freq == 0):
198 if ((step+1) % freq == 0):
199 logging.info('%.1f%%, step:%d, perplexity: %.3f, speed: %.0f words',
200 (step + 1) * 1.0 / epoch_size * 100, step, ppl_cumsum/(step+1),
201 (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
202 logging.info("Perplexity: %.3f, speed: %.0f words per sec",
203 ppl, (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
204
205 return ppl, summary_str, global_step
206
207 def sample_seq(self, session, length, start_text, sample_type= SampleType.max_prob,given='',rhyme_ref='',rhyme_idx = 0):
208 #state = self.zero_state.eval()
209 if self.cell_type in ['rnn', 'gru']:
210 state = self.zero_state.eval()
211 else:
212 state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
213 np.zeros((self.batch_size, self.hidden_size)))
214 for _ in range(self.num_layers)])
215
216 # use start_text to warm up the RNN.
217 start_text = self.check_start(start_text)
218 if start_text is not None and len(start_text) > 0:
219 seq = list(start_text)
220 for char in start_text[:-1]:
221 x = np.array([[self.w2v_model.vocab_hash[char]]])
222 state = session.run(self.final_state, {self.input_data: x, self.initial_state: state})
223 x = np.array([[self.w2v_model.vocab_hash[start_text[-1]]]])
224 else:
225 x = np.array([[np.random.randint(0, self.vocab_size)]])
226 seq = []
227
228 for i in range(length):
229 state, logits = session.run([self.final_state, self.logits],
230 {self.input_data: x, self.initial_state: state})
231 unnormalized_probs = np.exp(logits[0] - np.max(logits[0]))
232 probs = unnormalized_probs / np.sum(unnormalized_probs)
233
234 if rhyme_ref and i == rhyme_idx :
235 sample = self.select_rhyme(rhyme_ref,probs)
236 elif sample_type == SampleType.max_prob:
237 sample = np.argmax(probs)
238 elif sample_type == SampleType.select_given:
239 sample,given = self.select_by_given(given,probs)
240 else: #SampleType.weighted_sample
241 sample = np.random.choice(self.vocab_size, 1, p=probs)[0]
242
243 seq.append(self.w2v_model.vocab[sample])
244 x = np.array([[sample]])
245
246 return ''.join(seq)
247
248 def select_by_given(self,given,probs,max_prob = False):
249 if given:
250 seq_probs = zip(probs,range(0,self.vocab_size))
251 topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0])
252
253 for _,seq in topn:
254 if self.w2v_model.vocab[seq] in given:
255 given = given.replace(self.w2v_model.vocab[seq],'')
256 return seq,given
257 if max_prob:
258 return np.argmax(probs),given
259
260 return np.random.choice(self.vocab_size, 1, p=probs)[0],given
261
262
263 def select_rhyme(self,rhyme_ref,probs):
264 if rhyme_ref:
265 rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref)
266 if rhyme_set:
267 seq_probs = zip(probs,range(0,self.vocab_size))
268 topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0])
269
270 for _,seq in topn:
271 if self.w2v_model.vocab[seq] in rhyme_set:
272 return seq
273
274 return np.argmax(probs)
275
276 def check_start(self,text):
277 idx = text.find('<')
278 if idx > -1:
279 text = text[:idx]
280
281 valid_text = []
282 for w in text:
283 if w in self.w2v_model.vocab:
284 valid_text.append(w)
285 return ''.join(valid_text)
0 import numpy as np
1 import word2vec
2
3 class Word2Vec():
4 def __init__(self, file_path):
5 # w2v_file = os.path.join(base_path, "vectors_poem.bin")
6 self.model = word2vec.load(file_path)
7 self.add_word('<unknown>')
8 self.add_word('<pad>')
9 # self.vocab_size = len(self.model.vocab)
10
11 def add_word(self,word):
12 if word not in self.model.vocab_hash:
13 w_vec = np.random.uniform(-0.1,0.1,size=128)
14 self.model.vocab_hash[word] = len(self.model.vocab)
15 self.model.vectors = np.row_stack((self.model.vectors,w_vec))
16 self.model.vocab = np.concatenate((self.model.vocab,np.array([word])))
17
18 # vocab = np.empty(1, dtype='<U%s' % 78)
19 # vocab[0] =word
20 #
21 # self.model.vocab = np.concatenate((self.model.vocab,vocab))
22
23 def get(self, word):
24 if word not in self.model.vocab_hash:
25 word = 'unknown'
26 return self.model[word]
27
28
29
30 if __name__ == '__main__':
31 # w2vpath = './corpus/vectors_xhj_shj.bin' #分字
32 w2vpath = './corpus/vectors_qa_word.bin' #分词
33
34 w2v = Word2Vec(w2vpath)
35 with open( './corpus/vocab_word.txt','w',encoding='utf-8') as fw:
36 for w in w2v.model.vocab:
37 fw.writelines(w + '\n')
0 absl-py==0.6.1
1 alembic==1.0.3
2 astor==0.7.1
3 autopep8==1.4
4 backcall==0.1.0
5 bleach==3.0.2
6 certifi==2018.10.15
7 chardet==3.0.4
0 torchvision==0.2.1
1 toolz==0.9.0
82 cloudpickle==0.6.1
9 cycler==0.10.0
103 Cython==0.29
4 pydot==1.2.4
115 dask==0.20.2
12 decorator==4.3.0
13 defusedxml==0.5.0
14 entrypoints==0.2.3
15 future==0.16.0
16 gast==0.2.0
17 grpcio==1.16.0
18 h5py==2.8.0
19 html5lib==1.0.1
20 idna==2.7
21 imageio==2.4.1
22 imgaug==0.2.6
23 ipykernel==5.1.0
24 ipython==7.1.1
25 ipython-genutils==0.2.0
26 ipywidgets==7.4.2
27 jedi==0.13.1
28 jieba==0.39
29 Jinja2==2.10
30 jsonschema==2.6.0
31 jupyter==1.0.0
32 jupyter-client==5.2.3
33 jupyter-console==6.0.0
34 jupyter-core==4.4.0
35 jupyterhub==0.8.1
36 jupyterlab==0.31.1
37 jupyterlab-launcher==0.10.5
38 Keras==2.2.4
39 Keras-Applications==1.0.6
40 Keras-Preprocessing==1.0.5
41 kiwisolver==1.0.1
42 Mako==1.0.7
43 Markdown==3.0.1
44 MarkupSafe==1.0
45 matplotlib==3.0.2
46 mccabe==0.6.1
47 mistune==0.8.4
48 nbconvert==5.4.0
49 nbformat==4.4.0
6 scikit-image==0.14.1
507 networkx==2.2
51 nltk==3.3
52 notebook==5.7.0
53 numpy==1.15.2
54 opencv-python==3.4.3.18
55 pamela==0.3.0
56 pandas==0.23.4
57 pandocfilters==1.4.2
58 parso==0.3.1
59 pbr==5.1.1
60 pexpect==4.6.0
61 pickleshare==0.7.5
62 Pillow==5.3.0
63 pluggy==0.7.1
64 prometheus-client==0.4.2
65 prompt-toolkit==2.0.7
66 protobuf==3.6.1
67 ptyprocess==0.6.0
68 pycodestyle==2.4.0
69 pycurl==7.43.0
70 pydocstyle==2.1.1
71 pydot==1.2.4
72 pyflakes==2.0.0
73 Pygments==2.2.0
74 pygobject==3.20.0
75 pyparsing==2.3.0
76 python-apt==1.1.0b1+ubuntu0.16.4.2
77 python-dateutil==2.7.3
78 python-editor==1.0.3
79 python-jsonrpc-server==0.0.1
80 python-language-server==0.21.2
81 python-oauth2==1.1.0
82 pytz==2018.5
83 PyWavelets==1.0.1
84 PyYAML==3.13
85 pyzmq==17.1.2
86 qtconsole==4.4.2
87 requests==2.20.0
88 rope==0.11.0
89 scikit-image==0.14.1
90 scikit-learn==0.20.0
91 scipy==1.1.0
92 seaborn==0.9.0
93 Send2Trash==1.5.0
94 Shapely==1.6.4.post2
95 simplegeneric==0.8.1
96 six==1.11.0
97 sklearn==0.0
98 snowballstemmer==1.2.1
99 SQLAlchemy==1.2.14
100 stevedore==1.29.0
101 tensorboard==1.11.0
102 tensorflow==1.12.0
103 termcolor==1.1.0
104 terminado==0.8.1
105 testpath==0.4.2
106 toolz==0.9.0
107 torch==0.4.1
108 torchvision==0.2.1
109 tornado==5.1.1
110 traitlets==4.3.2
111 urllib3==1.24.1
112 virtualenv==16.0.0
113 virtualenv-clone==0.4.0
114 virtualenvwrapper==4.8.2
115 wcwidth==0.1.7
116 webencodings==0.5.1
117 Werkzeug==0.14.1
118 widgetsnbextension==3.4.2
1198 word2vec==0.10.2
1209 xlrd==1.1.0
121 yapf==0.24.0
10 jieba==0.39
11 defusedxml==0.5.0
12 imgaug==0.2.6
13 Shapely==1.6.4.post2
14 nltk==3.3
15 torch==0.4.1
16 PyWavelets==1.0.1
17 imageio==2.4.1
18 python-jsonrpc-server==0.0.1
0 {
1 "best_model": "/home/jovyan/work/results/output_poem/best_model/model-20390",
2 "best_valid_ppl": 22.313692092895508,
3 "latest_model": "/home/jovyan/work/results/output_poem/save_model/model-20390",
4 "params": {
5 "batch_size": 16,
6 "cell_type": "lstm",
7 "dropout": 0.0,
8 "embedding_size": 128,
9 "hidden_size": 128,
10 "input_dropout": 0.0,
11 "learning_rate": 0.005,
12 "max_grad_norm": 5.0,
13 "num_layers": 2,
14 "num_unrollings": 64
15 },
16 "test_ppl": 27.004379272460938
17 }
11 import word2vec
22
33 class Word2Vec():
4 def __init__(self,file_path):
4 def __init__(self, file_path):
55 # w2v_file = os.path.join(base_path, "vectors_poem.bin")
66 self.model = word2vec.load(file_path)
77 self.add_word('<unknown>')
2323 def get(self, word):
2424 if word not in self.model.vocab_hash:
2525 word = 'unknown'
26
2726 return self.model[word]
2827
2928