Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def get_unit_spike_train(
per unit and per segment compact in memory.
Using the cache makes the first call quite slow but then future calls are very fast.

Note : if use_cache=False but the good lexsorted cache is already computed then it will be used anyway.

Returns
-------
spike_train : np.ndarray
Expand All @@ -188,9 +190,14 @@ def get_unit_spike_train(
)

segment_index = self._check_segment_index(segment_index)

lexsort_key = ("sample_index", "segment_index", "unit_index")
if lexsort_key in self._cached_lexsorted_spike_vector.keys():
use_cache = True

if use_cache:
ordered_spike_vector, slices = self.to_reordered_spike_vector(
lexsort=("sample_index", "segment_index", "unit_index"),
lexsort=lexsort_key,
return_order=False,
return_slices=True,
)
Expand Down
23 changes: 18 additions & 5 deletions src/spikeinterface/core/unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,24 @@ def _compute_and_cache_spike_vector(self) -> None:
all_old_unit_ids=self._parent_sorting.unit_ids,
all_new_unit_ids=self._unit_ids,
)
# lexsort by segment_index, sample_index, unit_index
sort_indices = np.lexsort(
(spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"])
)
self._cached_spike_vector = spike_vector[sort_indices]

# check if order is preserved
pos = np.searchsorted(self._parent_sorting.unit_ids, self.unit_ids)
order_is_preserved = np.all(np.diff(pos)>0)
print('order_is_preserved', order_is_preserved)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this


if not order_is_preserved:
# note from Sam:
# this can be a very high cost and make big dataset very slow
# the only goal of this is to ensure the unit_index order when the sample is the same
# (and maybe this is not so usefull!!! : to be discussed)

# lexsort by segment_index, sample_index, unit_index
sort_indices = np.lexsort(
(spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"])
)
spike_vector = spike_vector[sort_indices]
self._cached_spike_vector = spike_vector


class UnitsSelectionSortingSegment(BaseSortingSegment):
Expand Down
59 changes: 46 additions & 13 deletions src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SortingAnalyzer,
)
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.core.base import minimum_spike_dtype

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
raise ImportError(self.installation_mesg)

phy_folder = Path(folder_path)
spike_times = np.load(phy_folder / "spike_times.npy").astype(int)
spike_times = np.load(phy_folder / "spike_times.npy").astype("int64")

if (phy_folder / "spike_clusters.npy").is_file():
spike_clusters = np.load(phy_folder / "spike_clusters.npy")
Expand All @@ -83,8 +84,8 @@ def __init__(
spike_times = np.atleast_1d(spike_times.squeeze())
spike_clusters = np.atleast_1d(spike_clusters.squeeze())

clust_id = np.unique(spike_clusters)
unique_unit_ids = [int(c) for c in clust_id]
unique_unit_ids = np.unique(spike_clusters).astype("int64")

params = read_python(str(phy_folder / "params.py"))
sampling_frequency = params["sample_rate"]

Expand Down Expand Up @@ -151,10 +152,15 @@ def __init__(
cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids}")

# update spike clusters and times values
bad_clusters = [clust for clust in clust_id if clust not in cluster_info["cluster_id"].values]
spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters)
spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs]
spike_times_clean = spike_times[spike_clusters_clean_idxs]
bad_clusters = [clust for clust in unique_unit_ids if clust not in cluster_info["cluster_id"].values]
if len(bad_clusters) > 0:
# if no bad cluster we avoid this data reduction wich cost a lot for long dataset
spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters)
spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs]
spike_times_clean = spike_times[spike_clusters_clean_idxs]
else:
spike_clusters_clean = spike_clusters
spike_times_clean = spike_times

if "si_unit_id" in cluster_info.columns:
unit_ids = cluster_info["si_unit_id"].values
Expand All @@ -180,7 +186,7 @@ def __init__(
idx = np.searchsorted(from_values, spike_clusters_clean, sorter=sort_idx)
spike_clusters_new = unit_ids[sort_idx][idx]

unit_ids = unit_ids.astype(int)
unit_ids = unit_ids.astype("int64")
spike_clusters_clean = spike_clusters_new
del cluster_info["si_unit_id"]
else:
Expand Down Expand Up @@ -224,20 +230,47 @@ def __init__(

self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean))

def _compute_and_cache_spike_vector(self) -> None:
# make the spike_vector fast using the internal spike_times/spike_clusters
# with a small mapping id to index
# the order for 2 units with the same sample_index is not garanty here but should be OK

unit_ids = self.unit_ids

# mapping unit_id to unit_index
mapping = -np.ones(np.max(unit_ids) + 1, dtype="int64")
for unit_ind, unit_id in enumerate(unit_ids):
mapping[unit_id] = unit_ind

spike_times = self.segments[0]._all_spike_times
spike_clusters = self.segments[0]._all_clusters
n = spike_times.size
spikes = np.zeros(n, dtype=minimum_spike_dtype)
spikes["sample_index"] = spike_times
spikes["unit_index"] = mapping[spike_clusters]
# This is useless because phy is always one segment
# spikes["segment_index"] = 0

self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = np.zeros((1, 2), dtype="int64")
self._cached_spike_vector_segment_slices[0, 1] = n


class PhySortingSegment(BaseSortingSegment):
def __init__(self, all_spikes, all_clusters):
def __init__(self, all_spike_times, all_clusters):
BaseSortingSegment.__init__(self)
self._all_spikes = all_spikes
self._all_spike_times = all_spike_times
self._all_clusters = all_clusters

def get_unit_spike_train(self, unit_id, start_frame, end_frame):
start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left")
start = 0 if start_frame is None else np.searchsorted(self._all_spike_times, start_frame, side="left")
end = (
len(self._all_spikes) if end_frame is None else np.searchsorted(self._all_spikes, end_frame, side="left")
len(self._all_spike_times)
if end_frame is None
else np.searchsorted(self._all_spike_times, end_frame, side="left")
) # Exclude end frame

spike_times = self._all_spikes[start:end][self._all_clusters[start:end] == unit_id]
spike_times = self._all_spike_times[start:end][self._all_clusters[start:end] == unit_id]
return np.atleast_1d(spike_times.copy().squeeze())


Expand Down
Loading