Wrap Statsmodels In Sklearn

21 May 2018

So, you want to use scikitlearn’s cross-validation framework, but want to use a statsmodels object with the formula syntax? No problem.

David Dale has an excellent SO post, but his answer is for the standard statsmodels.api. Here’s a slight adaptation for statsmodels.formula.api:

from sklearn.base import BaseEstimator, RegressorMixin

class SMFormulaWrapper(BaseEstimator, RegressorMixin):
    """ A sklearn-style wrapper for formula-based statsmodels regressors """
    def __init__(self, model_class, formula):
        self.model_class = model_class
        self.formula = formula
    def fit(self, X, y=None):
        self.model_ = self.model_class(self.formula, data=X)
        self.results_ = self.model_.fit()
    def predict(self, X):
        return self.results_.predict(X)

Then you can do something like:

import statsmodels.formula.api as smf
formula = 'y  ~ numeric_var + categorical_var1'
cv_ols = cross_val_score(SMFormulaWrapper(smf.ols, formula), Xy, None, cv=3)

You don’t need to pass in a y, since that information is already included in formula