@@ -488,7 +488,7 @@ def device_ptr(self):
488488 Note
489489 ----
490490 - This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
491- - No other arrays will share the same device pointer.
491+ - No other arrays will share the same device pointer.
492492 - A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory.
493493 - In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array.
494494 """
@@ -985,6 +985,12 @@ def __getitem__(self, key):
985985 try:
986986 out = Array()
987987 n_dims = self.numdims()
988+
989+ if (isinstance(key, Array) and key.type() == Dtype.b8.value):
990+ n_dims = 1
991+ if (count(key) == 0):
992+ return out
993+
988994 inds = _get_indices(key)
989995
990996 safe_call(backend.get().af_index_gen(ct.pointer(out.arr),
@@ -1005,9 +1011,21 @@ def __setitem__(self, key, val):
10051011 try:
10061012 n_dims = self.numdims()
10071013
1014+ is_boolean_idx = isinstance(key, Array) and key.type() == Dtype.b8.value
1015+
1016+ if (is_boolean_idx):
1017+ n_dims = 1
1018+ num = count(key)
1019+ if (num == 0):
1020+ return
1021+
10081022 if (_is_number(val)):
10091023 tdims = _get_assign_dims(key, self.dims())
1010- other_arr = constant_array(val, tdims[0], tdims[1], tdims[2], tdims[3], self.type())
1024+ if (is_boolean_idx):
1025+ n_dims = 1
1026+ other_arr = constant_array(val, int(num), dtype=self.type())
1027+ else:
1028+ other_arr = constant_array(val, tdims[0] , tdims[1], tdims[2], tdims[3], self.type())
10111029 del_other = True
10121030 else:
10131031 other_arr = val.arr
@@ -1017,8 +1035,8 @@ def __setitem__(self, key, val):
10171035 inds = _get_indices(key)
10181036
10191037 safe_call(backend.get().af_assign_gen(ct.pointer(out_arr),
1020- self.arr, ct.c_longlong(n_dims), inds.pointer,
1021- other_arr))
1038+ self.arr, ct.c_longlong(n_dims), inds.pointer,
1039+ other_arr))
10221040 safe_call(backend.get().af_release_array(self.arr))
10231041 if del_other:
10241042 safe_call(backend.get().af_release_array(other_arr))
@@ -1235,5 +1253,5 @@ def read_array(filename, index=None, key=None):
12351253
12361254 return out
12371255
1238- from .algorithm import sum
1256+ from .algorithm import ( sum, count)
12391257from .arith import cast
0 commit comments