@nkitaev I'm using this to feed the multiple text files. Do you think I can tweak any of the hyparameters in the parse_config to run the model longer than half an hour without running into memory issues?
def my_inputs(n_devices): while True: file = random.choice(os.listdir('files')) with GFile('/files/' + file) as f: text = f.read() IDS = TOKENIZER.EncodeAsIds(text)
MultifactorSchedulecontrol the learning rate schedule, which only affects how long training takes and not how much memory is used. You can try running with a little more warmup steps, and more
steps_per_cyclein the cyclic cosine schedule.
my_inputswill let you feed in your own data, and you can tune the model hyperparameters as well
Just wanted to chime in to say we are really excited to be making use of Trax in the Project Clarify (https://github.com/projectclarify/clarify) codebase and would like to invite anyone with experience with Trax or Jax and interest in mentoring to come give a tutorial session at one of our upcoming hackathon/training days (Jan 25th and Feb 29): https://forms.gle/oFWkN7UuAxS7NUGJ9 Especially you @afrozenator, didn't have your email to send you an invite. Looking forward to adding value to Trax as we get up to speed with using it.
Thanks @cwbeitel - that looks exciting, will follow up over email.
Step 1563: train accuracy | 0.21875000 Step 1563: train loss | 13.42460442 - Step 3126: train accuracy | 0.15625000 Step 3126: train loss | 2.90936065 - Step 4689: train accuracy | 0.28125000 Step 4689: train loss | 1.86861885 - Step 6252: train accuracy | 0.09375000 Step 6252: train loss | 20935.30468750 - Step 7815: train accuracy | 0.46875000 Step 7815: train loss | 1.39475393