master
/ .ipynb_checkpoints / train-checkpoint.py

train-checkpoint.py @6135863 raw · history · blame

import codecs
import json
import logging
import os
import shutil
import sys
import time
import numpy as np
import tensorflow as tf
from char_rnn_model import CharRNNLM
from config_poem import config_poem_train
from data_loader import DataLoader
from word2vec_helper import Word2Vec

sys.path.insert(0, os.path.dirname(__file__))
TF_VERSION = int(tf.__version__.split('.')[1])


def main(args=''):
    args = config_poem_train(args)
    # Specifying location to store model, best model and tensorboard log.
    args.save_model = os.path.join(args.output_dir, 'save_model/model')
    args.save_best_model = os.path.join(args.output_dir, 'best_model/model')
    # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
    timestamp = str(int(time.time()))
    # args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp))
    args.tb_log_dir = os.path.abspath(os.path.join('./results/tb_results', "tensorboard_log", timestamp))
    print("Writing to {}\n".format(args.tb_log_dir))

    # Create necessary directories.
    if len(args.init_dir) != 0:
        args.output_dir = args.init_dir

    # else:
        # if os.path.exists(args.output_dir):
        #     shutil.rmtree(args.output_dir)
        # for paths in [args.save_model, args.save_best_model, args.tb_log_dir]:
        #     os.makedirs(os.path.dirname(paths))

    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO, datefmt='%I:%M:%S')

    print('=' * 60)
    print('All final and intermediate outputs will be stored in %s/' % args.output_dir)
    print('=' * 60 + '\n')

    logging.info('args are:\n%s', args)

    if len(args.init_dir) != 0:
        with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
            result = json.load(f)
        params = result['params']
        args.init_model = result['latest_model']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            args.encoding = result['encoding']
        else:
            args.encoding = 'utf-8'

    else:
        params = {'batch_size': args.batch_size,
                  'num_unrollings': args.num_unrollings,
                  'hidden_size': args.hidden_size,
                  'max_grad_norm': args.max_grad_norm,
                  'embedding_size': args.embedding_size,
                  'num_layers': args.num_layers,
                  'learning_rate': args.learning_rate,
                  'cell_type': args.cell_type,
                  'dropout': args.dropout,
                  'input_dropout': args.input_dropout}
        best_model = ''
    logging.info('Parameters are:\n%s\n', json.dumps(params, sort_keys=True, indent=4))

    # Create batch generators.
    batch_size = params['batch_size']
    num_unrollings = params['num_unrollings']

    base_path = args.data_path
    w2v_file = os.path.join(base_path, "vectors_poem.bin")
    w2v = Word2Vec(w2v_file)

    train_data_loader = DataLoader(base_path, batch_size, num_unrollings, w2v.model, 'train')
    test1_data_loader = DataLoader(base_path, batch_size, num_unrollings, w2v.model, 'test')
    valid_data_loader = DataLoader(base_path, batch_size, num_unrollings, w2v.model, 'valid')

    # Create graphs
    logging.info('Creating graph')
    graph = tf.Graph()
    with graph.as_default():
        w2v_vocab_size = len(w2v.model.vocab)
        with tf.name_scope('training'):
            train_model = CharRNNLM(is_training=True, w2v_model = w2v.model, vocab_size=w2v_vocab_size, infer=False, **params)
            tf.get_variable_scope().reuse_variables()

        with tf.name_scope('validation'):
            valid_model = CharRNNLM(is_training=False, w2v_model = w2v.model, vocab_size=w2v_vocab_size, infer=False, **params)

        with tf.name_scope('evaluation'):
            test_model = CharRNNLM(is_training=False, w2v_model = w2v.model,vocab_size=w2v_vocab_size,  infer=False, **params)
            saver = tf.train.Saver(name='model_saver')
            best_model_saver = tf.train.Saver(name='best_model_saver')

    logging.info('Start training\n')

    result = {}
    result['params'] = params

    try:
        with tf.Session(graph=graph) as session:
            # Version 8 changed the api of summary writer to use
            # graph instead of graph_def.
            if TF_VERSION >= 8:
                graph_info = session.graph
            else:
                graph_info = session.graph_def

            train_summary_dir = os.path.join(args.tb_log_dir, "summaries", "train")
            train_writer = tf.summary.FileWriter(train_summary_dir, graph_info)
            valid_summary_dir = os.path.join(args.tb_log_dir, "summaries", "valid")
            valid_writer = tf.summary.FileWriter(valid_summary_dir, graph_info)

            # load a saved model or start from random initialization.
            if len(args.init_model) != 0:
                saver.restore(session, args.init_model)
            else:
                tf.global_variables_initializer().run()

            learning_rate = args.learning_rate
            for epoch in range(args.num_epochs):
                logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch)
                logging.info('Training on training set')
                # training step
                ppl, train_summary_str, global_step = train_model.run_epoch(session, train_data_loader, is_training=True,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
                # record the summary
                train_writer.add_summary(train_summary_str, global_step)
                train_writer.flush()
                # save model
                saved_path = saver.save(session, args.save_model,
                                        global_step=train_model.global_step)

                logging.info('Latest model saved in %s\n', saved_path)
                logging.info('Evaluate on validation set')

                valid_ppl, valid_summary_str, _ = valid_model.run_epoch(session, valid_data_loader, is_training=False,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)

                # save and update best model
                if (len(best_model) == 0) or (valid_ppl < best_valid_ppl):
                    best_model = best_model_saver.save(session, args.save_best_model,
                                                       global_step=train_model.global_step)
                    best_valid_ppl = valid_ppl
                else:
                    learning_rate /= 2.0
                    logging.info('Decay the learning rate: ' + str(learning_rate))

                valid_writer.add_summary(valid_summary_str, global_step)
                valid_writer.flush()

                logging.info('Best model is saved in %s', best_model)
                logging.info('Best validation ppl is %f\n', best_valid_ppl)

                result['latest_model'] = saved_path
                result['best_model'] = best_model
                # Convert to float because numpy.float is not json serializable.
                result['best_valid_ppl'] = float(best_valid_ppl)

                result_path = os.path.join(args.output_dir, 'result.json')
                if os.path.exists(result_path):
                    os.remove(result_path)
                with open(result_path, 'w') as f:
                    json.dump(result, f, indent=2, sort_keys=True)

            logging.info('Latest model is saved in %s', saved_path)
            logging.info('Best model is saved in %s', best_model)
            logging.info('Best validation ppl is %f\n', best_valid_ppl)

            logging.info('Evaluate the best model on test set')
            saver.restore(session, best_model)
            test_ppl, _, _ = test_model.run_epoch(session, test1_data_loader, is_training=False,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
            result['test_ppl'] = float(test_ppl)
    except Exception as e:
        print('err :{}'.format(e))
    finally:
        result_path = os.path.join(args.output_dir, 'result.json')
        if os.path.exists(result_path):
            os.remove(result_path)
        with open(result_path, 'w', encoding='utf-8', errors='ignore') as f:
            json.dump(result, f, indent=2, sort_keys=True)


if __name__ == '__main__':
    args = '--output_dir ./results/output_poem --data_path ./datasets/yangsaisai-poetrydatasets-0_0_1/  --hidden_size 128 --embedding_size 128 --cell_type lstm'
    main(args)