-
Notifications
You must be signed in to change notification settings - Fork 220
Fast TensorAccessor #1396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fast TensorAccessor #1396
Changes from 2 commits
ce679e2
958a187
abe9990
0b20f13
7df8e46
d6865a6
d2857bf
f5e43d7
6235075
1ab3891
19effd7
9ff7866
9e6ba01
46fe73a
d24c709
63aa144
c29592b
c031c61
0fe2f96
ac06718
1ea4266
5d50ebe
e55f01c
d0e5db8
4304c3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,3 +273,5 @@ packages/ | |
| /.idea | ||
| /test/TorchSharpTest/exportsd.py | ||
| .vscode/settings.json | ||
| /TestClear | ||
| TestClear/ | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Collections.Generic; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Diagnostics; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Linq; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Runtime.InteropServices; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using static TorchSharp.PInvoke.NativeMethods; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace TorchSharp.Utils | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -47,6 +48,16 @@ public T[] ToArray() | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (_tensor.ndim < 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return (T[])ToNDArray(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (_tensor.is_contiguous()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
haytham2597 marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| //This is very fast. And work VERY WELL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
haytham2597 marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var shps = _tensor.shape; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long TempCount = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < shps.Length; i++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TempCount *= shps[i]; //Theorically the numel is simple as product of each element shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
NiklasGustafsson marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unsafe { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return new Span<T>(_tensor_data_ptr.ToPointer(), Convert.ToInt32(TempCount)).ToArray(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var result = new T[Count]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CopyTo(result); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -231,8 +242,38 @@ private void validate(long index) | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (index >= Count) throw new IndexOutOfRangeException(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private void CopyContiguous(T[] array, int index=0, int count=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!_tensor.is_contiguous()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new Exception("The tensor is not contiguous"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var shps = _tensor.shape; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long TempCount = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < shps.Length; i++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TempCount *= shps[i]; //Theorically the numel is simple as product of each element shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
NiklasGustafsson marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (count > TempCount || count == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| count = (int)TempCount; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (array is byte[] ba) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Marshal.Copy(_tensor_data_ptr, ba, index, count); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (array is short[] sa) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Marshal.Copy(_tensor_data_ptr, sa, index, count); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if(array is char[] ca) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Marshal.Copy(_tensor_data_ptr, ca, index, count); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (array is long[] la) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Marshal.Copy(_tensor_data_ptr, la, index, count); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (array is float[] fa) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Marshal.Copy(_tensor_data_ptr, fa, index, count); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (array is int[] ia) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Marshal.Copy(_tensor_data_ptr, ia, index, count); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (array is double[] da) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Marshal.Copy(_tensor_data_ptr, da, index, count); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+204
to
+217
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (array is byte[] ba) | |
| Marshal.Copy(_tensor_data_ptr, ba, index, count); | |
| if (array is short[] sa) | |
| Marshal.Copy(_tensor_data_ptr, sa, index, count); | |
| if(array is char[] ca) | |
| Marshal.Copy(_tensor_data_ptr, ca, index, count); | |
| if (array is long[] la) | |
| Marshal.Copy(_tensor_data_ptr, la, index, count); | |
| if (array is float[] fa) | |
| Marshal.Copy(_tensor_data_ptr, fa, index, count); | |
| if (array is int[] ia) | |
| Marshal.Copy(_tensor_data_ptr, ia, index, count); | |
| if (array is double[] da) | |
| Marshal.Copy(_tensor_data_ptr, da, index, count); | |
| bool copied = false; | |
| if (array is byte[] ba) { | |
| Marshal.Copy(_tensor_data_ptr, ba, index, count); | |
| copied = true; | |
| } | |
| if (array is short[] sa) { | |
| Marshal.Copy(_tensor_data_ptr, sa, index, count); | |
| copied = true; | |
| } | |
| if (array is char[] ca) { | |
| Marshal.Copy(_tensor_data_ptr, ca, index, count); | |
| copied = true; | |
| } | |
| if (array is long[] la) { | |
| Marshal.Copy(_tensor_data_ptr, la, index, count); | |
| copied = true; | |
| } | |
| if (array is float[] fa) { | |
| Marshal.Copy(_tensor_data_ptr, fa, index, count); | |
| copied = true; | |
| } | |
| if (array is int[] ia) { | |
| Marshal.Copy(_tensor_data_ptr, ia, index, count); | |
| copied = true; | |
| } | |
| if (array is double[] da) { | |
| Marshal.Copy(_tensor_data_ptr, da, index, count); | |
| copied = true; | |
| } | |
| if (!copied) { | |
| throw new NotSupportedException($"CopyContiguous does not support element type '{typeof(T)}' with array type '{array.GetType()}'."); | |
| } |
Uh oh!
There was an error while loading. Please reload this page.