Pytorch fx - It is a machine learning library

lathashivkumar22 21 views 10 slides Oct 11, 2024
Slide 1
Slide 1 of 10
Slide 1
1
Slide 2
2
Slide 3
3
Slide 4
4
Slide 5
5
Slide 6
6
Slide 7
7
Slide 8
8
Slide 9
9
Slide 10
10

About This Presentation

Pytorch fx


Slide Content

PYTORCH FX

AGENDA Execution modes Torch fx Torch FX Example Torch compile

EXECUTION MODES PyTorch is an open-source machine learning library developed using Torch library for python programs. PyTorch supports two execution modes 1.Eager Mode 2.Graph Mode

EAGER MODE Eager mode (which is called as define-by-run), operators in a model are immediately executed as they are encountered. That means you do not execute a pre-constructed graph with Session. GRAPH MODE Graph execution (which is called as define and run) extracts tensor computations from Python and builds an efficient graph before evaluation. Graph mode enables operator fusion, wherein one operator is merged with another to reduce/localize memory reads.

TORCH FX Torch.FX is available toolkit as part of the PyTorch package that supports graph mode execution. It can make DAG representation of pytorch pgms. Allows to transform/optimize pytorch code using algorithms. Helps build torch.compile

TORCH FX EXAMPLE import torch import torch.fx from torch.fx import Graph def add_eg(x,y): # import pdb # pdb.set_trace() a = torch.sin(x)**2; b = torch.cos(y)**2; return torch.add(a,b) traced = torch.fx.symbolic_trace(add_eg) print(traced.graph) traced.graph.print_tabular() print(traced.code)

graph(): %x : [num_users=1] = placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %sin : [num_users=1] = call_function[target=torch.sin](args = (%x,), kwargs = {}) %pow_1 : [num_users=1] = call_function[target=operator.pow](args = (%sin, 2), kwargs = {}) %cos : [num_users=1] = call_function[target=torch.cos](args = (%y,), kwargs = {}) %pow_2 : [num_users=1] = call_function[target=operator.pow](args = (%cos, 2), kwargs = {}) %add : [num_users=1] = call_function[target=torch.add](args = (%pow_1, %pow_2), kwargs = {}) return add

x sin sin2x y cos cos2x ADD OUTPUT

TORCH COMPILE torch.compile is a method to JIT-compile PyTorch code into optimized kernels. EXAMPLE def foo(x, y): a = torch.sin(x) b = torch.cos(y) return a + b opt_foo1 = torch.compile(foo) print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

choose differnet backends : # Reset since we are using a different mode. import torch._dynamo torch._dynamo.reset() opt_model = torch.compile(foo, backend = custom_backend) eager = foo(torch.randn((10,10)), torch.randn((10,10))) graph_res = opt_model(torch.randn((10,10)), torch.randn((10,10)))
Tags