def my_inputs(n_devices):
while True:
file = random.choice(os.listdir('files'))
with GFile('/files/' + file) as f:
text = f.read()
IDS = TOKENIZER.EncodeAsIds(text)
@nkitaev I'm using this to feed the multiple text files. Do you think I can tweak any of the hyparameters in the parse_config to run the model longer than half an hour without running into memory issues?
MultifactorSchedule
control the learning rate schedule, which only affects how long training takes and not how much memory is used. You can try running with a little more warmup steps, and more steps_per_cycle
in the cyclic cosine schedule.
my_inputs
will let you feed in your own data, and you can tune the model hyperparameters as well
Just wanted to chime in to say we are really excited to be making use of Trax in the Project Clarify (https://github.com/projectclarify/clarify) codebase and would like to invite anyone with experience with Trax or Jax and interest in mentoring to come give a tutorial session at one of our upcoming hackathon/training days (Jan 25th and Feb 29): https://forms.gle/oFWkN7UuAxS7NUGJ9 Especially you @afrozenator, didn't have your email to send you an invite. Looking forward to adding value to Trax as we get up to speed with using it.
Thanks @cwbeitel - that looks exciting, will follow up over email.