What’s what in TensorFlow 2.0 – Part Deux

Welcome to part 2 of my “What’s what in TensorFlow 2.0” series. If you have not already, I encourage you to read part one first. In the previous article, we went through a couple of major changes that made TensorFlow 2.0 a revolution rather than evolution. In this article, we will continue with the remaining changes and hopefully, by the time you finish reading, you will learn among other things:

  • How Sub classing allowed us to write cleaner code and at the same time empowered us to create custom Models and Layers easier
  • How @TF.Function decorator allowed us to create more maintainable mini graphs (don’t worry, they are step-through debuggable – unlike like the old Session powered Graphs)
  • How much easier navigating this large Framework become after the namespaces cleanup (it is more exciting than it sounds)

Model Subclassing

This may sound unbelievable but before TensorFlow 2.0 there was no one way of creating classes that encapsulated Network Models or Custom Network Layers. There was no out of the box or recommended way of leveraging object-oriented programming. While some may argue the OOP is not so great, like this gentleman here. I think it is the best thing that happened to complex systems since sliced bread. In TF1 You would have to come up with your own way of grouping methods and logic into objects and classes if you wanted to leverage OOP. And as one can imagine, many people did which resulted in many, many different shapes and sizes of spaghetti-code monsters. TensorFlow 2.0 has borrowed the layers and model base classes from Keras which was very smart. After some modifications, finally, there is a well-described set of best practices on how to do OOP in TensorFlow world!

Have a look at bellow class encapsulating a Residual Network model composed of two ResNet blocks:

class ResNet(tf.keras.Model):

    def __init__(self):
        super(ResNet, self).__init__()
        self.block_1 = ResNetBlock()
        self.block_2 = ResNetBlock()
        self.global_pool = layers.GlobalAveragePooling2D()
        self.classifier = Dense(num_classes)

    def call(self, inputs):
        x = self.block_1(inputs)
        x = self.block_2(x)
        x = self.global_pool(x)
        return self.classifier(x)


resnet = ResNet()
dataset = ...
resnet.fit(dataset, epochs=10)
resnet.save_weights(filepath)

As you can see tf.keras.Model base class gives us a solid template on which programmers can base their models and on which many other TensorFlow APIs can depend on. It Exposes a “.call” method which is a standard way of encapsulating a single forward pass of data through the network. The model class also provides us with standard helper methods for saving and loading weights from the disk files. There are also other standard methods like .Compile() and .Fit() which in that order – configure the network for training with the parameters of your choice and later execute the training for a desired number of epochs. There is quite a lot of validation and optimization logic baked in this standard model under the hood. Think about it as a lot of solid boiler plate code that you will never have to write yourself!

Layer Subclassing

Similar to Models you can create your own custom network layers based on the common template class. You may want to override the build-in LSTM layer to change its internal behaviour or simply create a brand new custom layer with your own improvements. Both scenarios are easily achievable with the new Subclassing model. Below you can see a simple Linear layer in which we control the internal math applied on the data passing through it:

class Linear(layers.Layer):

  def __init__(self, units=32):
    super(Linear, self).__init__()
    self.units = units

  def build(self, input_shape):
    self.w = self.add_weight(shape=(input_shape[-1], self.units),
                             initializer='random_normal',
                             trainable=True)
    self.b = self.add_weight(shape=(self.units,),
                             initializer='random_normal',
                             trainable=True)

  def call(self, inputs):
    return tf.matmul(inputs, self.w) + self.b

Base class is purposely similar to the Model base class sharing a lot of helper methods like saving or storing weights. The “.call” method represents the Activation function – the mathematical operation applied on tensors passing through this layer. So again the advantage of this model is that in a hopefully short time we will see more standardised examples and snippets. They will, in turn, allow the community to more efficiently share code and learn from each other. If you want to dig deeper here is a link to the documentation of Model base class and the Layer base class.

TF.Function

The @TF.Function method decorator has done most of the hard work getting us off that old and clunky TF1 Graph and Session concept (may it forever peacefully rest in the backward compatibility namespace). Any python method decorated with the @TF.Function decorator will be at runtime statically analysed and converted into its own mini-graph. This conversion will change any control statements like loops into their tensor flow operations equivalents. So, if you’ll write a regular Python control statement like: 

while x > 0:  
    x = x - 1

Autograph feature of TensorFlow 2.0 will convert it to the old Graph-ready:

x = tf.while_loop(..., loop_vars=(x,)

You will not, however, see this happening as it will be converted at the runtime. For you, this will be a regular python code – highly optimized for performance and efficiency behind the scenes. And if you remove the decorator you will even be able to step-through debug it! This auto graphing feature allows TensorFlow to best convert and optimizes our simple readable python code to the more complex but efficient constructs that form the TensorFlow Graph. Believe me that complex ideas/models were not easy to write and read in the old days of TF1. If you’d like to dig deeper here is a great document describing the limitations of Auto-Graphing. I think it’s the best resource to get to know TF.Functions capabilities.

Note: The Graph is still present, but it has been abstracted away and does not require the Session to control its scope anymore. This is a great improvement allowing for simpler more verbose and maintainable code.

Namespaces

TensorFlow 1 was fat. And I mean really fat, like “deep-fried snickers bars for breakfast” kind of fat. It had over 2000 endpoints (ways of reaching inner modules) with 500 of them in the root tf namespace. This means if you would depend on our friendly neighbourhood inteli-sense to find something useful in TensorFlow you would be in for a long trip. Once you would press the dot after “tf” you would see a long list of 500+ suggestions of possible modules or functions you could use. Of course, it’s not only the number of endpoints – naming conventions were suboptimal, to say the least, and some modules you just wouldn’t expect in the endpoints they were placed. In short, it was a mess causing a lot of pain, especially to newcomers.

In the scope of the namespaces refactor a lot of modules were moved renamed or depreciated – some got moved to the backward compatibility “.compat” namespaces. Current endpoints setup is a huge step forward towards ease of discoverability and simplicity achieved via naming conventions. You can see details of the endpoint cleanup project in here. You may find it especially helpful if you are switching from the TF1.x

What’s next?

This concludes all the things I would like to see in one place when I first started wrapping my head around new TensorFlow. There are however other interesting concepts that I think everyone should familiarize themselves with. Here is what I encourage you to explore and what you will likely learn from one of my next articles:

I hope you’ve found this useful and as always feel free to leave your thoughts in the comments box below!

Tags: ,