@@ -4,7 +4,30 @@ use femto_gpt::optimizer::AdamW;
44use femto_gpt:: tokenizer:: { SimpleTokenizer , Tokenizer } ;
55use std:: fs;
66use std:: io:: prelude:: * ;
7- use std:: path:: Path ;
7+ use std:: path:: PathBuf ;
8+ use structopt:: StructOpt ;
9+
10+ #[ derive( StructOpt , Debug ) ]
11+ enum Cli {
12+ Train {
13+ #[ structopt( long, default_value = "dataset.txt" ) ]
14+ dataset : PathBuf ,
15+ #[ structopt( long, default_value = "training_state.dat" ) ]
16+ model : PathBuf ,
17+ } ,
18+ Infer {
19+ #[ structopt( long, default_value = "dataset.txt" ) ]
20+ tokenizer_dataset : PathBuf ,
21+ #[ structopt( long, default_value = "training_state.dat" ) ]
22+ model : PathBuf ,
23+ #[ structopt( long) ]
24+ prompt : String ,
25+ #[ structopt( long, default_value = "100" ) ]
26+ count : usize ,
27+ #[ structopt( long, default_value = "0.5" ) ]
28+ temperature : f32 ,
29+ } ,
30+ }
831
932fn main ( ) -> Result < ( ) , GraphError > {
1033 #[ cfg( not( feature = "gpu" ) ) ]
@@ -17,130 +40,193 @@ fn main() -> Result<(), GraphError> {
1740 #[ cfg( feature = "gpu" ) ]
1841 let is_gpu = true ;
1942
20- let training_state_path = Path :: new ( "training_state.dat" ) ;
21-
22- let mut rng = rand:: thread_rng ( ) ;
23-
24- // Create a unique char-to-int mapping for all unique characters inside our dataset
25- let dataset_char =
26- fs:: read_to_string ( "dataset.txt" ) . expect ( "Should have been able to read the file" ) ;
27- let tokenizer = SimpleTokenizer :: new ( & dataset_char) ;
28-
29- let dataset = tokenizer. tokenize ( & dataset_char) ;
30-
3143 let batch_size = 32 ;
3244 let num_tokens = 64 ;
33- let vocab_size = tokenizer. vocab_size ( ) ;
3445 let embedding_degree = 64 ;
3546 let num_layers = 4 ;
3647 let num_heads = 4 ;
3748 let head_size = embedding_degree / num_heads;
3849 let dropout = 0.0 ;
39-
4050 assert_eq ! ( num_heads * head_size, embedding_degree) ;
4151
42- println ! ( "Vocab-size: {} unique characters" , vocab_size) ;
43-
44- let mut gpt = GPT :: new (
45- & mut rng,
46- graph,
47- is_gpu. then ( || batch_size) , // Pre-allocate batches only when using GPUs
48- vocab_size,
49- embedding_degree,
50- num_tokens,
51- num_layers,
52- num_heads,
53- head_size,
54- dropout,
55- ) ?;
56-
57- gpt. sync ( ) ?;
58-
59- println ! ( "Number of parameters: {}" , gpt. num_params( ) ) ;
60-
61- // Load training data from train_data directory (If exists)
62- // If you want to reuse training_data of a smaller model in a bigger model, you may
63- // first start again with a new optimizer by setting load_optimizer=false
64- // WARN: YOU CAN ONLY REUSE THE WEIGHTS OF A MODEL WITH DIFFERENT NUM-LAYERS!
65- // IT'S NOT POSSIBLE TO CHANGE OTHER PROPERTIES ONCE THE MODEL IS TRAINED!
66- if training_state_path. is_file ( ) {
67- let mut ts_file = fs:: File :: open ( training_state_path) . unwrap ( ) ;
68- let mut bytes = Vec :: new ( ) ;
69- ts_file. read_to_end ( & mut bytes) . unwrap ( ) ;
70- let ts: TrainingState = bincode:: deserialize ( & bytes) . unwrap ( ) ;
71- gpt. set_training_state ( ts, true ) ?;
72- }
73-
74- println ! ( ) ;
75- println ! ( "Starting the training loop... (This make take hours to converge! be patient!)" ) ;
76- println ! ( ) ;
77-
78- let base_lr = 0.001 ;
79- let min_lr = 0.00001 ;
80- let warmup_steps = 100 ;
81- let decay_steps = 50000 ;
82-
83- let learning_rate = |step| {
84- if step < warmup_steps {
85- ( base_lr / warmup_steps as f32 ) * step as f32
86- } else {
87- // Fancy LR tuning, thanks to https://github.com/cutoken!
88- f32:: max (
89- min_lr,
90- base_lr - ( base_lr - min_lr) * ( step - warmup_steps) as f32 / decay_steps as f32 ,
91- )
52+ let cli = Cli :: from_args ( ) ;
53+ match cli {
54+ Cli :: Infer {
55+ tokenizer_dataset,
56+ model,
57+ prompt,
58+ count,
59+ temperature,
60+ } => {
61+ let training_state_path = & model. clone ( ) ;
62+
63+ let mut rng = rand:: thread_rng ( ) ;
64+
65+ // Create a unique char-to-int mapping for all unique characters inside our dataset
66+ let dataset_char = fs:: read_to_string ( tokenizer_dataset)
67+ . expect ( "Should have been able to read the file" ) ;
68+ let tokenizer = SimpleTokenizer :: new ( & dataset_char) ;
69+
70+ assert_eq ! ( num_heads * head_size, embedding_degree) ;
71+
72+ let vocab_size = tokenizer. vocab_size ( ) ;
73+ println ! ( "Vocab-size: {} unique characters" , vocab_size) ;
74+ let mut gpt = GPT :: new (
75+ & mut rng,
76+ graph,
77+ is_gpu. then ( || batch_size) , // Pre-allocate batches only when using GPUs
78+ vocab_size,
79+ embedding_degree,
80+ num_tokens,
81+ num_layers,
82+ num_heads,
83+ head_size,
84+ dropout,
85+ ) ?;
86+
87+ gpt. sync ( ) ?;
88+
89+ println ! ( "Generating text:" ) ;
90+
91+ let inference = gpt. infer (
92+ & mut rng,
93+ & tokenizer. tokenize ( & prompt) ,
94+ count,
95+ temperature,
96+ |_ch| { } ,
97+ ) ?;
98+
99+ // Generate 100 character with the currently trained model before
100+ // starting the training loop.
101+ println ! ( "{}" , tokenizer. untokenize( & inference) ) ;
102+
103+ println ! ( "Saving the model..." ) ;
104+ gpt. sync ( ) . unwrap ( ) ;
105+ let ts = gpt. get_training_state ( ) . unwrap ( ) ;
106+ let bytes = bincode:: serialize ( & ts) . unwrap ( ) ;
107+ fs:: write ( training_state_path, & bytes) . expect ( "Unable to write file" ) ;
108+ Ok ( ( ) )
92109 }
93- } ;
94-
95- let callback = |gpt : & mut GPT < _ > | {
96- let mut rng = rand:: thread_rng ( ) ;
97- let inference_temperature = 0.5 ; // How creative? 0.0 min 1.0 max
98-
99- println ! ( "Generating text:" ) ;
100-
101- let inference = gpt. infer (
102- & mut rng,
103- & tokenizer. tokenize ( "\n " ) ,
104- 100 ,
105- inference_temperature,
106- |_ch| { } ,
107- ) ?;
108-
109- // Generate 100 character with the currently trained model before
110- // starting the training loop.
111- println ! ( "{}" , tokenizer. untokenize( & inference) ) ;
112-
113- println ! ( "Saving the model..." ) ;
114- gpt. sync ( ) . unwrap ( ) ;
115- let ts = gpt. get_training_state ( ) . unwrap ( ) ;
116- let bytes = bincode:: serialize ( & ts) . unwrap ( ) ;
117- fs:: write ( training_state_path, & bytes) . expect ( "Unable to write file" ) ;
118-
119- Ok ( ( ) )
120- } ;
121-
122- // Training loop!
123- #[ cfg( not( feature = "gpu" ) ) ]
124- gpt. train_cpu (
125- & dataset,
126- 100000 ,
127- batch_size,
128- None , // or Some(n), limit backward process to last n computations
129- & AdamW :: new ( ) ,
130- learning_rate,
131- callback,
132- ) ?;
133-
134- #[ cfg( feature = "gpu" ) ]
135- gpt. train (
136- & dataset,
137- 100000 ,
138- batch_size,
139- None , // or Some(n), limit backward process to last n computations
140- & AdamW :: new ( ) ,
141- learning_rate,
142- callback,
143- ) ?;
144-
145- Ok ( ( ) )
110+ Cli :: Train { dataset, model } => {
111+ let training_state_path = & model. clone ( ) ;
112+
113+ let mut rng = rand:: thread_rng ( ) ;
114+
115+ // Create a unique char-to-int mapping for all unique characters inside our dataset
116+ let dataset_char =
117+ fs:: read_to_string ( dataset) . expect ( "Should have been able to read the file" ) ;
118+ let tokenizer = SimpleTokenizer :: new ( & dataset_char) ;
119+
120+ let dataset = tokenizer. tokenize ( & dataset_char) ;
121+
122+ let vocab_size = tokenizer. vocab_size ( ) ;
123+ println ! ( "Vocab-size: {} unique characters" , vocab_size) ;
124+ let mut gpt = GPT :: new (
125+ & mut rng,
126+ graph,
127+ is_gpu. then ( || batch_size) , // Pre-allocate batches only when using GPUs
128+ vocab_size,
129+ embedding_degree,
130+ num_tokens,
131+ num_layers,
132+ num_heads,
133+ head_size,
134+ dropout,
135+ ) ?;
136+
137+ gpt. sync ( ) ?;
138+
139+ println ! ( "Number of parameters: {}" , gpt. num_params( ) ) ;
140+
141+ // Load training data from train_data directory (If exists)
142+ // If you want to reuse training_data of a smaller model in a bigger model, you may
143+ // first start again with a new optimizer by setting load_optimizer=false
144+ // WARN: YOU CAN ONLY REUSE THE WEIGHTS OF A MODEL WITH DIFFERENT NUM-LAYERS!
145+ // IT'S NOT POSSIBLE TO CHANGE OTHER PROPERTIES ONCE THE MODEL IS TRAINED!
146+ if training_state_path. is_file ( ) {
147+ let mut ts_file = fs:: File :: open ( & training_state_path) . unwrap ( ) ;
148+ let mut bytes = Vec :: new ( ) ;
149+ ts_file. read_to_end ( & mut bytes) . unwrap ( ) ;
150+ let ts: TrainingState = bincode:: deserialize ( & bytes) . unwrap ( ) ;
151+ gpt. set_training_state ( ts, true ) ?;
152+ }
153+
154+ println ! ( ) ;
155+ println ! (
156+ "Starting the training loop... (This make take hours to converge! be patient!)"
157+ ) ;
158+ println ! ( ) ;
159+
160+ let base_lr = 0.001 ;
161+ let min_lr = 0.00001 ;
162+ let warmup_steps = 100 ;
163+ let decay_steps = 50000 ;
164+
165+ let learning_rate = |step| {
166+ if step < warmup_steps {
167+ ( base_lr / warmup_steps as f32 ) * step as f32
168+ } else {
169+ // Fancy LR tuning, thanks to https://github.com/cutoken!
170+ f32:: max (
171+ min_lr,
172+ base_lr
173+ - ( base_lr - min_lr) * ( step - warmup_steps) as f32
174+ / decay_steps as f32 ,
175+ )
176+ }
177+ } ;
178+
179+ let callback = |gpt : & mut GPT < _ > | {
180+ let mut rng = rand:: thread_rng ( ) ;
181+ let inference_temperature = 0.5 ; // How creative? 0.0 min 1.0 max
182+
183+ println ! ( "Generating text:" ) ;
184+
185+ let inference = gpt. infer (
186+ & mut rng,
187+ & tokenizer. tokenize ( "\n " ) ,
188+ 100 ,
189+ inference_temperature,
190+ |_ch| { } ,
191+ ) ?;
192+
193+ // Generate 100 character with the currently trained model before
194+ // starting the training loop.
195+ println ! ( "{}" , tokenizer. untokenize( & inference) ) ;
196+
197+ println ! ( "Saving the model..." ) ;
198+ gpt. sync ( ) . unwrap ( ) ;
199+ let ts = gpt. get_training_state ( ) . unwrap ( ) ;
200+ let bytes = bincode:: serialize ( & ts) . unwrap ( ) ;
201+ fs:: write ( training_state_path, & bytes) . expect ( "Unable to write file" ) ;
202+
203+ Ok ( ( ) )
204+ } ;
205+
206+ // Training loop!
207+ #[ cfg( not( feature = "gpu" ) ) ]
208+ gpt. train_cpu (
209+ & dataset,
210+ 100000 ,
211+ batch_size,
212+ None , // or Some(n), limit backward process to last n computations
213+ & AdamW :: new ( ) ,
214+ learning_rate,
215+ callback,
216+ ) ?;
217+
218+ #[ cfg( feature = "gpu" ) ]
219+ gpt. train (
220+ & dataset,
221+ 100000 ,
222+ batch_size,
223+ None , // or Some(n), limit backward process to last n computations
224+ & AdamW :: new ( ) ,
225+ learning_rate,
226+ callback,
227+ ) ?;
228+
229+ Ok ( ( ) )
230+ }
231+ }
146232}
0 commit comments