Skip to content

Commit d00da38

Browse files
authored
Merge pull request #24 from dylan-park/main
2 parents 1af3258 + fee4270 commit d00da38

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

src/main.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ fn main() -> Result<(), GraphError> {
8686

8787
gpt.sync()?;
8888

89+
let mut ts_file = fs::File::open(&training_state_path).unwrap();
90+
let mut bytes = Vec::new();
91+
ts_file.read_to_end(&mut bytes).unwrap();
92+
let ts: TrainingState = bincode::deserialize(&bytes).unwrap();
93+
gpt.set_training_state(ts, true)?;
94+
8995
println!("Generating text:");
9096

9197
let inference = gpt.infer(
@@ -96,15 +102,9 @@ fn main() -> Result<(), GraphError> {
96102
|_ch| {},
97103
)?;
98104

99-
// Generate 100 character with the currently trained model before
100-
// starting the training loop.
105+
// Generate 100 character with the currently trained model
101106
println!("{}", tokenizer.untokenize(&inference));
102107

103-
println!("Saving the model...");
104-
gpt.sync().unwrap();
105-
let ts = gpt.get_training_state().unwrap();
106-
let bytes = bincode::serialize(&ts).unwrap();
107-
fs::write(training_state_path, &bytes).expect("Unable to write file");
108108
Ok(())
109109
}
110110
Cli::Train { dataset, model } => {

0 commit comments

Comments
 (0)