Skip to content

Commit 58ff2a3

Browse files
authored
Merge pull request #835 from kitsudaiki/feat/improve-task-processing
related issue: #834
2 parents 7c28baf + 7babb2b commit 58ff2a3

18 files changed

Lines changed: 126 additions & 81 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
- quota-management was added to the admin-section of the dashboard
2828
- dataset list in the dashboard now prints the number of rows and columns
2929
- cluster list in the dashboard now prints the cluster adress
30+
- added forecast-length as new parameter for train-tasks (only via python-sdk at the moment)
3031

3132
### Changed
3233

@@ -46,7 +47,8 @@
4647
- fixed broken version-output of the pre-build docker images
4748
- restart of hanami and sakura instance in regard of the cluster was now fixed to avoid broken cluster and hosts after restart
4849
- restarts of the same sakura-host doesn't result in duplications in hanami anymore
49-
50+
- progress-bar in dashboard for multiple epochs is now calculated correctly
51+
- fixed number of output-values in case of int and float output-type
5052

5153
## v0.10.0
5254

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@ overflow-checks = true
2424
[profile.release]
2525
opt-level = 3
2626
lto = "fat"
27-
codegen-units = 1
27+
codegen-units = 8
2828
strip = true

src/binaries/sakura/src/api/http_endpoints/model/task/checkpoint_restore_task_v1_0.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ pub async fn checkpoint_restore_task(
8888
model_uuid: *model_uuid,
8989
name: body.name.clone(),
9090
info: TaskVariant::CheckpointRestore(info),
91-
meta: TaskMeta::new(1, 1, 1),
91+
meta: TaskMeta::new(1, 1, 1, 0),
9292
};
9393
super::add_task_to_model(task, &task_type, &context)?;
9494

src/binaries/sakura/src/api/http_endpoints/model/task/checkpoint_save_task_v1_0.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ pub async fn checkpoint_save_task(
8989
model_uuid: *model_uuid,
9090
name: body.name.clone(),
9191
info: TaskVariant::CheckpointSave(info),
92-
meta: TaskMeta::new(1, 1, 1),
92+
meta: TaskMeta::new(1, 1, 1, 0),
9393
};
9494
super::add_task_to_model(task, &task_type, &context)?;
9595

src/binaries/sakura/src/api/http_endpoints/model/task/create_request_task_v1_0.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub async fn create_request_task(
141141
model_uuid: *model_uuid,
142142
name: body.name.clone(),
143143
info: TaskVariant::Request(Box::new(info)),
144-
meta: TaskMeta::new(number_of_cycles, 1, time_length),
144+
meta: TaskMeta::new(number_of_cycles, 1, time_length, 0),
145145
};
146146
super::add_task_to_model(task, &task_type, &context)?;
147147

src/binaries/sakura/src/api/http_endpoints/model/task/create_train_task_v1_0.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ pub async fn create_train_task(
5353
let task_uuid = Uuid::new_v4();
5454
let task_type = TaskType::Train;
5555
let time_length = body.time_length.unwrap_or(1);
56+
let forecast_length = body.forecast_length.unwrap_or(0);
5657
let mut number_of_cycles = u64::MAX;
5758

5859
if time_length < 1 {
@@ -111,21 +112,31 @@ pub async fn create_train_task(
111112
}
112113

113114
// handle the time-lenght-value
114-
if number_of_cycles < time_length {
115+
let cycle_length = time_length + forecast_length;
116+
if number_of_cycles < cycle_length {
115117
let msg = format!(
116-
"Time-length {time_length} is bigger than at least of of the seleced datasets."
118+
"Time-length {cycle_length} is bigger than at least of of the seleced datasets."
117119
);
118120
return Err(ErrorResponse::BadRequest(msg));
119121
}
122+
// TODO: seems not fully correct calculated
120123
number_of_cycles -= time_length - 1;
124+
if forecast_length > 0 {
125+
number_of_cycles /= forecast_length;
126+
}
121127

122128
// create new task
123129
let task = Task {
124130
uuid: task_uuid,
125131
model_uuid: *model_uuid,
126132
name: body.name.clone(),
127133
info: TaskVariant::Training(info),
128-
meta: TaskMeta::new(number_of_cycles, body.number_of_epochs, time_length),
134+
meta: TaskMeta::new(
135+
number_of_cycles,
136+
body.number_of_epochs,
137+
time_length,
138+
forecast_length,
139+
),
129140
};
130141
super::add_task_to_model(task, &task_type, &context)?;
131142

src/binaries/sakura/src/core/blocks/core_block.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ use super::super::processing::worker_queue::*;
4242
pub struct Synapse {
4343
/// The threshold value that must be exceeded for the synapse to activate.
4444
pub border: f32,
45-
/// First synaptic weight value.
45+
/// synaptic weight values.
4646
pub weight_1: f32,
47-
/// Second synaptic weight value.
4847
pub weight_2: f32,
4948

5049
/// Counter tracking how many times this synapse has been activated.

src/binaries/sakura/src/core/blocks/input_block.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ impl InputBlock {
144144
/// * `input_size` - Size of the input data.
145145
/// * `offset` - Offset to apply to the input data.
146146
/// * `time_length` - Length of the time dimension of the input data.
147-
/// * `allow_cration` - Whether to allow creation of new input links.
147+
/// * `allow_creation` - Whether to allow creation of new input links.
148148
pub fn apply_input(
149149
&mut self,
150150
input_ptr: &[f32],
151151
input_size: usize,
152152
offset: usize,
153153
time_length: usize,
154-
allow_cration: bool,
154+
allow_creation: bool,
155155
) {
156156
// resize links, if necessary
157157
let maximum_size = input_size * 2 * time_length;
@@ -162,7 +162,7 @@ impl InputBlock {
162162
let mut is_negative;
163163
let mut total_position;
164164

165-
if allow_cration {
165+
if allow_creation {
166166
// update links
167167
for (i, val) in input_ptr.iter().enumerate().take(input_size) {
168168
total_position = (offset + i) * 2;

src/binaries/sakura/src/core/model_interface.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ impl ModelInterface {
218218
model_data_handler.get_output_buffer(&self.model_uuid, hexagon_name)?;
219219

220220
let mut output_buffer = output_buffer_mutex.lock().expect("mutex poisoned");
221-
data.resize(output_buffer.output_neurons.len(), 0.0f32);
222221
convert_output_to_buffer(data, &mut output_buffer);
223222
}
224223

src/binaries/sakura/src/core/processing/finish_counter.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,12 @@ use super::tasks::Task;
2121
/// within a specific cycle, and can trigger actions when the completion criteria are met.
2222
#[derive(Default, Debug)]
2323
pub struct FinishCounter {
24-
/// Number of input tasks that have been compared in the current cycle.
2524
pub input_compare: usize,
26-
27-
/// Number of output tasks that have been compared in the current cycle.
2825
pub output_compare: usize,
29-
30-
/// The threshold number of task comparisons needed to consider the cycle finished.
3126
task_compare: usize,
32-
33-
/// Current count of completed tasks in the current cycle.
3427
counter: usize,
35-
36-
/// The cycle number we're expecting to complete.
3728
expected_cycle_number: u64,
38-
39-
/// Flag indicating whether this cycle has already been marked as finished.
4029
already_finished: bool,
41-
42-
/// Optional reference to the task that this counter is associated with.
4330
pub task: Option<Arc<Mutex<Task>>>,
4431
}
4532

0 commit comments

Comments
 (0)