In this Python tutorial, we will learn about **Jax Vs PyTorch** in python. Jax is a programming language showing and creating transformations of numerical programs. It is also able to compile numerical programs for CPU or accelerating GPU.

And PyTorch is an open-source machine learning library that is mostly used for computer vision and natural language processing in python. And we will also cover different examples related to **Jax vs PyTorch**. And these are the following topics that we are going to discuss in this tutorial.

- Introduction to Jax
- Introduction to PyTorch
- Jax Vs PyTorch
- Jax Vs PyTorch Vs Tensorflow
- Jax Vs PyTorch benchmark

## Introduction to Jax

In this section, we will learn about **What is JAX** and how its works in python.

JAX stands for Just After execution. It is a machine learning library developed by DeepMind. Jax is a JIT( Just In Time ) compiler that focused on controlling the maximum numbers of FLOPS that create optimized code while using python.

Jax is a programming language showing and creating transformations of numerical programs. It is also able to compile numerical programs for CPU or accelerating GPU.

Jax enables Numpy code ON not only for CPU but GPU and TPU as well.

**Code:**

In the following code, we will import all the necessary libraries such as ** import jax.numpy as jnp**,

**,**

*import grad***,**

*jit***,**

*vmap from jax***.**

*and import random from jax***rd = random.PRNGKey(0)**is used to generate random data and the random state is described by two unsigned 32-bit integers that we call as a key.**y = random.normal(rd, (8,))**is used to generate a sample of numbers drawn from the normal distribution.**print(y)**is used to print the y values using the print() function.**size = 2899**is used to give the size.**random.normal(rd, (size, size), dtype=jnp.float32)**is used to generate a sample of numbers drawn from the normal distribution.**%timeit jnp.dot(y, y.T).block_until_ready()**it is runs on the GPU when GPU is available otherwise it is runs on the CPU.

```
# Importing libraries
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# Generating random data
rd = random.PRNGKey(0)
y = random.normal(rd, (8,))
print(y)
# Multiply two matrices
size = 2899
y = random.normal(rd, (size, size), dtype=jnp.float32)
# runs on the GPU
%timeit jnp.dot(y, y.T).block_until_ready()
```

**Output:**

After running the above code, we get the following output in which we can see that the multiplication of the two matrices is printed on the screen.

So with this, we understood Jax and how its works in python.

Read: PyTorch Activation Function

## Introduction to PyTorch

In this section, we will learn about **What is PyTorch** and how we can work with PyTorch in python.

PyTorch is an open-source machine learning library that is mostly used for computer vision and natural language processing in python. It is developed by the Facebook AI(Artificial Intelligence) Research Lab.

It is a software released under the Modified BSD License. It is built based on python which supports the calculation of tensors on GPU ( Graphical Processing Unit).

PyTorch is easy to use, has efficient memory usage, dynamic computational graph, is flexible, and creates coding feasible that increases the processing speed. The PyTorch is the most recommended library for deep learning and artificial intelligence.

**Code:**

In the following code, we will import all the necessary libraries such as import torch and import math.

**y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)**is used to create random input and output data.**m = torch.randn((), device=device, dtype=dtype)**is used to randomly initialize the weights.**z_pred = m + n * y + o * y ** 2 + p * y ** 3**is used as a forward pass to compute predicted Z.**loss = (z_pred – z).pow(2).sum().item()**is used to compute the loss.**print(int, loss)**is used to print the loss.**grad_m = grad_z_pred.sum():**Here we apply backpropagation to compute gradients of m,n,o, and p concerning loss.**m -= learning_rate * grad_m**is used to update weights using gradient descent.**print(f’Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3′)**is used to print the result using print() function.

```
# Importing libraries
import torch
import math
# Device Used
dtype = torch.float
device = torch.device("cpu")
# Create random input and output data
y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
z = torch.sin(y)
# Randomly initialize weights
m = torch.randn((), device=device, dtype=dtype)
n = torch.randn((), device=device, dtype=dtype)
o = torch.randn((), device=device, dtype=dtype)
p = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for i in range(2000):
# Forward pass: compute predicted z
z_pred = m + n * y + o * y ** 2 + p * y ** 3
# Compute and print loss
loss = (z_pred - z).pow(2).sum().item()
if i % 100 == 99:
print(int, loss)
# Backprop to compute gradients of m, n, o, p with respect to loss
grad_z_pred = 2.0 * (z_pred - z)
grad_m = grad_z_pred.sum()
grad_n = (grad_z_pred * y).sum()
grad_o = (grad_z_pred * y ** 2).sum()
grad_p = (grad_z_pred * y ** 3).sum()
# Update weights using gradient descent
m -= learning_rate * grad_m
n -= learning_rate * grad_n
o -= learning_rate * grad_o
p -= learning_rate * grad_p
# Print the result
print(f'Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3')
```

**Output:**

In the below output, you can see that the result of the items is printed on the screen.

So, with this, we understood PyTorch and its implementation.

Read: PyTorch fully connected layer

## Jax Vs PyTorch

In this section, we will learn about **the Key differences between Jax and PyTorch** in python.

Jax | PyTorch |

Jax was released in December 2018 | PyTorch was released in October 2016. |

Jax is developed by Google | PyTorch is developed by Facebook |

Its graph creation is static | Its graph creation is dynamic |

The target audience is researchers | The target audience is researchers and developers |

Jax implementation has linear run-time complexity. | PyTorch implementation has quadratic-time complexity. |

Jax is more flexible than PyTorch because it allows you to define functions and then automatically calculates the derivative of those functions. | PyTorch is flexible. |

The development stage is developing(v0.1.55) | The development stage is Mature(v1.8.0) |

Jax is more efficient than PyTorch because it can automatically parallelize our code across multiple CPUs. | PyTorch is efficient. |

