File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -86,6 +86,12 @@ fn main() -> Result<(), GraphError> {
8686
8787 gpt. sync ( ) ?;
8888
89+ let mut ts_file = fs:: File :: open ( & training_state_path) . unwrap ( ) ;
90+ let mut bytes = Vec :: new ( ) ;
91+ ts_file. read_to_end ( & mut bytes) . unwrap ( ) ;
92+ let ts: TrainingState = bincode:: deserialize ( & bytes) . unwrap ( ) ;
93+ gpt. set_training_state ( ts, true ) ?;
94+
8995 println ! ( "Generating text:" ) ;
9096
9197 let inference = gpt. infer (
@@ -96,15 +102,9 @@ fn main() -> Result<(), GraphError> {
96102 |_ch| { } ,
97103 ) ?;
98104
99- // Generate 100 character with the currently trained model before
100- // starting the training loop.
105+ // Generate 100 character with the currently trained model
101106 println ! ( "{}" , tokenizer. untokenize( & inference) ) ;
102107
103- println ! ( "Saving the model..." ) ;
104- gpt. sync ( ) . unwrap ( ) ;
105- let ts = gpt. get_training_state ( ) . unwrap ( ) ;
106- let bytes = bincode:: serialize ( & ts) . unwrap ( ) ;
107- fs:: write ( training_state_path, & bytes) . expect ( "Unable to write file" ) ;
108108 Ok ( ( ) )
109109 }
110110 Cli :: Train { dataset, model } => {
You can’t perform that action at this time.
0 commit comments