yangsaisai
7 years ago
| 232 | 232 | probs = unnormalized_probs / np.sum(unnormalized_probs) |
| 233 | 233 | |
| 234 | 234 | if rhyme_ref and i == rhyme_idx : |
| 235 | sample = self.select_rhyme(rhyme_ref,probs) | |
| 235 | sample = self.select_rhyme(rhyme_ref, probs) | |
| 236 | 236 | elif sample_type == SampleType.max_prob: |
| 237 | 237 | sample = np.argmax(probs) |
| 238 | 238 | elif sample_type == SampleType.select_given: |
| 239 | sample,given = self.select_by_given(given,probs) | |
| 239 | sample, given = self.select_by_given(given, probs) | |
| 240 | 240 | else: #SampleType.weighted_sample |
| 241 | 241 | sample = np.random.choice(self.vocab_size, 1, p=probs)[0] |
| 242 | 242 | |
| 245 | 245 | |
| 246 | 246 | return ''.join(seq) |
| 247 | 247 | |
| 248 | def select_by_given(self,given,probs,max_prob = False): | |
| 248 | def select_by_given(self, given, probs, max_prob=False): | |
| 249 | 249 | if given: |
| 250 | seq_probs = zip(probs,range(0,self.vocab_size)) | |
| 251 | topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0]) | |
| 252 | ||
| 253 | for _,seq in topn: | |
| 250 | seq_probs = zip(probs, range(0, self.vocab_size)) | |
| 251 | topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0]) | |
| 252 | ||
| 253 | for _, seq in topn: | |
| 254 | 254 | if self.w2v_model.vocab[seq] in given: |
| 255 | given = given.replace(self.w2v_model.vocab[seq],'') | |
| 256 | return seq,given | |
| 255 | given = given.replace(self.w2v_model.vocab[seq], '') | |
| 256 | return seq, given | |
| 257 | 257 | if max_prob: |
| 258 | return np.argmax(probs),given | |
| 258 | return np.argmax(probs), given | |
| 259 | 259 | |
| 260 | 260 | return np.random.choice(self.vocab_size, 1, p=probs)[0],given |
| 261 | 261 | |
| 262 | ||
| 263 | def select_rhyme(self,rhyme_ref,probs): | |
| 262 | def select_rhyme(self, rhyme_ref, probs): | |
| 264 | 263 | if rhyme_ref: |
| 265 | 264 | rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref) |
| 266 | 265 | if rhyme_set: |
| 267 | seq_probs = zip(probs,range(0,self.vocab_size)) | |
| 268 | topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0]) | |
| 269 | ||
| 270 | for _,seq in topn: | |
| 266 | seq_probs = zip(probs, range(0, self.vocab_size)) | |
| 267 | topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0]) | |
| 268 | ||
| 269 | for _, seq in topn: | |
| 271 | 270 | if self.w2v_model.vocab[seq] in rhyme_set: |
| 272 | 271 | return seq |
| 273 | 272 | |
| 274 | return np.argmax(probs) | |
| 275 | ||
| 276 | def check_start(self,text): | |
| 273 | return np.argmax(probs) | |
| 274 | ||
| 275 | def check_start(self, text): | |
| 277 | 276 | idx = text.find('<') |
| 278 | 277 | if idx > -1: |
| 279 | 278 | text = text[:idx] |
| 2 | 2 | import sys |
| 3 | 3 | |
| 4 | 4 | # Import necessary packages |
| 5 | from modules import json_parser | |
| 6 | 5 | from modules import json_parser |
| 7 | 6 | from modules import Client |
| 8 | 7 | from write_poem import start_model |
| 44 | 44 | self.sess = tf.Session() |
| 45 | 45 | w2v_vocab_size = len(self.w2v.model.vocab) |
| 46 | 46 | with tf.name_scope('evaluation'): |
| 47 | self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params) | |
| 47 | self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params) | |
| 48 | 48 | saver = tf.train.Saver(name='model_saver') |
| 49 | 49 | saver.restore(self.sess, best_model) |
| 50 | 50 | |
| 147 | 147 | |
| 148 | 148 | return sample[1:] |
| 149 | 149 | |
| 150 | def cangtou(self,given_text): | |
| 150 | def cangtou(self, given_text): | |
| 151 | 151 | ''' |
| 152 | 152 | 藏头诗 |
| 153 | 153 | Returns: |
| 180 | 180 | rhyme_seq += 1 |
| 181 | 181 | |
| 182 | 182 | sample = self.model.sample_seq(self.sess, self.args.length, start, |
| 183 | sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq ) | |
| 183 | sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq ) | |
| 184 | 184 | |
| 185 | 185 | # 暂时屏蔽 |
| 186 | 186 | # print('Sampled text is:\n\n%s' % sample) |
| 188 | 188 | sample = sample[before_idx:] |
| 189 | 189 | idx1 = sample.find(',') |
| 190 | 190 | idx2 = sample.find('。') |
| 191 | min_idx = min(idx1,idx2) | |
| 191 | min_idx = min(idx1, idx2) | |
| 192 | 192 | |
| 193 | 193 | if min_idx == -1: |
| 194 | 194 | if idx1 > -1 : |
| 195 | 195 | min_idx = idx1 |
| 196 | else: min_idx =idx2 | |
| 196 | else: | |
| 197 | min_idx = idx2 | |
| 197 | 198 | if min_idx > 0: |
| 198 | 199 | # last_sample.append(sample[:min_idx + 1]) |
| 199 | start ='{}{}'.format(start, sample[:min_idx + 1]) | |
| 200 | start = '{}{}'.format(start, sample[:min_idx + 1]) | |
| 200 | 201 | |
| 201 | 202 | if i == 1: |
| 202 | 203 | rhyme_seq = min_idx - 1 |
| 206 | 207 | # print('last_sample text is:\n\n%s' % start) |
| 207 | 208 | |
| 208 | 209 | return WritePoem.assemble(start) |
| 210 | ||
| 209 | 211 | |
| 210 | 212 | def start_model(): |
| 211 | 213 | now = int(time.time()) |
| 231 | 231 | unnormalized_probs = np.exp(logits[0] - np.max(logits[0])) |
| 232 | 232 | probs = unnormalized_probs / np.sum(unnormalized_probs) |
| 233 | 233 | |
| 234 | if rhyme_ref and i == rhyme_idx : | |
| 235 | sample = self.select_rhyme(rhyme_ref,probs) | |
| 234 | if rhyme_ref and i == rhyme_idx: | |
| 235 | sample = self.select_rhyme(rhyme_ref, probs) | |
| 236 | 236 | elif sample_type == SampleType.max_prob: |
| 237 | 237 | sample = np.argmax(probs) |
| 238 | 238 | elif sample_type == SampleType.select_given: |
| 239 | sample,given = self.select_by_given(given,probs) | |
| 239 | sample, given = self.select_by_given(given, probs) | |
| 240 | 240 | else: #SampleType.weighted_sample |
| 241 | 241 | sample = np.random.choice(self.vocab_size, 1, p=probs)[0] |
| 242 | 242 | |
| 245 | 245 | |
| 246 | 246 | return ''.join(seq) |
| 247 | 247 | |
| 248 | def select_by_given(self,given,probs,max_prob = False): | |
| 248 | def select_by_given(self, given, probs, max_prob=False): | |
| 249 | 249 | if given: |
| 250 | seq_probs = zip(probs,range(0,self.vocab_size)) | |
| 251 | topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0]) | |
| 252 | ||
| 253 | for _,seq in topn: | |
| 250 | seq_probs = zip(probs, range(0, self.vocab_size)) | |
| 251 | topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0]) | |
| 252 | ||
| 253 | for _, seq in topn: | |
| 254 | 254 | if self.w2v_model.vocab[seq] in given: |
| 255 | given = given.replace(self.w2v_model.vocab[seq],'') | |
| 256 | return seq,given | |
| 255 | given = given.replace(self.w2v_model.vocab[seq], '') | |
| 256 | return seq, given | |
| 257 | 257 | if max_prob: |
| 258 | return np.argmax(probs),given | |
| 258 | return np.argmax(probs), given | |
| 259 | 259 | |
| 260 | 260 | return np.random.choice(self.vocab_size, 1, p=probs)[0],given |
| 261 | 261 | |
| 262 | ||
| 263 | def select_rhyme(self,rhyme_ref,probs): | |
| 262 | def select_rhyme(self, rhyme_ref, probs): | |
| 264 | 263 | if rhyme_ref: |
| 265 | 264 | rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref) |
| 266 | 265 | if rhyme_set: |
| 267 | seq_probs = zip(probs,range(0,self.vocab_size)) | |
| 268 | topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0]) | |
| 269 | ||
| 270 | for _,seq in topn: | |
| 266 | seq_probs = zip(probs, range(0, self.vocab_size)) | |
| 267 | topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0]) | |
| 268 | ||
| 269 | for _, seq in topn: | |
| 271 | 270 | if self.w2v_model.vocab[seq] in rhyme_set: |
| 272 | 271 | return seq |
| 273 | 272 | |
| 274 | return np.argmax(probs) | |
| 275 | ||
| 276 | def check_start(self,text): | |
| 273 | return np.argmax(probs) | |
| 274 | ||
| 275 | def check_start(self, text): | |
| 277 | 276 | idx = text.find('<') |
| 278 | 277 | if idx > -1: |
| 279 | 278 | text = text[:idx] |
| 0 | absl-py==0.6.1 | |
| 1 | alembic==1.0.3 | |
| 2 | astor==0.7.1 | |
| 3 | autopep8==1.4 | |
| 4 | backcall==0.1.0 | |
| 5 | bleach==3.0.2 | |
| 6 | certifi==2018.10.15 | |
| 7 | chardet==3.0.4 | |
| 8 | cloudpickle==0.6.1 | |
| 9 | cycler==0.10.0 | |
| 10 | Cython==0.29 | |
| 11 | dask==0.20.2 | |
| 12 | decorator==4.3.0 | |
| 13 | defusedxml==0.5.0 | |
| 14 | entrypoints==0.2.3 | |
| 15 | future==0.16.0 | |
| 16 | gast==0.2.0 | |
| 17 | grpcio==1.16.0 | |
| 18 | h5py==2.8.0 | |
| 19 | html5lib==1.0.1 | |
| 20 | idna==2.7 | |
| 21 | imageio==2.4.1 | |
| 22 | imgaug==0.2.6 | |
| 23 | ipykernel==5.1.0 | |
| 24 | ipython==7.1.1 | |
| 25 | ipython-genutils==0.2.0 | |
| 26 | ipywidgets==7.4.2 | |
| 27 | jedi==0.13.1 | |
| 28 | jieba==0.39 | |
| 29 | Jinja2==2.10 | |
| 30 | jsonschema==2.6.0 | |
| 31 | jupyter==1.0.0 | |
| 32 | jupyter-client==5.2.3 | |
| 33 | jupyter-console==6.0.0 | |
| 34 | jupyter-core==4.4.0 | |
| 35 | jupyterhub==0.8.1 | |
| 36 | jupyterlab==0.31.1 | |
| 37 | jupyterlab-launcher==0.10.5 | |
| 38 | Keras==2.2.4 | |
| 39 | Keras-Applications==1.0.6 | |
| 40 | Keras-Preprocessing==1.0.5 | |
| 41 | kiwisolver==1.0.1 | |
| 42 | Mako==1.0.7 | |
| 43 | Markdown==3.0.1 | |
| 44 | MarkupSafe==1.0 | |
| 45 | matplotlib==3.0.2 | |
| 46 | mccabe==0.6.1 | |
| 47 | mistune==0.8.4 | |
| 48 | nbconvert==5.4.0 | |
| 49 | nbformat==4.4.0 | |
| 50 | networkx==2.2 | |
| 51 | nltk==3.3 | |
| 52 | notebook==5.7.0 | |
| 53 | numpy==1.15.2 | |
| 54 | opencv-python==3.4.3.18 | |
| 55 | pamela==0.3.0 | |
| 56 | pandas==0.23.4 | |
| 57 | pandocfilters==1.4.2 | |
| 58 | parso==0.3.1 | |
| 59 | pbr==5.1.1 | |
| 60 | pexpect==4.6.0 | |
| 61 | pickleshare==0.7.5 | |
| 62 | Pillow==5.3.0 | |
| 63 | pluggy==0.7.1 | |
| 64 | prometheus-client==0.4.2 | |
| 65 | prompt-toolkit==2.0.7 | |
| 66 | protobuf==3.6.1 | |
| 67 | ptyprocess==0.6.0 | |
| 68 | pycodestyle==2.4.0 | |
| 69 | pycurl==7.43.0 | |
| 70 | pydocstyle==2.1.1 | |
| 71 | pydot==1.2.4 | |
| 72 | pyflakes==2.0.0 | |
| 73 | Pygments==2.2.0 | |
| 74 | pygobject==3.20.0 | |
| 75 | pyparsing==2.3.0 | |
| 76 | python-apt==1.1.0b1+ubuntu0.16.4.2 | |
| 77 | python-dateutil==2.7.3 | |
| 78 | python-editor==1.0.3 | |
| 79 | python-jsonrpc-server==0.0.1 | |
| 80 | python-language-server==0.21.2 | |
| 81 | python-oauth2==1.1.0 | |
| 82 | pytz==2018.5 | |
| 83 | PyWavelets==1.0.1 | |
| 84 | PyYAML==3.13 | |
| 85 | pyzmq==17.1.2 | |
| 86 | qtconsole==4.4.2 | |
| 87 | requests==2.20.0 | |
| 88 | rope==0.11.0 | |
| 89 | scikit-image==0.14.1 | |
| 90 | scikit-learn==0.20.0 | |
| 91 | scipy==1.1.0 | |
| 92 | seaborn==0.9.0 | |
| 93 | Send2Trash==1.5.0 | |
| 94 | Shapely==1.6.4.post2 | |
| 95 | simplegeneric==0.8.1 | |
| 96 | six==1.11.0 | |
| 97 | sklearn==0.0 | |
| 98 | snowballstemmer==1.2.1 | |
| 99 | SQLAlchemy==1.2.14 | |
| 100 | stevedore==1.29.0 | |
| 101 | tensorboard==1.11.0 | |
| 102 | tensorflow==1.12.0 | |
| 103 | termcolor==1.1.0 | |
| 104 | terminado==0.8.1 | |
| 105 | testpath==0.4.2 | |
| 106 | toolz==0.9.0 | |
| 107 | torch==0.4.1 | |
| 108 | torchvision==0.2.1 | |
| 109 | tornado==5.1.1 | |
| 110 | traitlets==4.3.2 | |
| 111 | urllib3==1.24.1 | |
| 112 | virtualenv==16.0.0 | |
| 113 | virtualenv-clone==0.4.0 | |
| 114 | virtualenvwrapper==4.8.2 | |
| 115 | wcwidth==0.1.7 | |
| 116 | webencodings==0.5.1 | |
| 117 | Werkzeug==0.14.1 | |
| 118 | widgetsnbextension==3.4.2 | |
| 119 | word2vec==0.10.2 | |
| 120 | xlrd==1.1.0 | |
| 121 | yapf==0.24.0 |
| 2 | 2 | import sys |
| 3 | 3 | |
| 4 | 4 | # Import necessary packages |
| 5 | from modules import json_parser | |
| 6 | 5 | from modules import json_parser |
| 7 | 6 | from modules import Client |
| 8 | 7 | from write_poem import start_model |
| 0 | Please store your tensorboard results here |
| 22 | 22 | args.save_best_model = os.path.join(args.output_dir, 'best_model/model') |
| 23 | 23 | # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/') |
| 24 | 24 | timestamp = str(int(time.time())) |
| 25 | args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp)) | |
| 25 | # args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp)) | |
| 26 | args.tb_log_dir = os.path.abspath(os.path.join('./results/tb_results', "tensorboard_log", timestamp)) | |
| 26 | 27 | print("Writing to {}\n".format(args.tb_log_dir)) |
| 27 | 28 | |
| 28 | 29 | # Create necessary directories. |
| 44 | 44 | self.sess = tf.Session() |
| 45 | 45 | w2v_vocab_size = len(self.w2v.model.vocab) |
| 46 | 46 | with tf.name_scope('evaluation'): |
| 47 | self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params) | |
| 47 | self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params) | |
| 48 | 48 | saver = tf.train.Saver(name='model_saver') |
| 49 | 49 | saver.restore(self.sess, best_model) |
| 50 | 50 | |
| 147 | 147 | |
| 148 | 148 | return sample[1:] |
| 149 | 149 | |
| 150 | def cangtou(self,given_text): | |
| 150 | def cangtou(self, given_text): | |
| 151 | 151 | ''' |
| 152 | 152 | 藏头诗 |
| 153 | 153 | Returns: |
| 180 | 180 | rhyme_seq += 1 |
| 181 | 181 | |
| 182 | 182 | sample = self.model.sample_seq(self.sess, self.args.length, start, |
| 183 | sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq ) | |
| 183 | sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq ) | |
| 184 | 184 | |
| 185 | 185 | # 暂时屏蔽 |
| 186 | 186 | # print('Sampled text is:\n\n%s' % sample) |
| 188 | 188 | sample = sample[before_idx:] |
| 189 | 189 | idx1 = sample.find(',') |
| 190 | 190 | idx2 = sample.find('。') |
| 191 | min_idx = min(idx1,idx2) | |
| 191 | min_idx = min(idx1, idx2) | |
| 192 | 192 | |
| 193 | 193 | if min_idx == -1: |
| 194 | 194 | if idx1 > -1 : |
| 195 | 195 | min_idx = idx1 |
| 196 | else: min_idx =idx2 | |
| 196 | else: | |
| 197 | min_idx = idx2 | |
| 197 | 198 | if min_idx > 0: |
| 198 | 199 | # last_sample.append(sample[:min_idx + 1]) |
| 199 | start ='{}{}'.format(start, sample[:min_idx + 1]) | |
| 200 | start = '{}{}'.format(start, sample[:min_idx + 1]) | |
| 200 | 201 | |
| 201 | 202 | if i == 1: |
| 202 | 203 | rhyme_seq = min_idx - 1 |
| 206 | 207 | # print('last_sample text is:\n\n%s' % start) |
| 207 | 208 | |
| 208 | 209 | return WritePoem.assemble(start) |
| 210 | ||
| 209 | 211 | |
| 210 | 212 | def start_model(): |
| 211 | 213 | now = int(time.time()) |