Chinmaya’s GSoC 2017 Summary: Integration with sklearn & Keras and implementing fastText

Chinmaya Pancholi gensim, Google Summer of Code, Student Incubator

This blog summarizes the work that I did for Google Summer of Code 2017 with Gensim. My work during the summer was divided into two parts: integrating Gensim with scikit-learn & Keras and adding a Python implementation of fastText model to Gensim.

Gensim integration with scikit-learn and Keras

Gensim is a topic modelling and information extraction library which mainly serves unsupervised tasks. However, at several instances, to be able to usefully apply it to a real business problem, the output generated must go to a supervised classifier. Since the most popular supervised learning packages currently are scikit-learn (for simpler data analysis) and Keras (for artificial neural networks), this project aimed to create wrappers for scikit-learn and Keras around all Gensim models. This provided a new interface for using Gensim models and also allowed seamless integration of Gensim with these two libraries.

For providing an API for scikit-learn, I created wrappers for the following Gensim models:

  • Latent Semantic Indexing
  • Latent Dirichlet Allocation
  • Random Projections
  • Author-Topic
  • LDA Seq
  • Word2Vec
  • Doc2Vec
  • Text-to-BagOfWords
  • Term Frequency-Inverse Document Frequency
  • Hierarchical Dirichlet Process
  • Phrases

As acknowledged here, the wrappers developed make using Gensim models very convenient. Also, it becomes very easy to use sklearn constructs like GridSearchCV and Pipeline with Gensim models. For instance, you can now use LDA and Text-to-BOW models in an sklearn Pipeline as follows:


import numpy
import codecs
import pickle
 
from sklearn.pipeline import Pipeline
from sklearn import linear_model
from sklearn.datasets import fetch_20newsgroups
 
from gensim.sklearn_api.ldamodel import LdaTransformer
from gensim.sklearn_api.text2bow import Text2BowTransformer
from gensim.corpora import Dictionary

data = fetch_20newsgroups(subset='train', categories=['alt.atheism', 'comp.graphics', 'sci.space'])

text2bow_model = Text2BowTransformer()
lda_model = LdaTransformer(num_topics=2, passes=10, minimum_probability=0, random_state=numpy.random.seed(0))
clf = linear_model.LogisticRegression(penalty='l2', C=0.1)
text_lda = Pipeline([('bow_model', text2bow_model), ('ldamodel', lda_model), ('classifier', clf)])
text_lda.fit(data.data, data.target)
score = text_lda.score(data.data, data.target)

For more examples of using these transformers, you can check out this notebook.

For integration with Keras, the most important model that we were looking to integrate was the Word2Vec model. This wrapper allows you to use Gensim’s Word2Vec model as part of your Keras model and perform various tasks like computing word similarity and predicting the classes of input words & phrases. Here is an example of employing this integration for a classification task using the 20NewsGroups dataset.


import numpy as np

from gensim.models import word2vec

from sklearn.datasets import fetch_20newsgroups

import keras
from keras.engine import Input
from keras.models import Model
from keras.layers.merge import dot
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils.np_utils import to_categorical
from keras.layers import Dense, Flatten
from keras.layers import Conv1D, MaxPooling1D

model_twenty_ng = word2vec.Word2Vec(min_count=1)

MAX_SEQUENCE_LENGTH = 1000

# Prepare text samples and their labels

# Processing text dataset
texts = [] # list of text samples
texts_w2v = [] # used to train the word embeddings
labels = [] # list of label ids

data = fetch_20newsgroups(subset='train', categories=['alt.atheism', 'comp.graphics', 'sci.space'])
for index in range(len(data)):
    label_id = data.target[index]
    file_data = data.data[index]
    i = file_data.find('\n\n') # skip header
    if i:
        file_data = file_data[i:]
    try:
        curr_str = str(file_data)
        sentence_list = curr_str.split('\n')
        for sentence in sentence_list:
            sentence = (sentence.strip()).lower()
            texts.append(sentence)
            texts_w2v.append(sentence.split(' '))
            labels.append(label_id)
    except:
        None

# Vectorize the text samples into a 2D integer tensor
tokenizer = Tokenizer()
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)

# word_index = tokenizer.word_index
data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)
labels = to_categorical(np.asarray(labels))

x_train = data
y_train = labels

