master
/ .ipynb_checkpoints / write_poem-checkpoint.py

write_poem-checkpoint.py @6135863

af5088c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b707fb
af5088c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b707fb
af5088c
 
 
 
 
 
 
 
 
8c8dc33
af5088c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c8dc33
af5088c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c8dc33
af5088c
 
 
 
 
 
 
8c8dc33
af5088c
 
 
 
8c8dc33
 
af5088c
 
8c8dc33
af5088c
 
 
 
 
 
 
 
 
 
8c8dc33
af5088c
 
6a25108
af5088c
 
 
 
 
import json
import os
import sys
import time
import logging
import math
import numpy as np
import tensorflow as tf
from char_rnn_model import CharRNNLM, SampleType
from config_poem import config_sample
from word2vec_helper import Word2Vec
from rhyme_helper import RhymeWords


class WritePoem():
    def __init__(self, args):
        self.args = args

        logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO, datefmt='%I:%M:%S')

        with open(os.path.join(self.args.model_dir, 'result.json'), 'r') as f:
            result = json.load(f)

        params = result['params']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            self.args.encoding = result['encoding']
        else:
            self.args.encoding = 'utf-8'

        base_path = args.data_dir
        w2v_file = os.path.join(base_path, "vectors_poem.bin")
        self.w2v = Word2Vec(w2v_file)

        RhymeWords.read_rhyme_words(os.path.join(base_path, 'rhyme_words.txt'))

        if args.seed >= 0:
            np.random.seed(args.seed)

        logging.info('best_model: %s\n', best_model)

        self.sess = tf.Session()
        w2v_vocab_size = len(self.w2v.model.vocab)
        with tf.name_scope('evaluation'):
            self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params)
            saver = tf.train.Saver(name='model_saver')
            saver.restore(self.sess, best_model)

    def free_verse(self):
        '''
        自由诗
        Returns:

        '''
        sample = self.model.sample_seq(self.sess, 40, '[',sample_type= SampleType.weighted_sample)
        if not sample:
            return 'err occar!'

        # 暂时屏蔽
        # print('free_verse:',sample)

        idx_end = sample.find(']')
        parts = sample.split('。')
        if len(parts) > 1:
            two_sentence_len = len(parts[0]) + len(parts[1])
            if idx_end < 0 or two_sentence_len < idx_end:
                return sample[1:two_sentence_len + 2]

        return sample[1:idx_end]

    @staticmethod
    def assemble(sample):
        if  sample:
            parts = sample.split('。')
            if len(parts) > 1:
                return '{}{}。'.format(parts[0][1:],parts[1][:len(parts[0])])

        return ''


    def rhyme_verse(self):
        '''
        押韵诗
        Returns:

        '''
        gen_len = 20
        sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.weighted_sample)
        if not sample:
            return 'err occar!'

        # 暂时屏蔽
        # print('rhyme_verse:',sample)

        parts = sample.split('。')
        if len(parts) > 0:
           start = parts[0] + '。'
           rhyme_ref_word = start[-2]
           rhyme_seq = len(start) - 3

           sample = self.model.sample_seq(self.sess, gen_len , start,
                                                  sample_type= SampleType.weighted_sample,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
        # 暂时屏蔽
        #    print(sample)
           return WritePoem.assemble(sample)

        return sample[1:]

    def hide_words(self,given_text):
        '''
        藏字诗
        Args:
            given_text:

        Returns:

        '''
        if(not given_text):
            return self.rhyme_verse()

        givens = ['','']
        split_len = math.ceil(len(given_text)/2)
        givens[0] = given_text[:split_len]
        givens[1] = given_text[split_len:]

        gen_len = 20
        sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.select_given,given=givens[0])
        if not sample:
            return 'err occar!'
        # 暂时屏蔽
        # print('rhyme_verse:',sample)

        parts = sample.split('。')
        if len(parts) > 0:
           start = parts[0] + '。'
           rhyme_ref_word = start[-2]
           rhyme_seq = len(start) - 3
           # gen_len = len(start) - 1

           sample = self.model.sample_seq(self.sess, gen_len , start,
                                                  sample_type= SampleType.select_given,given=givens[1],rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
        # 暂时屏蔽
        #    print(sample)
           return WritePoem.assemble(sample)

        return sample[1:]

    def cangtou(self, given_text):
        '''
        藏头诗
        Returns:

        '''
        if(not given_text):
            return self.rhyme_verse()

        start = ''
        rhyme_ref_word = ''
        rhyme_seq = 0

        # for i,word in enumerate(given_text):
        for i in range(4):
            word = ''
            if i < len(given_text):
                word = given_text[i]

            if i == 0:
                start = '[' + word
            else:
                start += word

            before_idx = len(start)
            if(i != 3):
                sample = self.model.sample_seq(self.sess, self.args.length, start,
                                         sample_type= SampleType.weighted_sample )
            else:
                if not word:
                    rhyme_seq += 1

                sample = self.model.sample_seq(self.sess, self.args.length, start,
                                      sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq )

            # 暂时屏蔽
            # print('Sampled text is:\n\n%s' % sample)

            sample = sample[before_idx:]
            idx1 = sample.find(',')
            idx2 = sample.find('。')
            min_idx = min(idx1, idx2)

            if min_idx == -1:
                if idx1 > -1 :
                    min_idx = idx1
                else:
                    min_idx = idx2
            if min_idx > 0:
                # last_sample.append(sample[:min_idx + 1])
                start = '{}{}'.format(start, sample[:min_idx + 1])

                if i == 1:
                    rhyme_seq = min_idx - 1
                    rhyme_ref_word = sample[rhyme_seq]

            # 暂时屏蔽
            # print('last_sample text is:\n\n%s' % start)

        return WritePoem.assemble(start)


def start_model():
    now = int(time.time())
    args = config_sample('--model_dir ./results/output_poem --length 16 --seed {}'.format(now))
    writer = WritePoem(args)
    return writer

if __name__ == '__main__':
    writer = start_model()