Where communities thrive

  • Join over 1.5M+ people
  • Join over 100K+ communities
  • Free without limits
  • Create your own community
    Andreas Antoniou
    Will check it out and revert back
    Andreas Antoniou
    @OmarAlsaqa: I have tried to output the model as a tf model to serve it, although I found many obstacles. Even after I made it work with some ducktaping, the model's performance is pretty bad.
    I also changed the mode to predict, in order to adapt the infrastructure, however further issues were faced. Someone has to shed some light in this
    it takes at least 2 minutes for one simple prediction
    Omar Alsaqa
    @andreas_antoniou:matrix.org Hopefully someone would help us both. @lukaszkaiser
    Rafid K. Al-Humaimidi

    What is the benefit of providing a signature when calling the init() method of a model? Depending on the definition of the model, the signature is usually part of it. For example, when defining an Embedding layer, then surely the input is of shape (batch_size), where the values range from 0 to embed_size. In fact, in the code of Embedding, they simply delete the input signature:


    So, what is the point behind it?

    1 reply
    Similarly, if I define a TransformerEncoder and provide all the necessary information like vocab_size, n_classes, etc., why do I need to further provide input signature when I call init()?
    Rafid K. Al-Humaimidi
    I am new to Trax so forgive the basic question. I noticed that when I use the Loop class to train my model, it takes one batch of training data per step. For example, if I have 1000 batches in my training data, I would have to call run(1000) to finish a training epoch. Usually, one provides an iterator and doesn't know how many batches there are in the training data, so how do I go about telling Loop to go through the training data X epochs?
    1 reply
    Rafid K. Al-Humaimidi
    Does Trax simply reserve all GPU memory available? I defined an embedding layer of size 1000x64 and by the time I call init, almost all GPU memory is completely used. Why is that?
    2 replies
    how could i create personal subword represtentations like ende_32k.subword used in the Trax Quick Intro
    # Create a Transformer model.
    # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
    model = trax.models.Transformer(
        d_model=512, d_ff=2048,
        n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
        max_len=2048, mode='predict')
    # Initialize using pre-trained weights.
    # Tokenize a sentence.
    sentence = 'It is nice to learn new things today!'
    tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
    # Decode from the Transformer.
    tokenized = tokenized[None, :]  # Add batch dimension.
    tokenized_translation = trax.supervised.decoding.autoregressive_sample(
        model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.
    # De-tokenize,
    tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
    translation = trax.data.detokenize(tokenized_translation,
    Fernando Costa

    [Attention Visualization]

    Hi guys,

    How can I visualize attention weights/scores on my input sequences? I want to see where each attention head is paying attention on my input sequence at each input processed. I'm using a personalized Transformer decoder on my model, containing only the Causal Attention and this layer has a name that I can use, if needed.

    In Keras we can do this: https://stackoverflow.com/questions/53867351/how-to-visualize-attention-weights. Is there some related way to do this in Trax?

    2 replies
    Jake Searcy

    Hi All,

    Has anyone else had trouble with initializing a model from a file? When I start with a file the model isn't callable without a JAX error, and I see the weights have a different type than they would from running init.

    >> <class 'jaxlib.xla_extension.DeviceArray'>
    >> <class 'numpy.ndarray'>
    2 replies
    Can i use this library for any kind of language apart from english. For example the semtic languages?
    Ken Otwell
    Does trax run the Branch or Parallel layer splits actually in parallel in the GPU? My timing tests suggest not.
    1 reply
    Hi all! New to trax, I was trying to figure out where the pretrained models live. Is there a way to get an overview of all the pretrained models available? The tutorial has end_wmt32k.pkl.gz, but what about other languages? Or eg a ResNet50? Would it be possible to use pretrained models from tf_hub?
    1 reply
    Ken Otwell
    Anyone know how to get trax to actually use the GPU? If I call jax directly, it works fine - but when I train a trax model, there's a bit of activity at startup then nothing. Here's a tensorflow capture:
    Ryan Greenblatt
    Does anyone know how it would be possible to implement stop grad such that it only stops the gradient wrt to certain inputs?
    Francisco Javier Estrella Rodriguez
    Hello, is there any Trax community on discord?
    Jess Edmund Fan
    Does anyone have a minimal example of converting terraformer to keras? I have this, but it won't get past creating the hidden layer https://colab.research.google.com/drive/1Yiss6NKEimwcU9QAD4Vk7vsw59ZyTKz-?usp=sharing
    Peter Dippold
    Hello, I recently installed jax (0.3.7), jaxlib (0.3.7+cuda11.cudnn82) and trax (1.4.1) on a linux docker container (Cuda compilation tools, release 11.5, V11.5.119, Build cuda_11.5.r11.5/compiler.30672275_0, cudnn 8.3.1) and get no complaints after importing trax packages into a Jupyter notebook (like "GPU/TPU not found"), so everything looks fine. But when I try to train a model I can see that about 90% of my GPU's memory is allocated but the calculations are done by the CPU only (100% load, GPU 0% load). Is this due to the different cuda versions (8.2 vs. 8.3.1) or could there be any other reasons?
    2 replies
    Himanshu Chaturvedi
    Hi, I couldn't understand how to use the Parallel layer in the data module? can somebody elaborate on that?
    Alaa Shaker
    I have problem , I can't import trax on colab , just yesterday all thing was right, I search on stackoverflow or github to find solution but there is no one.
    please can any one help me?
    Alaa Shaker
    even on the official trax colab page I can't import it
    Hello, Is there a way to use maxpooling1D Layer in trax? I am unable to find it or if possible write a custom one using trax.layer.fn
    5 replies

    Cant debug why my model is failing. I am not sure how can I change the output.

    Here's the model

    def pre_bert_model(mode = 'train', dropout = 0.2):
      model = tl.Serial(
          tl.Embedding(vocab_size = sp.get_piece_size(), d_feature = 32),
          tl.convolution.Conv1d(filters = 100, kernel_size = 2, stride=1, padding = 'SAME'),
          tl.MaxPool(pool_size=(2,), strides=(1,), padding='SAME'),
          #tl.MaxPool(pool_size = tuple([1]*30), padding = 'SAME'),
          tl.Dropout(rate = dropout, mode = mode),
          tl.convolution.Conv1d(filters = 100, kernel_size = 3, stride=1, padding = 'SAME'),
          tl.MaxPool(pool_size=(2,), strides=(1,), padding='SAME'),
          tl.Dropout(rate = dropout, mode = mode),
          tl.convolution.Conv1d(filters = 100, kernel_size = 4, stride=1, padding = 'SAME'),
          tl.MaxPool(pool_size=(2,), strides=(1,), padding='SAME'),
          tl.Dropout(rate = dropout, mode = mode),
          tl.convolution.Conv1d(filters = 100, kernel_size = 5, stride=1, padding = 'SAME'),
          tl.MaxPool(pool_size=(2,), strides=(1,), padding='SAME'),
          tl.Dropout(rate = dropout, mode = mode),
          tl.LSTM(n_units = 100, mode = mode),
          tl.MaxPool(pool_size=(2,), strides=(1,), padding='SAME'),
          tl.Dense(n_units = 256),
          tl.Dropout(rate = dropout, mode = mode),
      return model

    the training task is defined here as:

    #train the model and save it in a file
    # UNQ_C5
    # GRADED FUNCTION: train_model
    def training_loop(pre_bert_model, train_gen, eval_gen, output_dir = "/content/drive/My Drive/Things to move/model/"):
            ReformerLM:  the Reformer language model you are building
            train_gen (generator): train data generator.
            eval_gen (generator): Validation generator. 
            output_dir (string): Path to save the model output. Defaults to './model/'.
            trax.supervised.training.Loop: Training loop for the model.
        # use the warmup_and_rsqrt_decay learning rate schedule
        lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01)
        ### START CODE HERE ###
        # define the train task
        train_task = training.TrainTask(            
            # labeled data
            # loss layer
            # optimizer
            # lr_schedule
            # n_steps
        # define the eval task
        eval_task = training.EvalTask(
        metrics=[tl.CrossEntropyLoss(), tl.Accuracy(classifier=tl.ThresholdToBinary())],
        n_eval_batches=20,  # For less variance in eval numbers.
        ### END CODE HERE ###
        loop = training.Loop(pre_bert_model(mode='train'),
        return loop
    loop = training_loop(pre_bert_model = pre_bert_model, train_gen = train_stream, eval_gen = eval_stream)

    and data pipeline is define as:

    data_pipeline = trax.data.Serial(
        # tokenize the data
        #tokenizer(list(data['tweet']), padding=True, truncation=True),
        trax.data.Tokenize(vocab_dir = vocab_dir, vocab_file='sarcasm_model.model', vocab_type='sentencepiece',keys = [0]),
        trax.data.FilterByLength(max_length=2048, length_keys=[0]),
        #lambda g: map(lambda x: (x[0], x[1]),g)
    def stream(data):
        # loop over the entire data
        tweets = iter([tuple([str(i[0]), np.int64(i[1])]) for i in data])
        return tweets

    following outputs 1 or 0 based for a given tweet.

    7 replies
    Hi, is there a Trax tutorial? I thought having seen one on Coursera but cannot find it ..
    I am pretty new to data science. I am trying to import custom csv dataset to this code.
    Can anyone show me how?

    This will download the train dataset if no data_dir is specified.

    train_stream_fn = trax.data.TFDS('para_crawl/enfr',
    data_dir='/content/drive/MyDrive/Colab Notebooks/data/',
    keys=('en', 'fr'),
    eval_holdout_size=0.01, # 1% for eval

    Get generator function for the eval set

    eval_stream_fn = trax.data.TFDS('para_crawl/enfr',
    data_dir='/content/drive/MyDrive/Colab Notebooks/data/',
    keys=('en', 'fr'),
    eval_holdout_size=0.01, # 1% for eval

    the csv dataset is in Gdrive
    I have mounted Gdrive
    How do I use my own dataset instead of paracrawl directory?

    Hi. Is this a version mismatch? The same exact code used to work two days ago, but now the following error at this cell:

    define the output directory

    output_dir = '/content/self_model/'

    define the training loop

    training_loop = training.Loop(model,

    /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py in device_put(x, device)
    1113 x = xla.canonicalize_dtype(x)
    1114 try:
    -> 1115 return device_put_handlerstype(x)
    1116 except KeyError as err:
    1117 raise TypeError(f"No device_put handler for type: {type(x)}") from err

    /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py in _device_put_array(x, device)
    1124 if x.dtype == dtypes.float0:
    1125 x = np.zeros(x.shape, dtype=np.dtype(bool))
    -> 1126 return (backend.buffer_from_pyval(x, device),)
    1128 def _device_put_scalar(x, device):

    /usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py in array(self, dtype, context)
    265 def array(self, dtype=None, context=None):
    --> 266 return np.asarray(self._value, dtype=dtype)
    268 setattr(device_array, "array", array)

    /usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in _sda_value(self)
    803 npy_value = np.empty(self.aval.shape, self.aval.dtype)
    804 for i in self.one_replica_buffer_indices:
    --> 805 npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
    806 self._npy_value = npy_value
    807 return self._npy_value

    TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

    Arnav Gupta
    def train_stream_fn():
    for a in zip(english_sentences_train,french_sentences_train):
    this could help you, I had two files, and reading them line by line and using the zip function to group the train data and train label helped in creating the same function on our own dataset. Similarly for eval data. I separated them using train test split from sklearn.
    @agoliaei it looks like your TPU ran out of memory, or you had batch sizes not in multiples of 8 while defining the batches.
    Hey, everyone I just started learning Trax, could anyone tell me where I can find the models from which a trax could be initiated
    In the trax intro they use this command model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz', weights_only=True)
    and I want to test other models, so where can I find them
    @arngpt Thank you for replying. The exact same code used to run without any problem. Batch size is 64. I don't know what is changed.
    This will list every file and folder in trax-ml directory.
    !gsutil ls gs://trax-ml/
    This message was deleted


    can someone help me ?

    Basically, I want to translate the following code from Pytorch to Jax/Trax.
    But I don't know how to do it, (the trax way).
    It is not that the code is to complicated but i'm not familiar enough with the initialization of the layers.
    (If a custom layer makes use of e.g. another layer, like in this example of a Dense layer)

    At this moment i've tried a lot of things, but i'm not getting anywhere.
    Also, (in my opinion) the documentation of trax is really lacking. Yes there is some standard documentation.
    But if you want to do something more advanced, it seems to be very difficult to find a useful example.
    (E.g. there are barely any examples of trax on stackoverflow, Blogs are almost non-existent, etc.)

    Here is the pytorch layer, which i want to translate to trax:

    class FeatureWiseAffine(nn.Module):
        def __init__(self, in_channels, out_channels, use_affine_level=False):
            super(FeatureWiseAffine, self).__init__()
            self.use_affine_level = use_affine_level
            self.noise_func = nn.Sequential(
                nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
        def forward(self, x, noise_embed):
            batch = x.shape[0]
            if self.use_affine_level:
                gamma, beta = self.noise_func(noise_embed).view(
                    batch, -1, 1, 1).chunk(2, dim=1)
                x = (1 + gamma) * x + beta
                x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
            return x

    Here is my attempt so far:
    As you can see, i've tried to use the trax Dense layer, but i'm not sure if this is the correct way to do it.

    • I find it very strange, that I've to initialize the Dense layer into my custom layer.
      It doesn't seem to be the correct way to do it, but i don't know how to do it otherwise.
    class FeatureWiseAffine(tl.base.Layer):                   
        # https://github.com/google/trax
        def __init__(self, out_channels, use_affine_level, dense_kernel_initializer = None, dense_bias_initializer = None):
            super().__init__(n_in=2, name=f'FeatureWiseAffine_{out_channels}_{use_affine_level}')
            self._out_channels:int = out_channels
            self._use_affine_level: bool = use_affine_level
            self._dense_kernel_initializer = dense_kernel_initializer
            self._dense_bias_initializer = dense_bias_initializer
        def forward(self, inputs):
            """Executes this layer as part of a forward pass through the model."""
            x, noise_embed = inputs
            batch = x.shape[0]
            if self._use_affine_level:
                #TODO : view and chunk ????
                gamma, beta = jnp.array_split(jnp.reshape(self._noise_func(noise_embed),(batch, -1, 1, 1)), 2, axis=1)
                x = (1 + gamma) * x + beta
                x = x + jnp.reshape(self._noise_func(noise_embed),(batch, -1, 1, 1))
            return x
        def init_weights_and_state(self, input_signature):
            """Randomly initializes this layer's weights.
            if self._dense_kernel_initializer is not None:
                self._noise_func = tl.Dense( n_units = self._out_channels * (1 + self._use_affine_level), 
                self._noise_func = tl.Dense( n_units = self._out_channels * (1 + self._use_affine_level))

    I'm trying to build a transformer decoder for passage summarization. I want to create summaries given a passage and set of questions to the passage.
    For this, I'm modifying my architecture to compute cross-attention between question encoding and (summarizer)decoder output encoding.

    I'm training the model with a mixed objective.

    1. Passage Summarizer decoder loss (cross entropy over vocab size)
    2. 0/1 classification loss for whether the summary can answer the question or not.

    I combine the losses as sum of the 2 CE losses.
    Currently I have 2 outputs from the model, logSoftmax() applied on :

    1. dense layer of vocab_size and,
    2. dense layer of size 2 for binary classification

    Should I write a custom loss fxn. which simply takes 2 inputs, computes Cross Entropies on the both and returns the sum, and use this loss fxn. as parameter to the trax TrainTask loss_layer ?

    I m trying to create a sarcasm detector. the training task is failing because it is expecting an extra param. I tried solving it using custom loss func. but it didnt seem to work. Can someone pls review it and suggest changes. the training task is the third last box in the notebook. https://colab.research.google.com/drive/16v9zJpILSJz7ah2WGy71-2bMIL6XuZve?usp=sharing
    1 reply
    Hi, I am trying to load weights from a pre-trained Resnet50 model. I don't see Resnet50 weight in gs/trax-ml. So I am guessing my other option is to load weights from one of pre-trained Resnet50 models on HuggingFace. However, HuggingFace model weights are in H5 format. So my question is: How can I load model weights from H5 format into trax Resnet50 model? Thanks for any suggestions!
    Mahima Bathla
    When I am trying to print a tensor I am getting: Traced<ShapedArray(float32[1,9,512])>with<DynamicJaxprTrace(level=0/1). Does anyone know how can I print a tensor? I want see its values
    does trax (fastmath) support the full numpy API? Can someone share a link to the docs for this?