2015-08-30 9 views
3

У меня есть два массива numpy, users и dat. Для каждого пользователя в users Мне нужно найти данные, относящиеся к пользователю в dat и подсчитать количество уникальных значений. Мне нужно обработать случай, где len(users)=200000 и len(dat)=2800000. В настоящее время я не использую тот факт, что dat отсортирован, что делает метод очень медленным. Как мне это сделать?Подсчет количества уникальных значений в подмножестве отсортированной матрицы

Значение 'other' в dat просто показывает, что другие значения будут присутствовать и в структурированном массиве.

import numpy as np 

users = np.array([111, 222, 333]) 
info = np.zeros(len(users)) 
dt = [('id', np.int32), ('group', np.int16), ('other', np.float)] 
dat = np.array([(111, 1, 0.0), (111, 3, 0.0), (111, 2, 0.0), (111, 1, 0.0), 
       (222, 1, 0.0), (222, 1, 0.0), (222, 4, 0.0), 
       (333, 2, 0.0), (333, 1, 0.0), (333, 2, 0.0)], 
       dtype=dt) 

for i, u in enumerate(users): 
    u_dat = dat[np.in1d(dat['id'], u)] 
    uniq = set(u_dat['group']) 
    info[i] = int(len(uniq)) 

print info 
+0

В C, вам петлю и увеличивать свой счетчик, когда текущее значение = предыдущее значение!. Это, вероятно, не полезно здесь, поскольку элементы цикла в python обычно не так, как вы пишете быстрый код numpy. –

ответ

2

Если вы хотите получить прибыль от векторизации Numpy, это помогло бы значительно, если вы можете удалить все дубликаты из dat, прежде чем руки. Вы можете найти первое и последнее вхождение значения с двумя вызовами searchsorted:

dat_unq = np.unique(dat) 
first = dat_unq['id'].searchsorted(users, side='left') 
last = dat_unq['id'].searchsorted(users, side='right') 
info = last - first 

Это будет только полезно, если вы собираетесь искать много записей в dat. Если это меньшая часть, вы можете использовать два вызова searchsorted, чтобы выяснить, какие нарезает назвать unique на:

info = np.empty_like(users, dtype=np.intp) 
first = dat['id'].searchsorted(users, side='left') 
last = dat['id'].searchsorted(users, side='right') 
for idx, (start, stop) in enumerate(zip(first, last)): 
    info[idx] = len(np.unique(dat[start:stop]))