master
/ .ipynb_checkpoints / write_poem-checkpoint.py

write_poem-checkpoint.py @6135863 raw · history · blame

import json
import os
import sys
import time
import logging
import math
import numpy as np
import tensorflow as tf
from char_rnn_model import CharRNNLM, SampleType
from config_poem import config_sample
from word2vec_helper import Word2Vec
from rhyme_helper import RhymeWords


class WritePoem():
    def __init__(self, args):
        self.args = args

        logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO, datefmt='%I:%M:%S')

        with open(os.path.join(self.args.model_dir, 'result.json'), 'r') as f:
            result = json.load(f)

        params = result['params']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            self.args.encoding = result['encoding']
        else:
            self.args.encoding = 'utf-8'

        base_path = args.data_dir
        w2v_file = os.path.join(base_path, "vectors_poem.bin")
        self.w2v = Word2Vec(w2v_file)

        RhymeWords.read_rhyme_words(os.path.join(base_path, 'rhyme_words.txt'))

        if args.seed >= 0:
            np.random.seed(args.seed)

        logging.info('best_model: %s\n', best_model)

        self.sess = tf.Session()
        w2v_vocab_size = len(self.w2v.model.vocab)
        with tf.name_scope('evaluation'):
            self.model = CharRNNLM(is_training=False, w2v_model = self.w2v.model, vocab_size=w2v_vocab_size, infer=True, **params)
            saver = tf.train.Saver(name='model_saver')
            saver.restore(self.sess, best_model)

    def free_verse(self):
        '''
        自由诗
        Returns:

        '''
        sample = self.model.sample_seq(self.sess, 40, '[',sample_type= SampleType.weighted_sample)
        if not sample:
            return 'err occar!'

        # 暂时屏蔽
        # print('free_verse:',sample)

        idx_end = sample.find(']')
        parts = sample.split('。')
        if len(parts) > 1:
            two_sentence_len = len(parts[0]) + len(parts[1])
            if idx_end < 0 or two_sentence_len < idx_end:
                return sample[1:two_sentence_len + 2]

        return sample[1:idx_end]

    @staticmethod
    def assemble(sample):
        if  sample:
            parts = sample.split('。')
            if len(parts) > 1:
                return '{}{}。'.format(parts[0][1:],parts[1][:len(parts[0])])

        return ''


    def rhyme_verse(self):
        '''
        押韵诗
        Returns:

        '''
        gen_len = 20
        sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.weighted_sample)
        if not sample:
            return 'err occar!'

        # 暂时屏蔽
        # print('rhyme_verse:',sample)

        parts = sample.split('。')
        if len(parts) > 0:
           start = parts[0] + '。'
           rhyme_ref_word = start[-2]
           rhyme_seq = len(start) - 3

           sample = self.model.sample_seq(self.sess, gen_len , start,
                                                  sample_type= SampleType.weighted_sample,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
        # 暂时屏蔽
        #    print(sample)
           return WritePoem.assemble(sample)

        return sample[1:]

    def hide_words(self,given_text):
        '''
        藏字诗
        Args:
            given_text:

        Returns:

        '''
        if(not given_text):
            return self.rhyme_verse()

        givens = ['','']
        split_len = math.ceil(len(given_text)/2)
        givens[0] = given_text[:split_len]
        givens[1] = given_text[split_len:]

        gen_len = 20
        sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.select_given,given=givens[0])
        if not sample:
            return 'err occar!'
        # 暂时屏蔽
        # print('rhyme_verse:',sample)

        parts = sample.split('。')
        if len(parts) > 0:
           start = parts[0] + '。'
           rhyme_ref_word = start[-2]
           rhyme_seq = len(start) - 3
           # gen_len = len(start) - 1

           sample = self.model.sample_seq(self.sess, gen_len , start,
                                                  sample_type= SampleType.select_given,given=givens[1],rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
        # 暂时屏蔽
        #    print(sample)
           return WritePoem.assemble(sample)

        return sample[1:]

    def cangtou(self, given_text):
        '''
        藏头诗
        Returns:

        '''
        if(not given_text):
            return self.rhyme_verse()

        start = ''
        rhyme_ref_word = ''
        rhyme_seq = 0

        # for i,word in enumerate(given_text):
        for i in range(4):
            word = ''
            if i < len(given_text):
                word = given_text[i]

            if i == 0:
                start = '[' + word
            else:
                start += word

            before_idx = len(start)
            if(i != 3):
                sample = self.model.sample_seq(self.sess, self.args.length, start,
                                         sample_type= SampleType.weighted_sample )
            else:
                if not word:
                    rhyme_seq += 1

                sample = self.model.sample_seq(self.sess, self.args.length, start,
                                      sample_type= SampleType.max_prob, rhyme_ref =rhyme_ref_word, rhyme_idx = rhyme_seq )

            # 暂时屏蔽
            # print('Sampled text is:\n\n%s' % sample)

            sample = sample[before_idx:]
            idx1 = sample.find(',')
            idx2 = sample.find('。')
            min_idx = min(idx1, idx2)

            if min_idx == -1:
                if idx1 > -1 :
                    min_idx = idx1
                else:
                    min_idx = idx2
            if min_idx > 0:
                # last_sample.append(sample[:min_idx + 1])
                start = '{}{}'.format(start, sample[:min_idx + 1])

                if i == 1:
                    rhyme_seq = min_idx - 1
                    rhyme_ref_word = sample[rhyme_seq]

            # 暂时屏蔽
            # print('last_sample text is:\n\n%s' % start)

        return WritePoem.assemble(start)


def start_model():
    now = int(time.time())
    args = config_sample('--model_dir ./results/output_poem --length 16 --seed {}'.format(now))
    writer = WritePoem(args)
    return writer

if __name__ == '__main__':
    writer = start_model()