Jax Vs PyTorch [Key Differences]

In this Python tutorial, we will learn about Jax Vs PyTorch in python. Jax is a programming language showing and creating transformations of numerical programs. It is also able to compile numerical programs for CPU or accelerating GPU.

And PyTorch is an open-source machine learning library that is mostly used for computer vision and natural language processing in python. And we will also cover different examples related to Jax vs PyTorch. And these are the following topics that we are going to discuss in this tutorial.

  • Introduction to Jax
  • Introduction to PyTorch
  • Jax Vs PyTorch
  • Jax Vs PyTorch Vs Tensorflow
  • Jax Vs PyTorch benchmark

Introduction to Jax

In this section, we will learn about What is JAX and how its works in python.

JAX stands for Just After execution. It is a machine learning library developed by DeepMind. Jax is a JIT( Just In Time ) compiler that focused on controlling the maximum numbers of FLOPS that create optimized code while using python.

Jax is a programming language showing and creating transformations of numerical programs. It is also able to compile numerical programs for CPU or accelerating GPU.

Jax enables Numpy code ON not only for CPU but GPU and TPU as well.

Code:

In the following code, we will import all the necessary libraries such as import jax.numpy as jnp, import grad, jit, vmap from jax, and import random from jax.

  • rd = random.PRNGKey(0) is used to generate random data and the random state is described by two unsigned 32-bit integers that we call as a key.
  • y = random.normal(rd, (8,)) is used to generate a sample of numbers drawn from the normal distribution.
  • print(y) is used to print the y values using the print() function.
  • size = 2899 is used to give the size.
  • random.normal(rd, (size, size), dtype=jnp.float32) is used to generate a sample of numbers drawn from the normal distribution.
  • %timeit jnp.dot(y, y.T).block_until_ready() it is runs on the GPU when GPU is available otherwise it is runs on the CPU.
# Importing libraries
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# Generating random data
rd = random.PRNGKey(0)
y = random.normal(rd, (8,))
print(y)
# Multiply two matrices
size = 2899
y = random.normal(rd, (size, size), dtype=jnp.float32)
# runs on the GPU
%timeit jnp.dot(y, y.T).block_until_ready()  

Output:

After running the above code, we get the following output in which we can see that the multiplication of the two matrices is printed on the screen.

Introduction to JAX
Introduction to JAX

So with this, we understood Jax and how its works in python.

Read: PyTorch Activation Function

Introduction to PyTorch

In this section, we will learn about What is PyTorch and how we can work with PyTorch in python.

PyTorch is an open-source machine learning library that is mostly used for computer vision and natural language processing in python. It is developed by the Facebook AI(Artificial Intelligence) Research Lab.

It is a software released under the Modified BSD License. It is built based on python which supports the calculation of tensors on GPU ( Graphical Processing Unit).

PyTorch is easy to use, has efficient memory usage, dynamic computational graph, is flexible, and creates coding feasible that increases the processing speed. The PyTorch is the most recommended library for deep learning and artificial intelligence.

Code:

In the following code, we will import all the necessary libraries such as import torch and import math.

  • y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) is used to create random input and output data.
  • m = torch.randn((), device=device, dtype=dtype) is used to randomly initialize the weights.
  • z_pred = m + n * y + o * y ** 2 + p * y ** 3 is used as a forward pass to compute predicted Z.
  • loss = (z_pred – z).pow(2).sum().item() is used to compute the loss.
  • print(int, loss) is used to print the loss.
  • grad_m = grad_z_pred.sum(): Here we apply backpropagation to compute gradients of m,n,o, and p concerning loss.
  • m -= learning_rate * grad_m is used to update weights using gradient descent.
  • print(f’Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3′) is used to print the result using print() function.
# Importing libraries
import torch 
import math 

# Device Used
dtype = torch.float 
device = torch.device("cpu") 

# Create random input and output data 
y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) 
z = torch.sin(y)  

