1

Я пытаюсь использовать стратегию перекрестной проверки TimeSeriesSplit в версии 0.18.1 sklearn с оценкой LogisticRegression. Я получаю ошибку о том, что:sklearn TimeSeriesSplit cross_val_predict работает только для разделов

cross_val_predict работает только для разделов

Следующий фрагмент кода показывает, как воспроизвести:

from sklearn import linear_model, neighbors 
from sklearn.model_selection import train_test_split, cross_val_predict, TimeSeriesSplit, KFold, cross_val_score 
import pandas as pd 
import numpy as np 
from datetime import date, datetime 

df = pd.DataFrame(data=np.random.randint(0,10,(100,5)), index=pd.date_range(start=date.today(), periods=100), columns='x1 x2 x3 x4 y'.split()) 


X, y = df['x1 x2 x3 x4'.split()], df['y'] 
score = cross_val_score(linear_model.LogisticRegression(fit_intercept=True), X, y, cv=TimeSeriesSplit(n_splits=2)) 
y_hat = cross_val_predict(linear_model.LogisticRegression(fit_intercept=True), X, y, cv=TimeSeriesSplit(n_splits=2), method='predict_proba') 

Что я делаю не так?

ответ

5

Есть несколько способов передать аргумент cv в cross_val_score. Здесь вам нужно передать генератор для расщепления. Например,

y = range(14) 
cv = TimeSeriesSplit(n_splits=2).split(y) 

дает генератор. С помощью этого вы можете генерировать CV и тестовые массивы. Первый выглядит следующим образом:

print cv.next() 
    (array([0, 1, 2, 3, 4, 5, 6, 7]), array([ 8, 9, 10, 11, 12, 13])) 

Вы также можете взять dataframe в качестве входных данных для split.

df = pd.DataFrame(data=np.random.randint(0,10,(100,5)), 
        index=pd.date_range(start=date.today(), 
        periods=100), columns='x1 x2 x3 x4 y'.split()) 

cv = TimeSeriesSplit(n_splits=2).split(df) 
print cv.next() 
    (array([ 0, 1, 2, ..., 31, 32, 33]), array([34, 35, 36, ..., 64, 65, 66])) 

В вашем случае это должно работать:

score = cross_val_score(linear_model.LogisticRegression(fit_intercept=True), 
         X, y, cv=TimeSeriesSplit(n_splits=2).split(df)) 

Посмотрите cross_val_score и TimeSeriesSplit для деталей.

+0

Что такое '14' здесь, в' range (14) '? Это какое-то произвольное число? То же самое с значениями dataframe. Мне не очень понятно, как вы пришли к этим ценностям. – keithhackbarth

+0

14 является произвольным и, как указано, примером. Что касается кадра данных: посмотрите на вопрос, я просто скопировал его. – glao