So, with this, we understood about the key differences between the Jax and the PyTorch.

Read: PyTorch Logistic Regression

## Jax Vs PyTorch Vs TensorFlow

In this section, we will learn about **the key differences between Jax Vs PyTorch Vs** **TensorFlow **in python.

Jax | PyTorch | TensorFlow |

Jax is developed by Google. | PyTorch is developed by Facebook. | TensorFlow is developed by Google. |

Jax is flexible. | PyTorch is flexible. | TensorFlow is not flexible. |

Jax’s target audience is researchers | PyTorch target audience is researchers and developers | TensorFlow target audience is researchers and developers. |

Jax created static graphs | PyTorch created dynamic graphs | TensorFlow created both static and dynamic graphs |

Jax has both high-level and low-level API | PyTorch has a low-level API | TensorFlow has a high-level API |

Jax is more efficient than PyTorch and TensorFlow | PyTorch is less efficient than Jax | Tensorflow is also less efficient than Jax |

Jax’s development stage is Developing(v0.1.55) | The PyTorch development stage is Mature(v.1.8.0) | The TensorFlow development stage is Mature(v2.4.1) |

So, with this, we understood Jax Vs PyTorch Vs TensorFlow.

Read: PyTorch Dataloader + Examples

## Jax Vs PyTorch benchmark

In this section, we will learn about the **Jax Vs PyTorch benchmark** in python.

Jax is a machine learning library for changing numerical functions. It can assemble numerical programs for CPU or accelerators GPU.

**Code:**

- In the following code, we will import all the necessary libraries such as import jax.numpy as jnp, import grad, jit, vmap from jax, and import random from jax.
**m = random.PRNGKey(0)**is used to generate random data and the random state is described by two unsigned 32-bit integers that we call as a key.**i = random.normal(rd, (8,))**is used to generate a sample of numbers drawn from the normal distribution.**print(y)**is used to print the y values using the print() function.**size = 2899**is used to give the size.**random.normal(rd, (size, size), dtype=jnp.float32)**is used to generate a sample of numbers drawn from the normal distribution.**%timeit jnp.dot(y, y.T).block_until_ready()**it is runs on the GPU when GPU is available otherwise it is runs on the CPU.**%timeit jnp.dot(i, i.T).block_until_ready():**Here we are using block_until_ready because jax uses asynchronous execution.**i = num.random.normal(size=(siz, siz)).astype(num.float32)**is used to transfer the data on GPU

```
# Importing Libraries
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# Multiplying Matrices
m = random.PRNGKey(0)
i = random.normal(m, (10,))
print(i)
# Multiply two big matrices
siz = 2800
i = random.normal(m, (siz, siz), dtype=jnp.float32)
%timeit jnp.dot(i, i.T).block_until_ready()
# Jax Numpy function work on regular numpy array
import numpy as num
i = num.random.normal(size=(siz, siz)).astype(num.float32)
%timeit jnp.dot(i, i.T).block_until_ready()
# Transfer the data to GPU
from jax import device_put
i = num.random.normal(size=(siz, siz)).astype(num.float32)
i = device_put(i)
%timeit jnp.dot(i, i.T).block_until_ready()
```

**Output:**

After running the above code we get the following output in which we can see that the matrices multiplication using Jax is done on the screen.

**PyTorch benchmark**

PyTorch benchmark helps us validate that our code encounters performance expectations and compare different approaches for solving the problems.

**Code:**

In the following code, we will import all the necessary libraries such as import torch and import timeit.

**return m.mul(n).sum(-1)**is used to calculate batched dot by multiplying and sum.**m = m.reshape(-1, 1, m.shape[-1])**is used to calculate batched dot by reducing to bmm.**i = torch.randn(1000, 62)**is used as input for benchmarking.**print(f’multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us’)**is used to print multiply and sum values.**print(f’bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us’)**is used to print the bmm values.

```
# Import library
import torch
import timeit
# Define the Model
def batcheddot_multiply_sum(m, n):
# Calculates batched dot by multiplying and sum
return m.mul(n).sum(-1)
def batcheddot_bmm(m, n):
#Calculates batched dot by reducing to bmm
m = m.reshape(-1, 1, m.shape[-1])
n = n.reshape(-1, n.shape[-1], 1)
return torch.bmm(m, n).flatten(-3)
# Input for benchmarking
i = torch.randn(1000, 62)
# Ensure that both functions compute the same output
assert batcheddot_multiply_sum(i, i).allclose(batcheddot_bmm(i, i))
# Using timeit.Timer() method
j = timeit.Timer(
stmt='batcheddot_multiply_sum(i, i)',
setup='from __main__ import batcheddot_multiply_sum',
globals={'i': i})
j1 = timeit.Timer(
stmt='batcheddot_bmm(i, i)',
setup='from __main__ import batcheddot_bmm',
globals={'i': i})
print(f'multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us')
```

**Output:**

After running the above code we get the following output in which we can see that the multiplication and sum value using the PyTorch benchmark is printed on the screen.

So, with this, we understood about the Jax Vs PyTorch benchmarks in python.

Also, take a look at some more PyTorch tutorials.

- PyTorch Pretrained Model
- PyTorch Stack Tutorial
- PyTorch Tensor to Numpy
- PyTorch Batch Normalization
- PyTorch Hyperparameter Tuning
- PyTorch Load Model

So, in this tutorial, we discussed **Jax Vs Pytorch** and we have also covered different examples related to its implementation. Here is the list of examples that we have covered.

- Introduction to Jax
- Introduction to PyTorch
- Jax Vs PyTorch
- Jax Vs PyTorch Vs TensorFlow
- Jax Vs PyTorch benchmark

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.