Вероятно, чистый способ сделать это в NumPy, особенно если у вас есть много классов, через сортировки:
SAMPLES = 50000
FEATURES = 784
CLASSES = 10
data = np.random.rand(SAMPLES, FEATURES)
classes = np.random.randint(CLASSES, size=SAMPLES)
sorter = np.argsort(classes)
classes_sorted = classes[sorter]
splitter, = np.where(classes_sorted[:-1] != classes_sorted[1:])
data_splitted = np.split(data[sorter], splitter + 1)
data_splitted
будет список массивов, один для каждого класс найден в classes
. Запуск выше кода с SAMPLES = 10
, FEATURES = 2
и CLASSES = 3
я получаю:
>>> data
array([[ 0.45813694, 0.47942962],
[ 0.96587082, 0.73260743],
[ 0.70539842, 0.76376921],
[ 0.01031978, 0.93660231],
[ 0.45434223, 0.03778273],
[ 0.01985781, 0.04272293],
[ 0.93026735, 0.40216376],
[ 0.39089845, 0.01891637],
[ 0.70937483, 0.16077439],
[ 0.45383099, 0.82074859]])
>>> classes
array([1, 1, 2, 1, 1, 2, 0, 2, 0, 1])
>>> data_splitted
[array([[ 0.93026735, 0.40216376],
[ 0.70937483, 0.16077439]]),
array([[ 0.45813694, 0.47942962],
[ 0.96587082, 0.73260743],
[ 0.01031978, 0.93660231],
[ 0.45434223, 0.03778273],
[ 0.45383099, 0.82074859]]),
array([[ 0.70539842, 0.76376921],
[ 0.01985781, 0.04272293],
[ 0.39089845, 0.01891637]])]
Если вы хотите, чтобы убедиться, что сорт является стабильным, т.е. точек данных в одном классе остается в том же относительном порядке после сортировки, вы будете необходимо указать sorter = np.argsort(classes, kind='mergesort')
.
Лучшее решение зависит от параметров, таких как размер данных (вся матрица вписывается в память?), Форма данных (массив данных numpy, строки, ....), ..... Благодаря точным данным пунктам. –