# prepare the embedding layer using the wrapper
Keras_w2v = model_twenty_ng
Keras_w2v.build_vocab(texts_w2v)
Keras_w2v.train(texts, total_examples=Keras_w2v.corpus_count, epochs=Keras_w2v.iter)
Keras_w2v_wv = Keras_w2v.wv
embedding_layer = Keras_w2v_wv.get_embedding_layer()

# create a 1D convnet to solve our classification task
sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu')(embedded_sequences)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu')(x)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu')(x)
x = MaxPooling1D(35)(x) # global max pooling
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
preds = Dense(y_train.shape[1], activation='softmax')(x)

model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
fit_ret_val = model.fit(x_train, y_train, epochs=1)

More such examples can be seen from this tutorial notebook.

Since this part of the work was a joint project with the library shorttext, I made several changes to shorttext’s codebase as well. This included updating CNNWordEmbed, DoubleCNNWordEmbed & CLSTMWordEmbed classes to have a with_gensim=True codepath as per the changes made in PR #1248 and setting up TravisCI with the library’s Github repository.

Here is an example showing the change that was made to CNNWordEmbed class:


import os
import shorttext

# download w2v model
os.system('wget https://raw.githubusercontent.com/chinmayapancholi13/shorttext_test_data/master/test_w2v_model')

w2v_model = shorttext.utils.load_word2vec_model('test_w2v_model', binary=False) # load word2vec model
trainclass_dict = shorttext.data.subjectkeywords() # load training data

# with_gensim = False
# create keras model using `CNNWordEmbed` class
keras_model_without_gensim = shorttext.classifiers.frameworks.CNNWordEmbed(wvmodel=w2v_model, nb_labels=len(trainclass_dict.keys()), vecsize=100, with_gensim=False)
# create and train classifier using keras model constructed above
main_classifier = shorttext.classifiers.VarNNEmbeddedVecClassifier(w2v_model, with_gensim=False, vecsize=100)
main_classifier.train(trainclass_dict, keras_model_without_gensim, nb_epoch=2)
# compute classification score
score_vals = main_classifier.score('artificial intelligence')

# with_gensim = True
# create keras model using `CNNWordEmbed` class
keras_model_with_gensim = shorttext.classifiers.frameworks.CNNWordEmbed(wvmodel=w2v_model, nb_labels=len(trainclass_dict.keys()), vecsize=100, with_gensim=True)
# create and train classifier using keras model constructed above
main_classifier = shorttext.classifiers.VarNNEmbeddedVecClassifier(w2v_model, with_gensim=True, vecsize=100)
main_classifier.train(trainclass_dict, keras_model_with_gensim, nb_epoch=2)
# compute classification score
score_vals = main_classifier.score('artificial intelligence')

Implementing fastText in Gensim

fastText is a word-embedding and classification library released recently by Facebook Research, which performs better than Word2Vec on syntactic tasks and trains much faster for supervised text classification. fastText’s training architecture is an extension of Word2Vec as it takes into account the n-gram features for the words rather than just obtaining a vector for each word in the vocabulary. Having vectors for n-grams also allows us to get word-vectors for out-of-vocabulary words. The original implementation for fastText by Facebook is in C++ and the aim of this project was to add a Python implementation to Gensim. This allows us to continue training of previously saved fastText models after loading and also makes it easy to work with small input examples conveniently without using unnecessary memory space, both of which are problems with the original C++ implementation currently.

As can be seen in this PR, for the Python implementation, I created a new class FastText which is a subclass of Word2Vec class and thus inherits and uses several of its parameters and functions directly. The main work involved here was creating functions for training the fastText model suitably. This included computing n-grams for the words present in the vocabulary and creating functions for training the n-gram vectors through backpropagation for various training modes such as skipgram, continuous-bag-of-words, hierarchical softmax and negative sampling. Since the architecture for backpropagation in fastText remains the same as that of Word2Vec, I could modify the exising functions train_cbow_pair() and train_sg_pair() to add a different codepath for fastText. I also reused some functions like compute_ngrams() and ft_hash() from the existing wrapper for fastText’s C++ code in Gensim.

For checking the correctness of my implementation, I used 3 functions to compare the results of the Python code with the original C++ code. These were: accuracy(), evaluate_word_pairs() and most_similar(). The values obtained on training on the 100 MB text8 corpus are as follows:

Results for accuracy() function

