8c8dc33
yangsaisai 7 years ago
9 changed file(s) with 56 addition(s) and 176 deletion(s). Raw diff Collapse all Expand all
232232 probs = unnormalized_probs / np.sum(unnormalized_probs)
233233
234234 if rhyme_ref and i == rhyme_idx :
235 sample = self.select_rhyme(rhyme_ref,probs)
235 sample = self.select_rhyme(rhyme_ref, probs)
236236 elif sample_type == SampleType.max_prob:
237237 sample = np.argmax(probs)
238238 elif sample_type == SampleType.select_given:
239 sample,given = self.select_by_given(given,probs)
239 sample, given = self.select_by_given(given, probs)
240240 else: #SampleType.weighted_sample
241241 sample = np.random.choice(self.vocab_size, 1, p=probs)[0]
242242
245245
246246 return ''.join(seq)
247247
248 def select_by_given(self,given,probs,max_prob = False):
248 def select_by_given(self, given, probs, max_prob=False):
249249 if given:
250 seq_probs = zip(probs,range(0,self.vocab_size))
251 topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0])
252
253 for _,seq in topn:
250 seq_probs = zip(probs, range(0, self.vocab_size))
251 topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0])
252
253 for _, seq in topn:
254254 if self.w2v_model.vocab[seq] in given:
255 given = given.replace(self.w2v_model.vocab[seq],'')
256 return seq,given
255 given = given.replace(self.w2v_model.vocab[seq], '')
256 return seq, given
257257 if max_prob:
258 return np.argmax(probs),given
258 return np.argmax(probs), given
259259
260260 return np.random.choice(self.vocab_size, 1, p=probs)[0],given
261261
262
263 def select_rhyme(self,rhyme_ref,probs):
262 def select_rhyme(self, rhyme_ref, probs):
264263 if rhyme_ref:
265264 rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref)
266265 if rhyme_set:
267 seq_probs = zip(probs,range(0,self.vocab_size))
268 topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0])
269
270 for _,seq in topn:
266 seq_probs = zip(probs, range(0, self.vocab_size))
267 topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0])
268
269 for _, seq in topn:
271270 if self.w2v_model.vocab[seq] in rhyme_set:
272271 return seq
273272
274 return np.argmax(probs)
275
276 def check_start(self,text):
273 return np.argmax(probs)
274
275 def check_start(self, text):
277276 idx = text.find('<')
278277 if idx > -1:
279278 text = text[:idx]
22 import sys
33
44 # Import necessary packages
5 from modules import json_parser
65 from modules import json_parser
76 from modules import Client
87 from write_poem import start_model
4444 self.sess = tf.Session()
4545 w2v_vocab_size = len(self.w2v.model.vocab)
4646 with tf.name_scope('evaluation'):
47 self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params)
47 self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params)
4848 saver = tf.train.Saver(name='model_saver')
4949 saver.restore(self.sess, best_model)
5050
147147
148148 return sample[1:]
149149
150 def cangtou(self,given_text):
150 def cangtou(self, given_text):
151151 '''
152152 藏头诗
153153 Returns:
180180 rhyme_seq += 1
181181
182182 sample = self.model.sample_seq(self.sess, self.args.length, start,
183 sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
183 sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq )
184184
185185 # 暂时屏蔽
186186 # print('Sampled text is:\n\n%s' % sample)
188188 sample = sample[before_idx:]
189189 idx1 = sample.find(',')
190190 idx2 = sample.find('。')
191 min_idx = min(idx1,idx2)
191 min_idx = min(idx1, idx2)
192192
193193 if min_idx == -1:
194194 if idx1 > -1 :
195195 min_idx = idx1
196 else: min_idx =idx2
196 else:
197 min_idx = idx2
197198 if min_idx > 0:
198199 # last_sample.append(sample[:min_idx + 1])
199 start ='{}{}'.format(start, sample[:min_idx + 1])
200 start = '{}{}'.format(start, sample[:min_idx + 1])
200201
201202 if i == 1:
202203 rhyme_seq = min_idx - 1
206207 # print('last_sample text is:\n\n%s' % start)
207208
208209 return WritePoem.assemble(start)
210
209211
210212 def start_model():
211213 now = int(time.time())
231231 unnormalized_probs = np.exp(logits[0] - np.max(logits[0]))
232232 probs = unnormalized_probs / np.sum(unnormalized_probs)
233233
234 if rhyme_ref and i == rhyme_idx :
235 sample = self.select_rhyme(rhyme_ref,probs)
234 if rhyme_ref and i == rhyme_idx:
235 sample = self.select_rhyme(rhyme_ref, probs)
236236 elif sample_type == SampleType.max_prob:
237237 sample = np.argmax(probs)
238238 elif sample_type == SampleType.select_given:
239 sample,given = self.select_by_given(given,probs)
239 sample, given = self.select_by_given(given, probs)
240240 else: #SampleType.weighted_sample
241241 sample = np.random.choice(self.vocab_size, 1, p=probs)[0]
242242
245245
246246 return ''.join(seq)
247247
248 def select_by_given(self,given,probs,max_prob = False):
248 def select_by_given(self, given, probs, max_prob=False):
249249 if given:
250 seq_probs = zip(probs,range(0,self.vocab_size))
251 topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0])
252
253 for _,seq in topn:
250 seq_probs = zip(probs, range(0, self.vocab_size))
251 topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0])
252
253 for _, seq in topn:
254254 if self.w2v_model.vocab[seq] in given:
255 given = given.replace(self.w2v_model.vocab[seq],'')
256 return seq,given
255 given = given.replace(self.w2v_model.vocab[seq], '')
256 return seq, given
257257 if max_prob:
258 return np.argmax(probs),given
258 return np.argmax(probs), given
259259
260260 return np.random.choice(self.vocab_size, 1, p=probs)[0],given
261261
262
263 def select_rhyme(self,rhyme_ref,probs):
262 def select_rhyme(self, rhyme_ref, probs):
264263 if rhyme_ref:
265264 rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref)
266265 if rhyme_set:
267 seq_probs = zip(probs,range(0,self.vocab_size))
268 topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0])
269
270 for _,seq in topn:
266 seq_probs = zip(probs, range(0, self.vocab_size))
267 topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0])
268
269 for _, seq in topn:
271270 if self.w2v_model.vocab[seq] in rhyme_set:
272271 return seq
273272
274 return np.argmax(probs)
275
276 def check_start(self,text):
273 return np.argmax(probs)
274
275 def check_start(self, text):
277276 idx = text.find('<')
278277 if idx > -1:
279278 text = text[:idx]
+0
-122
faas_requirements.txt less more
0 absl-py==0.6.1
1 alembic==1.0.3
2 astor==0.7.1
3 autopep8==1.4
4 backcall==0.1.0
5 bleach==3.0.2
6 certifi==2018.10.15
7 chardet==3.0.4
8 cloudpickle==0.6.1
9 cycler==0.10.0
10 Cython==0.29
11 dask==0.20.2
12 decorator==4.3.0
13 defusedxml==0.5.0
14 entrypoints==0.2.3
15 future==0.16.0
16 gast==0.2.0
17 grpcio==1.16.0
18 h5py==2.8.0
19 html5lib==1.0.1
20 idna==2.7
21 imageio==2.4.1
22 imgaug==0.2.6
23 ipykernel==5.1.0
24 ipython==7.1.1
25 ipython-genutils==0.2.0
26 ipywidgets==7.4.2
27 jedi==0.13.1
28 jieba==0.39
29 Jinja2==2.10
30 jsonschema==2.6.0
31 jupyter==1.0.0
32 jupyter-client==5.2.3
33 jupyter-console==6.0.0
34 jupyter-core==4.4.0
35 jupyterhub==0.8.1
36 jupyterlab==0.31.1
37 jupyterlab-launcher==0.10.5
38 Keras==2.2.4
39 Keras-Applications==1.0.6
40 Keras-Preprocessing==1.0.5
41 kiwisolver==1.0.1
42 Mako==1.0.7
43 Markdown==3.0.1
44 MarkupSafe==1.0
45 matplotlib==3.0.2
46 mccabe==0.6.1
47 mistune==0.8.4
48 nbconvert==5.4.0
49 nbformat==4.4.0
50 networkx==2.2
51 nltk==3.3
52 notebook==5.7.0
53 numpy==1.15.2
54 opencv-python==3.4.3.18
55 pamela==0.3.0
56 pandas==0.23.4
57 pandocfilters==1.4.2
58 parso==0.3.1
59 pbr==5.1.1
60 pexpect==4.6.0
61 pickleshare==0.7.5
62 Pillow==5.3.0
63 pluggy==0.7.1
64 prometheus-client==0.4.2
65 prompt-toolkit==2.0.7
66 protobuf==3.6.1
67 ptyprocess==0.6.0
68 pycodestyle==2.4.0
69 pycurl==7.43.0
70 pydocstyle==2.1.1
71 pydot==1.2.4
72 pyflakes==2.0.0
73 Pygments==2.2.0
74 pygobject==3.20.0
75 pyparsing==2.3.0
76 python-apt==1.1.0b1+ubuntu0.16.4.2
77 python-dateutil==2.7.3
78 python-editor==1.0.3
79 python-jsonrpc-server==0.0.1
80 python-language-server==0.21.2
81 python-oauth2==1.1.0
82 pytz==2018.5
83 PyWavelets==1.0.1
84 PyYAML==3.13
85 pyzmq==17.1.2
86 qtconsole==4.4.2
87 requests==2.20.0
88 rope==0.11.0
89 scikit-image==0.14.1
90 scikit-learn==0.20.0
91 scipy==1.1.0
92 seaborn==0.9.0
93 Send2Trash==1.5.0
94 Shapely==1.6.4.post2
95 simplegeneric==0.8.1
96 six==1.11.0
97 sklearn==0.0
98 snowballstemmer==1.2.1
99 SQLAlchemy==1.2.14
100 stevedore==1.29.0
101 tensorboard==1.11.0
102 tensorflow==1.12.0
103 termcolor==1.1.0
104 terminado==0.8.1
105 testpath==0.4.2
106 toolz==0.9.0
107 torch==0.4.1
108 torchvision==0.2.1
109 tornado==5.1.1
110 traitlets==4.3.2
111 urllib3==1.24.1
112 virtualenv==16.0.0
113 virtualenv-clone==0.4.0
114 virtualenvwrapper==4.8.2
115 wcwidth==0.1.7
116 webencodings==0.5.1
117 Werkzeug==0.14.1
118 widgetsnbextension==3.4.2
119 word2vec==0.10.2
120 xlrd==1.1.0
121 yapf==0.24.0
22 import sys
33
44 # Import necessary packages
5 from modules import json_parser
65 from modules import json_parser
76 from modules import Client
87 from write_poem import start_model
0 Please store your tensorboard results here
2222 args.save_best_model = os.path.join(args.output_dir, 'best_model/model')
2323 # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
2424 timestamp = str(int(time.time()))
25 args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp))
25 # args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp))
26 args.tb_log_dir = os.path.abspath(os.path.join('./results/tb_results', "tensorboard_log", timestamp))
2627 print("Writing to {}\n".format(args.tb_log_dir))
2728
2829 # Create necessary directories.
4444 self.sess = tf.Session()
4545 w2v_vocab_size = len(self.w2v.model.vocab)
4646 with tf.name_scope('evaluation'):
47 self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params)
47 self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params)
4848 saver = tf.train.Saver(name='model_saver')
4949 saver.restore(self.sess, best_model)
5050
147147
148148 return sample[1:]
149149
150 def cangtou(self,given_text):
150 def cangtou(self, given_text):
151151 '''
152152 藏头诗
153153 Returns:
180180 rhyme_seq += 1
181181
182182 sample = self.model.sample_seq(self.sess, self.args.length, start,
183 sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
183 sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq )
184184
185185 # 暂时屏蔽
186186 # print('Sampled text is:\n\n%s' % sample)
188188 sample = sample[before_idx:]
189189 idx1 = sample.find(',')
190190 idx2 = sample.find('。')
191 min_idx = min(idx1,idx2)
191 min_idx = min(idx1, idx2)
192192
193193 if min_idx == -1:
194194 if idx1 > -1 :
195195 min_idx = idx1
196 else: min_idx =idx2
196 else:
197 min_idx = idx2
197198 if min_idx > 0:
198199 # last_sample.append(sample[:min_idx + 1])
199 start ='{}{}'.format(start, sample[:min_idx + 1])
200 start = '{}{}'.format(start, sample[:min_idx + 1])
200201
201202 if i == 1:
202203 rhyme_seq = min_idx - 1
206207 # print('last_sample text is:\n\n%s' % start)
207208
208209 return WritePoem.assemble(start)
210
209211
210212 def start_model():
211213 now = int(time.time())