# Randomly initialize weights 
m = torch.randn((), device=device, dtype=dtype) 
n = torch.randn((), device=device, dtype=dtype) 
o = torch.randn((), device=device, dtype=dtype) 
p = torch.randn((), device=device, dtype=dtype) 


learning_rate = 1e-6 
for i in range(2000): 
    # Forward pass: compute predicted z 
    z_pred = m + n * y + o * y ** 2 + p * y ** 3 
 
    # Compute and print loss 
    loss = (z_pred - z).pow(2).sum().item() 
    if i % 100 == 99: 
        print(int, loss) 

    # Backprop to compute gradients of m, n, o, p with respect to loss 
    grad_z_pred = 2.0 * (z_pred - z) 
    grad_m = grad_z_pred.sum() 
    grad_n = (grad_z_pred * y).sum() 
    grad_o = (grad_z_pred * y ** 2).sum() 
    grad_p = (grad_z_pred * y ** 3).sum() 
 
    # Update weights using gradient descent 
    m -= learning_rate * grad_m 
    n -= learning_rate * grad_n 
    o -= learning_rate * grad_o 
    p -= learning_rate * grad_p 
 
# Print the result
print(f'Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3') 

Output:

In the below output, you can see that the result of the items is printed on the screen.

Introduction to PyTorch
Introduction to PyTorch

So, with this, we understood PyTorch and its implementation.

Read: PyTorch fully connected layer

Jax Vs PyTorch

In this section, we will learn about the Key differences between Jax and PyTorch in python.

JaxPyTorch
Jax was released in December 2018PyTorch was released in October 2016.
Jax is developed by GooglePyTorch is developed by Facebook
Its graph creation is staticIts graph creation is dynamic
The target audience is researchersThe target audience is researchers and developers
Jax implementation has linear run-time complexity.PyTorch implementation has quadratic-time complexity.
Jax is more flexible than PyTorch because it allows you to define functions and then automatically calculates the derivative of those functions.PyTorch is flexible.
The development stage is developing(v0.1.55)The development stage is Mature(v1.8.0)
Jax is more efficient than PyTorch because it can automatically parallelize our code across multiple CPUs.PyTorch is efficient.
Jax Vs PyTorch

So, with this, we understood about the key differences between the Jax and the PyTorch.

Read: PyTorch Logistic Regression

Jax Vs PyTorch Vs TensorFlow

In this section, we will learn about the key differences between Jax Vs PyTorch Vs TensorFlow in python.

JaxPyTorchTensorFlow
Jax is developed by Google.PyTorch is developed by Facebook. TensorFlow is developed by Google.
Jax is flexible.PyTorch is flexible.TensorFlow is not flexible.
Jax’s target audience is researchersPyTorch target audience is researchers and developersTensorFlow target audience is researchers and developers.
Jax created static graphsPyTorch created dynamic graphsTensorFlow created both static and dynamic graphs
Jax has both high-level and low-level APIPyTorch has a low-level APITensorFlow has a high-level API
Jax is more efficient than PyTorch and TensorFlowPyTorch is less efficient than JaxTensorflow is also less efficient than Jax
Jax’s development stage is Developing(v0.1.55)The PyTorch development stage is Mature(v.1.8.0)The TensorFlow development stage is Mature(v2.4.1)
Jax Vs PyTorch Vs TensorFlow

So, with this, we understood Jax Vs PyTorch Vs TensorFlow.

Read: PyTorch Dataloader + Examples

Jax Vs PyTorch benchmark

In this section, we will learn about the Jax Vs PyTorch benchmark in python.

Jax is a machine learning library for changing numerical functions. It can assemble numerical programs for CPU or accelerators GPU.