Training mode Semantic accuracy (Facebook) Semantic accuracy (Gensim) Syntactic accuracy (Facebook) Syntactic accuracy (Gensim)
sg, neg 4.82% 5.95% 57.86% 59.83%
sg , hs 12.99% 13.16% 60.89% 60.18%
cbow , neg 3.73% 4.19% 62.82% 64.61%
cbow , hs 10.14% 7.99% 63.92% 64.97%

Results for evaluate_word_pairs() function

Training mode Pearson correlation coefficient (Facebook) Pearson correlation coefficient (Gensim) Spearman rank-order correlation coefficient (Facebook) Spearman rank-order correlation coefficient (Gensim)
sg, neg (0.40642977066642144, 2.1511469366788656e-15) (0.43749669194085616, 7.6451280468087625e-18) SpearmanrResult(correlation=0.41465584255009436, pvalue=5.112038356665781e-16) SpearmanrResult(correlation=0.44928735341130571, pvalue=7.6879677395332348e-19)
sg , hs (0.52233432580563355, 5.8099948642054821e-26) (0.47297685624627656, 5.7630503659542615e-21) SpearmanrResult(correlation=0.5286759766618534, pvalue=1.1513403646883705e-26) SpearmanrResult(correlation=0.47417053075791682, pvalue=4.4579189201637095e-21)
cbow , neg (0.36714196000781285, 1.2167389084054417e-12) (0.42588494970811736, 6.738509734563405e-17) SpearmanrResult(correlation=0.37289717141572926, pvalue=5.0666684225289457e-13) SpearmanrResult(correlation=0.42639222170565849, pvalue=6.1380130000923238e-17)
cbow , hs (0.48189888637133277, 8.2496718536419455e-22) (0.45033137885363711, 6.2459352070044733e-19) SpearmanrResult(correlation=0.48553283320716889, pvalue=3.6765481242346979e-22) SpearmanrResult(correlation=0.45345936320558961, pvalue=3.3376901760689683e-19)

Results for most_similar() function
The top-10 words for the (cbow, neg) training model for the input word ‘night‘ were obtained to be:

  • For Gensim implementation :
    [(u’midnight’, 0.9214520454406738), (u’nightjar’, 0.8952612280845642), (u’tonight’, 0.8734667897224426), (u’nighthawk’, 0.8727679252624512), (u’nightbreed’, 0.8692173361778259), (u’nightfall’, 0.8459283709526062), (u’nightmare’, 0.8459077477455139), (u’nighttime’, 0.8353838920593262), (u’mcknight’, 0.8227508068084717), (u’nightjars’, 0.8224337697029114)]
  • For Facebook implementation :
    [(u’midnight’, 0.9323179721832275), (u’nightjar’, 0.9195586442947388), (u’nighthawk’, 0.8968080282211304), (u’nightfall’, 0.8818791508674622), (u’mcknight’, 0.8758728504180908), (u’nightbreed’, 0.8738420009613037), (u’tonight’, 0.8719567656517029), (u’nightmare’, 0.857421875), (u’nightjars’, 0.8562690019607544), (u’nighttime’, 0.8551853895187378)]

As can be observed, all the top-10 words are the same in the two implementations so we are getting very good overlap for the results for most_similar() type of queries.

Here is an example for training the model and using most_similar() function:


from gensim.models.fasttext import FastText

sentences = [
 ['human', 'interface', 'computer'],
 ['survey', 'user', 'computer', 'system', 'response', 'time'],
 ['eps', 'user', 'interface', 'system'],
 ['system', 'human', 'system', 'eps'],
 ['user', 'response', 'time'],
 ['trees'],
 ['graph', 'trees'],
 ['graph', 'minors', 'trees'],
 ['graph', 'minors', 'survey']
]

model = FastText(size=10, min_count=1, hs=1, negative=0)
model.build_vocab(sentences)

model.train(sentences, total_examples=model.corpus_count, epochs=model.iter)
sims = model.most_similar('graph', topn=10)

For more examples, you can check out this tutorial notebook for fastText.

You can find all my PRs for Gensim here and for shorttext here. I had also been maintaining a live-blog for GSoC here where I was posting weekly updates. Although GSoC period is over now, I plan to work with Gensim for Cythonisation of the Python implementation of fastText.

Last but not the least, I want to thank all my mentors – Lev, Ivan, Jayant, Stephen and Radim – for guiding me throughout the GSoC period. This was a great experience for me where I not only learned a lot of new things but also realised how much there still is to learn. 🙂