master
/ .ipynb_checkpoints / char_rnn_model-checkpoint.py

char_rnn_model-checkpoint.py @6135863

850ec5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c8dc33
850ec5b
 
 
8c8dc33
850ec5b
 
 
 
 
 
 
 
8c8dc33
850ec5b
8c8dc33
 
850ec5b
8c8dc33
850ec5b
8c8dc33
 
850ec5b
8c8dc33
850ec5b
 
 
8c8dc33
850ec5b
 
 
8c8dc33
 
850ec5b
8c8dc33
850ec5b
 
 
8c8dc33
850ec5b
8c8dc33
850ec5b
 
 
 
 
 
 
 
 
import logging
import time
from enum import Enum
import heapq
import numpy as np
import tensorflow as tf
from rhyme_helper import RhymeWords

logging.getLogger('tensorflow').setLevel(logging.WARNING)
SampleType = Enum('SampleType',('max_prob', 'weighted_sample', 'rhyme','select_given'))

class CharRNNLM(object):
    def __init__(self, is_training, batch_size, num_unrollings, vocab_size,w2v_model,
                 hidden_size, max_grad_norm, embedding_size, num_layers,
                 learning_rate, cell_type, dropout=0.0, input_dropout=0.0, infer=False):
        self.batch_size = batch_size
        self.num_unrollings = num_unrollings
        if infer:
            self.batch_size = 1
            self.num_unrollings = 1
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.max_grad_norm = max_grad_norm
        self.num_layers = num_layers
        self.embedding_size = embedding_size
        self.cell_type = cell_type
        self.dropout = dropout
        self.input_dropout = input_dropout
        self.w2v_model = w2v_model

        if embedding_size <= 0:
            self.input_size = vocab_size
            self.input_dropout = 0.0
        else:
            self.input_size = embedding_size

        self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='inputs')
        self.targets =  tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='targets')

        if self.cell_type == 'rnn':
            cell_fn = tf.nn.rnn_cell.BasicRNNCell
        elif self.cell_type == 'lstm':
            cell_fn = tf.nn.rnn_cell.LSTMCell
        elif self.cell_type == 'gru':
            cell_fn = tf.nn.rnn_cell.GRUCell

        params = dict()
        #params = {'input_size': self.input_size}
        if self.cell_type == 'lstm':
            params['forget_bias'] = 1.0  # 1.0 is default value
        cell = cell_fn(self.hidden_size, **params)

        cells = [cell]
        #params['input_size'] = self.hidden_size
        for i in range(self.num_layers-1):
            higher_layer_cell = cell_fn(self.hidden_size, **params)
            cells.append(higher_layer_cell)

        if is_training and self.dropout > 0:
            cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells]

        multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells)

        with tf.name_scope('initial_state'):
            self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32)
            if self.cell_type == 'rnn' or self.cell_type == 'gru':
                self.initial_state = tuple(
                        [tf.placeholder(tf.float32,
                            [self.batch_size, multi_cell.state_size[idx]],
                            'initial_state_'+str(idx+1)) for idx in range(self.num_layers)])
            elif self.cell_type == 'lstm':
                self.initial_state = tuple(
                        [tf.nn.rnn_cell.LSTMStateTuple(
                            tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][0]],
                                          'initial_lstm_state_'+str(idx+1)),
                            tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][1]],
                                           'initial_lstm_state_'+str(idx+1)))
                            for idx in range(self.num_layers)])

        with tf.name_scope('embedding_layer'):
            if embedding_size > 0:
                # self.embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size])
                self.embedding = tf.get_variable("word_embeddings",
                    initializer=self.w2v_model.vectors.astype(np.float32))
            else:
                self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32)

            inputs = tf.nn.embedding_lookup(self.embedding, self.input_data)
            if is_training and self.input_dropout > 0:
                inputs = tf.nn.dropout(inputs, 1-self.input_dropout)

        with tf.name_scope('slice_inputs'):
            # num_unrollings * (batch_size, embedding_size), the format of rnn inputs.
            sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(
                axis = 1, num_or_size_splits = self.num_unrollings, value = inputs)]

        # sliced_inputs: list of shape xx
        # inputs: A length T list of inputs, each a Tensor of shape [batch_size, input_size]
        # initial_state: An initial state for the RNN.
        #                If cell.state_size is an integer, this must be a Tensor of appropriate
        #                type and shape [batch_size, cell.state_size]
        # outputs: a length T list of outputs (one for each input), or a nested tuple of such elements.
        # state: the final state
        outputs, final_state = tf.nn.static_rnn(
                cell = multi_cell,
                inputs = sliced_inputs,
                initial_state=self.initial_state)
        self.final_state = final_state

        with tf.name_scope('flatten_outputs'):
            flat_outputs = tf.reshape(tf.concat(axis = 1, values = outputs), [-1, hidden_size])

        with tf.name_scope('flatten_targets'):
            flat_targets = tf.reshape(tf.concat(axis = 1, values = self.targets), [-1])

        with tf.variable_scope('softmax') as sm_vs:
            softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size])
            softmax_b = tf.get_variable('softmax_b', [vocab_size])
            self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b
            self.probs = tf.nn.softmax(self.logits)

        with tf.name_scope('loss'):
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits = self.logits, labels = flat_targets)
            self.mean_loss = tf.reduce_mean(loss)

        with tf.name_scope('loss_montor'):
            count = tf.Variable(1.0, name='count')
            sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss')

            self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0),
                                               count.assign(0.0), name='reset_loss_monitor')
            self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss),
                                                count.assign(count+1), name='update_loss_monitor')

            with tf.control_dependencies([self.update_loss_monitor]):
                self.average_loss = sum_mean_loss / count
                self.ppl = tf.exp(self.average_loss)

            average_loss_summary = tf.summary.scalar(
                    name = 'average loss', tensor = self.average_loss)
            ppl_summary = tf.summary.scalar(
                    name = 'perplexity', tensor = self.ppl)

        self.summaries = tf.summary.merge(
                inputs = [average_loss_summary, ppl_summary], name='loss_monitor')

        self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.0))

        # self.learning_rate = tf.constant(learning_rate)
        self.learning_rate = tf.placeholder(tf.float32, [], name='learning_rate')

        if is_training:
            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm)
            optimizer = tf.train.AdamOptimizer(self.learning_rate)
            self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step)


    def run_epoch(self, session, batch_generator, is_training, learning_rate, verbose=0, freq=10):
        epoch_size = batch_generator.num_batches

        if verbose > 0:
            logging.info('epoch_size: %d', epoch_size)
            logging.info('data_size: %d', batch_generator.seq_length)
            logging.info('num_unrollings: %d', self.num_unrollings)
            logging.info('batch_size: %d', self.batch_size)

        if is_training:
            extra_op = self.train_op
        else:
            extra_op = tf.no_op()

        if self.cell_type in ['rnn', 'gru']:
            state = self.zero_state.eval()
        else:
            state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
                np.zeros((self.batch_size, self.hidden_size)))
                for _ in range(self.num_layers)])

        self.reset_loss_monitor.run()
        batch_generator.reset_batch_pointer()
        start_time = time.time()
        ppl_cumsum = 0
        for step in range(epoch_size):
            x, y = batch_generator.next_batch()

            ops = [self.average_loss, self.ppl, self.final_state, extra_op,
                   self.summaries, self.global_step]

            feed_dict = {self.input_data: x, self.targets: y, self.initial_state: state,
                         self.learning_rate: learning_rate}

            results = session.run(ops, feed_dict)
            average_loss, ppl, final_state, _, summary_str, global_step = results
            ppl_cumsum += ppl

            # if (verbose > 0) and ((step+1) % freq == 0):
            if ((step+1) % freq == 0):
                logging.info('%.1f%%, step:%d, perplexity: %.3f, speed: %.0f words',
                             (step + 1) * 1.0 / epoch_size * 100, step, ppl_cumsum/(step+1),
                             (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
        logging.info("Perplexity: %.3f, speed: %.0f words per sec",
                     ppl, (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))

        return ppl, summary_str, global_step

    def sample_seq(self, session, length, start_text,  sample_type= SampleType.max_prob,given='',rhyme_ref='',rhyme_idx = 0):
        #state = self.zero_state.eval()
        if self.cell_type in ['rnn', 'gru']:
            state = self.zero_state.eval()
        else:
            state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
                np.zeros((self.batch_size, self.hidden_size)))
                for _ in range(self.num_layers)])

        # use start_text to warm up the RNN.
        start_text = self.check_start(start_text)
        if start_text is not None and len(start_text) > 0:
            seq = list(start_text)
            for char in start_text[:-1]:
                x = np.array([[self.w2v_model.vocab_hash[char]]])
                state = session.run(self.final_state, {self.input_data: x, self.initial_state: state})
            x = np.array([[self.w2v_model.vocab_hash[start_text[-1]]]])
        else:
            x = np.array([[np.random.randint(0, self.vocab_size)]])
            seq = []

        for i in range(length):
            state, logits = session.run([self.final_state, self.logits],
                                        {self.input_data: x, self.initial_state: state})
            unnormalized_probs = np.exp(logits[0] - np.max(logits[0]))
            probs = unnormalized_probs / np.sum(unnormalized_probs)

            if rhyme_ref and i == rhyme_idx :
                sample = self.select_rhyme(rhyme_ref, probs)
            elif sample_type == SampleType.max_prob:
                sample = np.argmax(probs)
            elif sample_type == SampleType.select_given:
                sample, given = self.select_by_given(given, probs)
            else: #SampleType.weighted_sample
                sample = np.random.choice(self.vocab_size, 1, p=probs)[0]

            seq.append(self.w2v_model.vocab[sample])
            x = np.array([[sample]])

        return ''.join(seq)

    def select_by_given(self, given, probs, max_prob=False):
        if given:
                seq_probs = zip(probs, range(0, self.vocab_size))
                topn = heapq.nlargest(100, seq_probs, key=lambda sp: sp[0])

                for _, seq in topn:
                    if self.w2v_model.vocab[seq] in given:
                        given = given.replace(self.w2v_model.vocab[seq], '')
                        return seq, given
        if max_prob:
            return np.argmax(probs), given

        return np.random.choice(self.vocab_size, 1, p=probs)[0],given

    def select_rhyme(self, rhyme_ref, probs):
        if rhyme_ref:
            rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref)
            if rhyme_set:
                seq_probs = zip(probs, range(0, self.vocab_size))
                topn = heapq.nlargest(50, seq_probs, key=lambda sp: sp[0])

                for _, seq in topn:
                    if self.w2v_model.vocab[seq] in rhyme_set:
                        return seq

        return np.argmax(probs)

    def check_start(self, text):
        idx = text.find('<')
        if idx > -1:
            text = text[:idx]

        valid_text = []
        for w in text:
            if w in self.w2v_model.vocab:
                valid_text.append(w)
        return ''.join(valid_text)