Code:

  • In the following code, we will import all the necessary libraries such as import jax.numpy as jnp, import grad, jit, vmap from jax, and import random from jax.
  • m = random.PRNGKey(0) is used to generate random data and the random state is described by two unsigned 32-bit integers that we call as a key.
    i = random.normal(rd, (8,)) is used to generate a sample of numbers drawn from the normal distribution.
  • print(y) is used to print the y values using the print() function.
    size = 2899 is used to give the size.
    random.normal(rd, (size, size), dtype=jnp.float32) is used to generate a sample of numbers drawn from the normal distribution.
    %timeit jnp.dot(y, y.T).block_until_ready() it is runs on the GPU when GPU is available otherwise it is runs on the CPU.
  • %timeit jnp.dot(i, i.T).block_until_ready(): Here we are using block_until_ready because jax uses asynchronous execution.
  • i = num.random.normal(size=(siz, siz)).astype(num.float32) is used to transfer the data on GPU
# Importing Libraries
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Multiplying Matrices
m = random.PRNGKey(0)
i = random.normal(m, (10,))
print(i)

# Multiply two big matrices
siz = 2800
i = random.normal(m, (siz, siz), dtype=jnp.float32)
%timeit jnp.dot(i, i.T).block_until_ready()  

# Jax Numpy function work on regular numpy array
import numpy as num
i = num.random.normal(size=(siz, siz)).astype(num.float32)
%timeit jnp.dot(i, i.T).block_until_ready()

# Transfer the data to GPU
from jax import device_put

i = num.random.normal(size=(siz, siz)).astype(num.float32)
i = device_put(i)
%timeit jnp.dot(i, i.T).block_until_ready()

Output:

After running the above code we get the following output in which we can see that the matrices multiplication using Jax is done on the screen.

Jax Vs PyTorch benchmark
Jax Vs PyTorch benchmark

PyTorch benchmark

PyTorch benchmark helps us validate that our code encounters performance expectations and compare different approaches for solving the problems.

Code:

In the following code, we will import all the necessary libraries such as import torch and import timeit.

  • return m.mul(n).sum(-1) is used to calculate batched dot by multiplying and sum.
  • m = m.reshape(-1, 1, m.shape[-1]) is used to calculate batched dot by reducing to bmm.
  • i = torch.randn(1000, 62) is used as input for benchmarking.
  • print(f’multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us’) is used to print multiply and sum values.
  • print(f’bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us’) is used to print the bmm values.
# Import library
import torch
import timeit


# Define the Model
def batcheddot_multiply_sum(m, n):
    # Calculates batched dot by multiplying and sum
    return m.mul(n).sum(-1)


def batcheddot_bmm(m, n):
    #Calculates batched dot by reducing to bmm
    m = m.reshape(-1, 1, m.shape[-1])
    n = n.reshape(-1, n.shape[-1], 1)
    return torch.bmm(m, n).flatten(-3)


# Input for benchmarking
i = torch.randn(1000, 62)

# Ensure that both functions compute the same output
assert batcheddot_multiply_sum(i, i).allclose(batcheddot_bmm(i, i))


# Using timeit.Timer() method
j = timeit.Timer(
    stmt='batcheddot_multiply_sum(i, i)',
    setup='from __main__ import batcheddot_multiply_sum',
    globals={'i': i})

j1 = timeit.Timer(
    stmt='batcheddot_bmm(i, i)',
    setup='from __main__ import batcheddot_bmm',
    globals={'i': i})

print(f'multiply_sum(i, i):  {j.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(i, i):      {j1.timeit(100) / 100 * 1e6:>5.1f} us')

Output:

After running the above code we get the following output in which we can see that the multiplication and sum value using the PyTorch benchmark is printed on the screen.

Jax Vs PyTorch benchmark
Jax Vs PyTorch benchmark

So, with this, we understood about the Jax Vs PyTorch benchmarks in python.

Also, take a look at some more PyTorch tutorials.

So, in this tutorial, we discussed Jax Vs Pytorch and we have also covered different examples related to its implementation. Here is the list of examples that we have covered.

  • Introduction to Jax
  • Introduction to PyTorch
  • Jax Vs PyTorch
  • Jax Vs PyTorch Vs TensorFlow
  • Jax Vs PyTorch benchmark