diff --git a/.ipynb_checkpoints/char_rnn_model-checkpoint.py b/.ipynb_checkpoints/char_rnn_model-checkpoint.py index 58e2b1a..f3b933c 100644 --- a/.ipynb_checkpoints/char_rnn_model-checkpoint.py +++ b/.ipynb_checkpoints/char_rnn_model-checkpoint.py @@ -233,11 +233,11 @@ probs = unnormalized_probs / np.sum(unnormalized_probs) if rhyme_ref and i == rhyme_idx : - sample = self.select_rhyme(rhyme_ref,probs) + sample = self.select_rhyme(rhyme_ref, probs) elif sample_type == SampleType.max_prob: sample = np.argmax(probs) elif sample_type == SampleType.select_given: - sample,given = self.select_by_given(given,probs) + sample, given = self.select_by_given(given, probs) else: #SampleType.weighted_sample sample = np.random.choice(self.vocab_size, 1, p=probs)[0] @@ -246,35 +246,34 @@ return ''.join(seq) - def select_by_given(self,given,probs,max_prob = False): + def select_by_given(self, given, probs, max_prob=False): if given: - seq_probs = zip(probs,range(0,self.vocab_size)) - topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0]) - - for _,seq in topn: + seq_probs = zip(probs, range(0, self.vocab_size)) + topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0]) + + for _, seq in topn: if self.w2v_model.vocab[seq] in given: - given = given.replace(self.w2v_model.vocab[seq],'') - return seq,given + given = given.replace(self.w2v_model.vocab[seq], '') + return seq, given if max_prob: - return np.argmax(probs),given + return np.argmax(probs), given return np.random.choice(self.vocab_size, 1, p=probs)[0],given - - def select_rhyme(self,rhyme_ref,probs): + def select_rhyme(self, rhyme_ref, probs): if rhyme_ref: rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref) if rhyme_set: - seq_probs = zip(probs,range(0,self.vocab_size)) - topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0]) - - for _,seq in topn: + seq_probs = zip(probs, range(0, self.vocab_size)) + topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0]) + + for _, seq in topn: if self.w2v_model.vocab[seq] in rhyme_set: return seq - return np.argmax(probs) - - def check_start(self,text): + return np.argmax(probs) + + def check_start(self, text): idx = text.find('<') if idx > -1: text = text[:idx] diff --git a/.ipynb_checkpoints/main-checkpoint.py b/.ipynb_checkpoints/main-checkpoint.py index 54d5b99..269fda9 100644 --- a/.ipynb_checkpoints/main-checkpoint.py +++ b/.ipynb_checkpoints/main-checkpoint.py @@ -3,7 +3,6 @@ import sys # Import necessary packages -from modules import json_parser from modules import json_parser from modules import Client from write_poem import start_model diff --git a/.ipynb_checkpoints/write_poem-checkpoint.py b/.ipynb_checkpoints/write_poem-checkpoint.py index 9bd6408..9b41385 100644 --- a/.ipynb_checkpoints/write_poem-checkpoint.py +++ b/.ipynb_checkpoints/write_poem-checkpoint.py @@ -45,7 +45,7 @@ self.sess = tf.Session() w2v_vocab_size = len(self.w2v.model.vocab) with tf.name_scope('evaluation'): - self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params) + self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params) saver = tf.train.Saver(name='model_saver') saver.restore(self.sess, best_model) @@ -148,7 +148,7 @@ return sample[1:] - def cangtou(self,given_text): + def cangtou(self, given_text): ''' 藏头诗 Returns: @@ -181,7 +181,7 @@ rhyme_seq += 1 sample = self.model.sample_seq(self.sess, self.args.length, start, - sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq ) + sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq ) # 暂时屏蔽 # print('Sampled text is:\n\n%s' % sample) @@ -189,15 +189,16 @@ sample = sample[before_idx:] idx1 = sample.find(',') idx2 = sample.find('。') - min_idx = min(idx1,idx2) + min_idx = min(idx1, idx2) if min_idx == -1: if idx1 > -1 : min_idx = idx1 - else: min_idx =idx2 + else: + min_idx = idx2 if min_idx > 0: # last_sample.append(sample[:min_idx + 1]) - start ='{}{}'.format(start, sample[:min_idx + 1]) + start = '{}{}'.format(start, sample[:min_idx + 1]) if i == 1: rhyme_seq = min_idx - 1 @@ -207,6 +208,7 @@ # print('last_sample text is:\n\n%s' % start) return WritePoem.assemble(start) + def start_model(): now = int(time.time()) diff --git a/char_rnn_model.py b/char_rnn_model.py index 58e2b1a..0ec194f 100644 --- a/char_rnn_model.py +++ b/char_rnn_model.py @@ -232,12 +232,12 @@ unnormalized_probs = np.exp(logits[0] - np.max(logits[0])) probs = unnormalized_probs / np.sum(unnormalized_probs) - if rhyme_ref and i == rhyme_idx : - sample = self.select_rhyme(rhyme_ref,probs) + if rhyme_ref and i == rhyme_idx: + sample = self.select_rhyme(rhyme_ref, probs) elif sample_type == SampleType.max_prob: sample = np.argmax(probs) elif sample_type == SampleType.select_given: - sample,given = self.select_by_given(given,probs) + sample, given = self.select_by_given(given, probs) else: #SampleType.weighted_sample sample = np.random.choice(self.vocab_size, 1, p=probs)[0] @@ -246,35 +246,34 @@ return ''.join(seq) - def select_by_given(self,given,probs,max_prob = False): + def select_by_given(self, given, probs, max_prob=False): if given: - seq_probs = zip(probs,range(0,self.vocab_size)) - topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0]) - - for _,seq in topn: + seq_probs = zip(probs, range(0, self.vocab_size)) + topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0]) + + for _, seq in topn: if self.w2v_model.vocab[seq] in given: - given = given.replace(self.w2v_model.vocab[seq],'') - return seq,given + given = given.replace(self.w2v_model.vocab[seq], '') + return seq, given if max_prob: - return np.argmax(probs),given + return np.argmax(probs), given return np.random.choice(self.vocab_size, 1, p=probs)[0],given - - def select_rhyme(self,rhyme_ref,probs): + def select_rhyme(self, rhyme_ref, probs): if rhyme_ref: rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref) if rhyme_set: - seq_probs = zip(probs,range(0,self.vocab_size)) - topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0]) - - for _,seq in topn: + seq_probs = zip(probs, range(0, self.vocab_size)) + topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0]) + + for _, seq in topn: if self.w2v_model.vocab[seq] in rhyme_set: return seq - return np.argmax(probs) - - def check_start(self,text): + return np.argmax(probs) + + def check_start(self, text): idx = text.find('<') if idx > -1: text = text[:idx] diff --git a/faas_requirements.txt b/faas_requirements.txt deleted file mode 100644 index fff8e36..0000000 --- a/faas_requirements.txt +++ /dev/null @@ -1,122 +0,0 @@ -absl-py==0.6.1 -alembic==1.0.3 -astor==0.7.1 -autopep8==1.4 -backcall==0.1.0 -bleach==3.0.2 -certifi==2018.10.15 -chardet==3.0.4 -cloudpickle==0.6.1 -cycler==0.10.0 -Cython==0.29 -dask==0.20.2 -decorator==4.3.0 -defusedxml==0.5.0 -entrypoints==0.2.3 -future==0.16.0 -gast==0.2.0 -grpcio==1.16.0 -h5py==2.8.0 -html5lib==1.0.1 -idna==2.7 -imageio==2.4.1 -imgaug==0.2.6 -ipykernel==5.1.0 -ipython==7.1.1 -ipython-genutils==0.2.0 -ipywidgets==7.4.2 -jedi==0.13.1 -jieba==0.39 -Jinja2==2.10 -jsonschema==2.6.0 -jupyter==1.0.0 -jupyter-client==5.2.3 -jupyter-console==6.0.0 -jupyter-core==4.4.0 -jupyterhub==0.8.1 -jupyterlab==0.31.1 -jupyterlab-launcher==0.10.5 -Keras==2.2.4 -Keras-Applications==1.0.6 -Keras-Preprocessing==1.0.5 -kiwisolver==1.0.1 -Mako==1.0.7 -Markdown==3.0.1 -MarkupSafe==1.0 -matplotlib==3.0.2 -mccabe==0.6.1 -mistune==0.8.4 -nbconvert==5.4.0 -nbformat==4.4.0 -networkx==2.2 -nltk==3.3 -notebook==5.7.0 -numpy==1.15.2 -opencv-python==3.4.3.18 -pamela==0.3.0 -pandas==0.23.4 -pandocfilters==1.4.2 -parso==0.3.1 -pbr==5.1.1 -pexpect==4.6.0 -pickleshare==0.7.5 -Pillow==5.3.0 -pluggy==0.7.1 -prometheus-client==0.4.2 -prompt-toolkit==2.0.7 -protobuf==3.6.1 -ptyprocess==0.6.0 -pycodestyle==2.4.0 -pycurl==7.43.0 -pydocstyle==2.1.1 -pydot==1.2.4 -pyflakes==2.0.0 -Pygments==2.2.0 -pygobject==3.20.0 -pyparsing==2.3.0 -python-apt==1.1.0b1+ubuntu0.16.4.2 -python-dateutil==2.7.3 -python-editor==1.0.3 -python-jsonrpc-server==0.0.1 -python-language-server==0.21.2 -python-oauth2==1.1.0 -pytz==2018.5 -PyWavelets==1.0.1 -PyYAML==3.13 -pyzmq==17.1.2 -qtconsole==4.4.2 -requests==2.20.0 -rope==0.11.0 -scikit-image==0.14.1 -scikit-learn==0.20.0 -scipy==1.1.0 -seaborn==0.9.0 -Send2Trash==1.5.0 -Shapely==1.6.4.post2 -simplegeneric==0.8.1 -six==1.11.0 -sklearn==0.0 -snowballstemmer==1.2.1 -SQLAlchemy==1.2.14 -stevedore==1.29.0 -tensorboard==1.11.0 -tensorflow==1.12.0 -termcolor==1.1.0 -terminado==0.8.1 -testpath==0.4.2 -toolz==0.9.0 -torch==0.4.1 -torchvision==0.2.1 -tornado==5.1.1 -traitlets==4.3.2 -urllib3==1.24.1 -virtualenv==16.0.0 -virtualenv-clone==0.4.0 -virtualenvwrapper==4.8.2 -wcwidth==0.1.7 -webencodings==0.5.1 -Werkzeug==0.14.1 -widgetsnbextension==3.4.2 -word2vec==0.10.2 -xlrd==1.1.0 -yapf==0.24.0 diff --git a/main.py b/main.py index 54d5b99..269fda9 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,6 @@ import sys # Import necessary packages -from modules import json_parser from modules import json_parser from modules import Client from write_poem import start_model diff --git a/results/tb_results/.ipynb_checkpoints/README-checkpoint.md b/results/tb_results/.ipynb_checkpoints/README-checkpoint.md new file mode 100644 index 0000000..90c2d64 --- /dev/null +++ b/results/tb_results/.ipynb_checkpoints/README-checkpoint.md @@ -0,0 +1 @@ +Please store your tensorboard results here diff --git a/train.py b/train.py index b17444d..8c7501b 100644 --- a/train.py +++ b/train.py @@ -23,7 +23,8 @@ args.save_best_model = os.path.join(args.output_dir, 'best_model/model') # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/') timestamp = str(int(time.time())) - args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp)) + # args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp)) + args.tb_log_dir = os.path.abspath(os.path.join('./results/tb_results', "tensorboard_log", timestamp)) print("Writing to {}\n".format(args.tb_log_dir)) # Create necessary directories. diff --git a/write_poem.py b/write_poem.py index 9bd6408..9b41385 100644 --- a/write_poem.py +++ b/write_poem.py @@ -45,7 +45,7 @@ self.sess = tf.Session() w2v_vocab_size = len(self.w2v.model.vocab) with tf.name_scope('evaluation'): - self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params) + self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params) saver = tf.train.Saver(name='model_saver') saver.restore(self.sess, best_model) @@ -148,7 +148,7 @@ return sample[1:] - def cangtou(self,given_text): + def cangtou(self, given_text): ''' 藏头诗 Returns: @@ -181,7 +181,7 @@ rhyme_seq += 1 sample = self.model.sample_seq(self.sess, self.args.length, start, - sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq ) + sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq ) # 暂时屏蔽 # print('Sampled text is:\n\n%s' % sample) @@ -189,15 +189,16 @@ sample = sample[before_idx:] idx1 = sample.find(',') idx2 = sample.find('。') - min_idx = min(idx1,idx2) + min_idx = min(idx1, idx2) if min_idx == -1: if idx1 > -1 : min_idx = idx1 - else: min_idx =idx2 + else: + min_idx = idx2 if min_idx > 0: # last_sample.append(sample[:min_idx + 1]) - start ='{}{}'.format(start, sample[:min_idx + 1]) + start = '{}{}'.format(start, sample[:min_idx + 1]) if i == 1: rhyme_seq = min_idx - 1 @@ -207,6 +208,7 @@ # print('last_sample text is:\n\n%s' % start) return WritePoem.assemble(start) + def start_model(): now = int(time.time())