diff --git a/.ipynb_checkpoints/data_loader-checkpoint.py b/.ipynb_checkpoints/data_loader-checkpoint.py new file mode 100644 index 0000000..0de9a62 --- /dev/null +++ b/.ipynb_checkpoints/data_loader-checkpoint.py @@ -0,0 +1,262 @@ +import os +#import collections +from six.moves import cPickle +import numpy as np +from word2vec_helper import Word2Vec +import math + + + +class DataLoader(): + def __init__(self, data_dir, batch_size,seq_max_length,w2v,data_type): + self.data_dir = data_dir + self.batch_size = batch_size + self.seq_max_length = seq_max_length + self.w2v = w2v + self.trainingSamples = [] + self.validationSamples = [] + self.testingSamples = [] + self.train_frac = 0.85 + self.valid_frac = 0.05 + + self.load_corpus(self.data_dir) + + if data_type == 'train': + self.create_batches(self.trainingSamples) + elif data_type == 'test': + self.create_batches(self.testingSamples) + elif data_type == 'valid': + self.create_batches(self.validationSamples) + + self.reset_batch_pointer() + + def _print_stats(self): + print('Loaded {}: training samples:{} ,validationSamples:{},testingSamples:{}'.format( + self.data_dir, len(self.trainingSamples),len(self.validationSamples),len(self.testingSamples))) + + def load_corpus(self,base_path): + """读/创建 对话数据: + 在训练文件创建的过程中,由两个文件 + 1. self.fullSamplePath + 2. self.filteredSamplesPath + """ + tensor_file = os.path.join(base_path,'poem_ids.txt') + print('tensor_file:%s' % tensor_file) + + datasetExist = os.path.isfile(tensor_file) + # 如果处理过的对话数据文件不存在,创建数据文件 + if not datasetExist: + print('训练样本不存在。从原始样本数据集创建训练样本...') + + fullSamplesPath = os.path.join(self.data_dir,'poems_edge_split.txt') + # 创建/读取原始对话样本数据集: self.trainingSamples + print('fullSamplesPath:%s' % fullSamplesPath) + self.load_from_text_file(fullSamplesPath) + + else: + self.load_dataset(tensor_file) + + self.padToken = self.w2v.ix('') + self.goToken = self.w2v.ix('[') + self.eosToken = self.w2v.ix(']') + self.unknownToken = self.w2v.ix('') + + self._print_stats() + # assert self.padToken == 0 + + def load_from_text_file(self,in_file): + # base_path = 'F:\BaiduYunDownload\chatbot_lecture\lecture2\data\ice_and_fire_zh' + # in_file = os.path.join(base_path,'poems_edge.txt') + fr = open(in_file, "r",encoding='utf-8') + poems = fr.readlines() + fr.close() + + print("唐诗总数: %d"%len(poems)) + # self.seq_max_length = max([len(poem) for poem in poems]) + # print("seq_max_length: %d"% (self.seq_max_length)) + + poem_ids = DataLoader.get_text_idx(poems,self.w2v.vocab_hash,self.seq_max_length) + + # # 后续处理 + # # 1. 单词过滤,去掉不常见(<=filterVocab)的单词,保留最常见的vocabSize个单词 + # print('Filtering words (vocabSize = {} and wordCount > {})...'.format( + # self.args.vocabularySize, + # self.args.filterVocab + # )) + # self.filterFromFull() + + # 2. 分割数据 + print('分割数据为 train, valid, test 数据集...') + n_samples = len(poem_ids) + train_size = int(self.train_frac * n_samples) + valid_size = int(self.valid_frac * n_samples) + test_size = n_samples - train_size - valid_size + + print('n_samples=%d, train-size=%d, valid_size=%d, test_size=%d' % ( + n_samples, train_size, valid_size, test_size)) + self.testingSamples = poem_ids[-test_size:] + self.validationSamples = poem_ids[-valid_size-test_size : -test_size] + self.trainingSamples = poem_ids[:train_size] + + # 保存处理过的训练数据集 + print('Saving dataset...') + poem_ids_file = os.path.join(self.data_dir,'poem_ids.txt') + self.save_dataset(poem_ids_file) + + # 2. utility 函数,使用pickle写文件 + def save_dataset(self, filename): + """使用pickle保存数据文件。 + + 数据文件包含词典和对话样本。 + + Args: + filename (str): pickle 文件名 + """ + with open(filename, 'wb') as handle: + data = { + 'trainingSamples': self.trainingSamples + } + + if len(self.validationSamples)>0: + data['validationSamples'] = self.validationSamples + data['testingSamples'] = self.testingSamples + data['maxSeqLen'] = self.seq_max_length + + cPickle.dump(data, handle, -1) # Using the highest protocol available + + # 3. utility 函数,使用pickle读文件 + def load_dataset(self, filename): + """使用pickle读入数据文件 + Args: + filename (str): pickle filename + """ + + print('Loading dataset from {}'.format(filename)) + with open(filename, 'rb') as handle: + data = cPickle.load(handle) + self.trainingSamples = data['trainingSamples'] + + if 'validationSamples' in data: + self.validationSamples = data['validationSamples'] + self.testingSamples = data['testingSamples'] + + print('file maxSeqLen = {}'.format( data['maxSeqLen'])) + + + @classmethod + def get_text_idx(text,vocab,max_document_length): + text_array = [] + for i,x in enumerate(text): + line = [] + for j, w in enumerate(x): + if (w not in vocab): + w = '' + line.append(vocab[w]) + text_array.append(line) + # else : + # print w,'not exist' + + return text_array + + def create_batches(self,samples): + + sample_size = len(samples) + self.num_batches = math.ceil(sample_size /self.batch_size) + new_sample_size = self.num_batches * self.batch_size + + # Create the batch tensor + # x_lengths = [len(sample) for sample in samples] + + x_lengths = [] + x_seqs = np.ndarray((new_sample_size,self.seq_max_length),dtype=np.int32) + y_seqs = np.ndarray((new_sample_size,self.seq_max_length),dtype=np.int32) + self.x_lengths = [] + for i,sample in enumerate(samples): + # fill with padding to align batchSize samples into one 2D list + x_lengths.append(len(sample)) + x_seqs[i] = sample + [self.padToken] * (self.seq_max_length - len(sample)) + + for i in range(sample_size,new_sample_size): + copyi = i - sample_size + x_seqs[i] = x_seqs[copyi] + x_lengths.append(x_lengths[copyi]) + + y_seqs[:,:-1] = x_seqs[:,1:] + y_seqs[:,-1] = x_seqs[:,0] + x_len_array = np.array(x_lengths) + + + + self.x_batches = np.split(x_seqs.reshape(self.batch_size, -1), self.num_batches, 1) + self.x_len_batches = np.split(x_len_array.reshape(self.batch_size, -1), self.num_batches, 1) + self.y_batches = np.split(y_seqs.reshape(self.batch_size, -1), self.num_batches, 1) + + def next_batch_dynamic(self): + x,x_len, y = self.x_batches[self.pointer], self.x_len_batches[self.pointer],self.y_batches[self.pointer] + self.pointer += 1 + return x,x_len, y + + def next_batch(self): + x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] + self.pointer += 1 + return x,y + + def reset_batch_pointer(self): + self.pointer = 0 + + @staticmethod + def get_text_idx(text,vocab,max_document_length): + max_document_length_without_end = max_document_length - 1 + text_array = [] + for i,x in enumerate(text): + line = [] + if len(x) > max_document_length: + x_parts = x[:max_document_length_without_end] + idx = x_parts.rfind('。') + if idx > -1 : + x_parts = x_parts[0:idx + 1] + ']' + x = x_parts + + for j, w in enumerate(x): + # if j >= max_document_length: + # break + + if (w not in vocab): + w = '' + line.append(vocab[w]) + text_array.append(line) + # else : + # print w,'not exist' + + return text_array + +if __name__ == '__main__': + base_path = './data/poem' + # poem = '风急云轻鹤背寒,洞天谁道却归难。千山万水瀛洲路,何处烟飞是醮坛。是的' + # idx = poem.rfind('。') + # poem_part = poem[:idx + 1] + w2v_file = os.path.join(base_path, "vectors_poem.bin") + w2v = Word2Vec(w2v_file) + + # vect = w2v_model['['][:10] + # print(vect) + # + # vect = w2v_model['春'][:10] + # print(vect) + + in_file = os.path.join(base_path,'poems_edge.txt') + # fr = open(in_file, "r",encoding='utf-8') + # poems = fr.readlines() + # fr.close() + # + # + # + # print("唐诗总数: %d"%len(poems)) + # + # poem_ids = get_text_idx(poems,w2v.model.vocab_hash,100) + # poem_ids_file = os.path.join(base_path,'poem_ids.txt') + # with open(poem_ids_file, 'wb') as f: + # cPickle.dump(poem_ids, f) + + dataloader = DataLoader(base_path,20,w2v.model,'train') + diff --git a/faas_requirements.txt b/faas_requirements.txt new file mode 100644 index 0000000..fff8e36 --- /dev/null +++ b/faas_requirements.txt @@ -0,0 +1,122 @@ +absl-py==0.6.1 +alembic==1.0.3 +astor==0.7.1 +autopep8==1.4 +backcall==0.1.0 +bleach==3.0.2 +certifi==2018.10.15 +chardet==3.0.4 +cloudpickle==0.6.1 +cycler==0.10.0 +Cython==0.29 +dask==0.20.2 +decorator==4.3.0 +defusedxml==0.5.0 +entrypoints==0.2.3 +future==0.16.0 +gast==0.2.0 +grpcio==1.16.0 +h5py==2.8.0 +html5lib==1.0.1 +idna==2.7 +imageio==2.4.1 +imgaug==0.2.6 +ipykernel==5.1.0 +ipython==7.1.1 +ipython-genutils==0.2.0 +ipywidgets==7.4.2 +jedi==0.13.1 +jieba==0.39 +Jinja2==2.10 +jsonschema==2.6.0 +jupyter==1.0.0 +jupyter-client==5.2.3 +jupyter-console==6.0.0 +jupyter-core==4.4.0 +jupyterhub==0.8.1 +jupyterlab==0.31.1 +jupyterlab-launcher==0.10.5 +Keras==2.2.4 +Keras-Applications==1.0.6 +Keras-Preprocessing==1.0.5 +kiwisolver==1.0.1 +Mako==1.0.7 +Markdown==3.0.1 +MarkupSafe==1.0 +matplotlib==3.0.2 +mccabe==0.6.1 +mistune==0.8.4 +nbconvert==5.4.0 +nbformat==4.4.0 +networkx==2.2 +nltk==3.3 +notebook==5.7.0 +numpy==1.15.2 +opencv-python==3.4.3.18 +pamela==0.3.0 +pandas==0.23.4 +pandocfilters==1.4.2 +parso==0.3.1 +pbr==5.1.1 +pexpect==4.6.0 +pickleshare==0.7.5 +Pillow==5.3.0 +pluggy==0.7.1 +prometheus-client==0.4.2 +prompt-toolkit==2.0.7 +protobuf==3.6.1 +ptyprocess==0.6.0 +pycodestyle==2.4.0 +pycurl==7.43.0 +pydocstyle==2.1.1 +pydot==1.2.4 +pyflakes==2.0.0 +Pygments==2.2.0 +pygobject==3.20.0 +pyparsing==2.3.0 +python-apt==1.1.0b1+ubuntu0.16.4.2 +python-dateutil==2.7.3 +python-editor==1.0.3 +python-jsonrpc-server==0.0.1 +python-language-server==0.21.2 +python-oauth2==1.1.0 +pytz==2018.5 +PyWavelets==1.0.1 +PyYAML==3.13 +pyzmq==17.1.2 +qtconsole==4.4.2 +requests==2.20.0 +rope==0.11.0 +scikit-image==0.14.1 +scikit-learn==0.20.0 +scipy==1.1.0 +seaborn==0.9.0 +Send2Trash==1.5.0 +Shapely==1.6.4.post2 +simplegeneric==0.8.1 +six==1.11.0 +sklearn==0.0 +snowballstemmer==1.2.1 +SQLAlchemy==1.2.14 +stevedore==1.29.0 +tensorboard==1.11.0 +tensorflow==1.12.0 +termcolor==1.1.0 +terminado==0.8.1 +testpath==0.4.2 +toolz==0.9.0 +torch==0.4.1 +torchvision==0.2.1 +tornado==5.1.1 +traitlets==4.3.2 +urllib3==1.24.1 +virtualenv==16.0.0 +virtualenv-clone==0.4.0 +virtualenvwrapper==4.8.2 +wcwidth==0.1.7 +webencodings==0.5.1 +Werkzeug==0.14.1 +widgetsnbextension==3.4.2 +word2vec==0.10.2 +xlrd==1.1.0 +yapf==0.24.0