2015-09-10 4 views
2

У меня есть 2D-Numpy массив расстояний:Получить индекс для argmin из 2d массива Numpy

a = np.array([[2.0, 12.1, 99.2], 
       [1.0, 1.1, 1.2], 
       [1.04, 1.05, 1.5], 
       [4.1, 4.2, 0.2], 
       [10.0, 11.0, 12.0], 
       [3.9, 4.9, 4.99] 
      ]) 

Мне нужна функция, которая оценивает каждую строку и возвращает индекс столбца для столбца, который имеет наименьшее значение. Конечно, это может быть сделано тривиально, выполнив:

np.argmin(a, axis=1) 

, который дает:

[0, 0, 0, 2, 0, 0] 

Однако, у меня есть несколько ограничений:

  1. Оценка argmin следует рассматривать лишь отдаляет ниже значение 5.0. Если ни одно из расстояний в строке не было ниже 5.0, тогда возвратите '-1' в качестве индекса
  2. Список индексов, возвращаемых для всех строк, должен быть уникальным (т. Е. Если две или несколько строк заканчиваются одним и тем же индексом столбца, тогда строка с меньшим расстоянием до заданного индекса столбца получает приоритет, а все остальные строки должны возвращать другой индекс столбца). Я предполагаю, что это сделает проблему итеративной, поскольку, если одна из строк набит, тогда она может впоследствии столкнуться с другой строкой с тем же индексом столбца.
  3. Любые нераспределенные строки должен возвращать '-1'

Таким образом, окончательный вывод должен выглядеть следующим образом:

[-1, 0, 1, 2, -1, -1] 

Один отправной точки было бы:

  1. выполнить argsort
  2. присваивать уникальные индексы колонн
  3. удалить е присвоены индексы столбцов из каждой строки
  4. Разрешая тай-брейки
  5. повторите шаги 2-4 до тех пор, как все индексы столбцов не назначены

Есть ли простой способ сделать это в Python?

+0

Итак, в чем вопрос? – wwii

+0

Как первый элемент ожидаемого o/p a 'nan', учитывая, что первая строка имеет' 2.0' в нем, которая меньше, чем '5.0'? Или вы имеете в виду, что все элементы в строке должны быть меньше, чем '5.0'? – Divakar

+0

Я не согласен с вашим новым ожидаемым выходом. В строке '3' столбец' 2' наименьший ('0,2'), но у вас есть' 'np.argmin (a, 1)' is '1', даже без ваших ограничений. Для новых 'a',' np.argmin (a, 1) 'дает' array ([0, 0, 0, 2, 0, 0]) ', поэтому окончательный вывод должен быть' array ([-1, 0, -1, 2, -1, -1]) ', я думаю. – askewchan

ответ

0

Этот цикл по числу столбцов, которые я предполагаю, меньше, чем число строк:

def find_smallest(a): 
    i = np.argmin(a, 1) 
    amin = a[np.arange(len(a)), i] # faster than a.min(1)? 
    toobig = amin >=5 
    i[toobig] = -1 
    for u, c in zip(*np.unique(i, return_counts=True)): 
     #u, c are the unique values and number of occurrences in `i` 
     if c < 2: 
      # no repeats of this index 
      continue 
     mask = i==u # the values in i that match u, which has repeats 
     notclosest = np.where(mask)[0].tolist() # indices of the repeats 
     notclosest.pop(np.argmin(amin[mask])) # the smallest a value is not a 'repeat', remove it from the list 
     i[notclosest] = -1 # and mark all the repeats as -1 
    return i 

Обратите внимание, я использовал -1 вместо np.nan, так как массив индекса int. Любое сокращение булевской индексации поможет. Я хотел использовать один из дополнительных дополнительных выходов от np.unique(i), но не смог.

+0

Не могли бы вы предоставить немного более подробную информацию о том, что происходит в цикле for? – slaw

+0

Похоже, что ваше решение не может обрабатывать 'a = np.array ([[2.0, 12.1, 99.2], [1.0, 1.1, 1.2], [1.04, 1.05, 1.5], [4.1, 4.2, 0.2] , [10.0, 11.0, 12.0], [3.9, 4.9, 4.99] ]) 'В частности, индекс столбца, присвоенный третьей строке, будет неправильным при применении вашего метода. – slaw

+0

@slaw. Редактируйте свой вопрос с помощью этих данных образца и сообщить нам ожидаемый результат, так как этот образец будет лучше проверять эти ограничения? – Divakar