# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
def _searchsorted(array, val, side="left"):
"""
Call np.searchsorted or use a custom binary
search if necessary.
"""
if hasattr(array, "searchsorted"):
return array.searchsorted(val, side=side)
# Python binary search
begin = 0
end = len(array)
while begin < end:
mid = (begin + end) // 2
if val > array[mid]:
begin = mid + 1
elif val < array[mid]:
end = mid
elif side == "right":
begin = mid + 1
else:
end = mid
return begin
[docs]class SortedArray:
"""
Implements a sorted array container using
a list of numpy arrays.
Parameters
----------
data : Table
Sorted columns of the original table
row_index : Column object
Row numbers corresponding to data columns
unique : bool
Whether the values of the index must be unique.
Defaults to False.
"""
def __init__(self, data, row_index, unique=False):
self.data = data
self.row_index = row_index
self.num_cols = len(getattr(data, "colnames", []))
self.unique = unique
@property
def cols(self):
return list(self.data.columns.values())
[docs] def add(self, key, row):
"""
Add a new entry to the sorted array.
Parameters
----------
key : tuple
Column values at the given row
row : int
Row number
"""
pos = self.find_pos(key, row) # first >= key
if (
self.unique
and 0 <= pos < len(self.row_index)
and all(self.data[pos][i] == key[i] for i in range(len(key)))
):
# already exists
raise ValueError(f'Cannot add duplicate value "{key}" in a unique index')
self.data.insert_row(pos, key)
self.row_index = self.row_index.insert(pos, row)
def _get_key_slice(self, i, begin, end):
"""
Retrieve the ith slice of the sorted array
from begin to end.
"""
if i < self.num_cols:
return self.cols[i][begin:end]
else:
return self.row_index[begin:end]
[docs] def find_pos(self, key, data, exact=False):
"""
Return the index of the largest key in data greater than or
equal to the given key, data pair.
Parameters
----------
key : tuple
Column key
data : int
Row number
exact : bool
If True, return the index of the given key in data
or -1 if the key is not present.
"""
begin = 0
end = len(self.row_index)
num_cols = self.num_cols
if not self.unique:
# consider the row value as well
key = key + (data,)
num_cols += 1
# search through keys in lexicographic order
for i in range(num_cols):
key_slice = self._get_key_slice(i, begin, end)
t = _searchsorted(key_slice, key[i])
# t is the smallest index >= key[i]
if exact and (t == len(key_slice) or key_slice[t] != key[i]):
# no match
return -1
elif t == len(key_slice) or (
t == 0 and len(key_slice) > 0 and key[i] < key_slice[0]
):
# too small or too large
return begin + t
end = begin + _searchsorted(key_slice, key[i], side="right")
begin += t
if begin >= len(self.row_index): # greater than all keys
return begin
return begin
[docs] def find(self, key):
"""
Find all rows matching the given key.
Parameters
----------
key : tuple
Column values
Returns
-------
matching_rows : list
List of rows matching the input key
"""
begin = 0
end = len(self.row_index)
# search through keys in lexicographic order
for i in range(self.num_cols):
key_slice = self._get_key_slice(i, begin, end)
t = _searchsorted(key_slice, key[i])
# t is the smallest index >= key[i]
if t == len(key_slice) or key_slice[t] != key[i]:
# no match
return []
elif t == 0 and len(key_slice) > 0 and key[i] < key_slice[0]:
# too small or too large
return []
end = begin + _searchsorted(key_slice, key[i], side="right")
begin += t
if begin >= len(self.row_index): # greater than all keys
return []
return self.row_index[begin:end]
[docs] def range(self, lower, upper, bounds):
"""
Find values in the given range.
Parameters
----------
lower : tuple
Lower search bound
upper : tuple
Upper search bound
bounds : (2,) tuple of bool
Indicates whether the search should be inclusive or
exclusive with respect to the endpoints. The first
argument corresponds to an inclusive lower bound,
and the second argument to an inclusive upper bound.
"""
lower_pos = self.find_pos(lower, 0)
upper_pos = self.find_pos(upper, 0)
if lower_pos == len(self.row_index):
return []
lower_bound = tuple(col[lower_pos] for col in self.cols)
if not bounds[0] and lower_bound == lower:
lower_pos += 1 # data[lower_pos] > lower
# data[lower_pos] >= lower
# data[upper_pos] >= upper
if upper_pos < len(self.row_index):
upper_bound = tuple(col[upper_pos] for col in self.cols)
if not bounds[1] and upper_bound == upper:
upper_pos -= 1 # data[upper_pos] < upper
elif upper_bound > upper:
upper_pos -= 1 # data[upper_pos] <= upper
return self.row_index[lower_pos : upper_pos + 1]
[docs] def remove(self, key, data):
"""
Remove the given entry from the sorted array.
Parameters
----------
key : tuple
Column values
data : int
Row number
Returns
-------
successful : bool
Whether the entry was successfully removed
"""
pos = self.find_pos(key, data, exact=True)
if pos == -1: # key not found
return False
self.data.remove_row(pos)
keep_mask = np.ones(len(self.row_index), dtype=bool)
keep_mask[pos] = False
self.row_index = self.row_index[keep_mask]
return True
[docs] def shift_left(self, row):
"""
Decrement all row numbers greater than the input row.
Parameters
----------
row : int
Input row number
"""
self.row_index[self.row_index > row] -= 1
[docs] def shift_right(self, row):
"""
Increment all row numbers greater than or equal to the input row.
Parameters
----------
row : int
Input row number
"""
self.row_index[self.row_index >= row] += 1
[docs] def replace_rows(self, row_map):
"""
Replace all rows with the values they map to in the
given dictionary. Any rows not present as keys in
the dictionary will have their entries deleted.
Parameters
----------
row_map : dict
Mapping of row numbers to new row numbers
"""
num_rows = len(row_map)
keep_rows = np.zeros(len(self.row_index), dtype=bool)
tagged = 0
for i, row in enumerate(self.row_index):
if row in row_map:
keep_rows[i] = True
tagged += 1
if tagged == num_rows:
break
self.data = self.data[keep_rows]
self.row_index = np.array([row_map[x] for x in self.row_index[keep_rows]])
[docs] def items(self):
"""
Retrieve all array items as a list of pairs of the form
[(key, [row 1, row 2, ...]), ...]
"""
array = []
last_key = None
for i, key in enumerate(zip(*self.data.columns.values())):
row = self.row_index[i]
if key == last_key:
array[-1][1].append(row)
else:
last_key = key
array.append((key, [row]))
return array
[docs] def sort(self):
"""
Make row order align with key order.
"""
self.row_index = np.arange(len(self.row_index))
[docs] def sorted_data(self):
"""
Return rows in sorted order.
"""
return self.row_index
def __getitem__(self, item):
"""
Return a sliced reference to this sorted array.
Parameters
----------
item : slice
Slice to use for referencing
"""
return SortedArray(self.data[item], self.row_index[item])
def __repr__(self):
t = self.data.copy()
t["rows"] = self.row_index
return f"<{self.__class__.__name__} length={len(t)}>\n{t}"