Reinforcement Learning with Human Feedback ODSC talk

LuisSerranoPhD 131 views 43 slides Sep 05, 2024
Slide 1
Slide 1 of 77
Slide 1
1
Slide 2
2
Slide 3
3
Slide 4
4
Slide 5
5
Slide 6
6
Slide 7
7
Slide 8
8
Slide 9
9
Slide 10
10
Slide 11
11
Slide 12
12
Slide 13
13
Slide 14
14
Slide 15
15
Slide 16
16
Slide 17
17
Slide 18
18
Slide 19
19
Slide 20
20
Slide 21
21
Slide 22
22
Slide 23
23
Slide 24
24
Slide 25
25
Slide 26
26
Slide 27
27
Slide 28
28
Slide 29
29
Slide 30
30
Slide 31
31
Slide 32
32
Slide 33
33
Slide 34
34
Slide 35
35
Slide 36
36
Slide 37
37
Slide 38
38
Slide 39
39
Slide 40
40
Slide 41
41
Slide 42
42
Slide 43
43
Slide 44
44
Slide 45
45
Slide 46
46
Slide 47
47
Slide 48
48
Slide 49
49
Slide 50
50
Slide 51
51
Slide 52
52
Slide 53
53
Slide 54
54
Slide 55
55
Slide 56
56
Slide 57
57
Slide 58
58
Slide 59
59
Slide 60
60
Slide 61
61
Slide 62
62
Slide 63
63
Slide 64
64
Slide 65
65
Slide 66
66
Slide 67
67
Slide 68
68
Slide 69
69
Slide 70
70
Slide 71
71
Slide 72
72
Slide 73
73
Slide 74
74
Slide 75
75
Slide 76
76
Slide 77
77

About This Presentation

Reinforcement Learning with Human Feedback, PPO, and DPO


Slide Content

Reinforcement Learning with Human Feedback
(RLHF)
Luis Serrano

Topics
•Large Language Models
•How to fine-tune them with reinforcement learning
•Quick intro to reinforcement learning
•PPO (the reinforcement learning technique to fine-tune LLMs)
•DPO (the non-reinforcement learning technique to fine-tune LLMs)

Serrano.Academy Serrano.Academy
Large Language Models

Serrano.Academy Serrano.Academy
Transformers
I generate text… one word at a time
Transformers

Serrano.Academy Serrano.Academy
Transformers
Hello, how are you doing
Transformers

Serrano.Academy Serrano.Academy
Transformers
Write a story. Once
Transformers

Serrano.Academy Serrano.Academy
Transformers
Write a story.Once upon
Transformers

Serrano.Academy Serrano.Academy
Transformers
Write a story.Onceupon a
Transformers

Serrano.Academy Serrano.Academy
How to train transformer models?
The internetCurated datasets
Transformers

Serrano.Academy Serrano.Academy
Is that the end of the story? (Nope)

Serrano.Academy Serrano.Academy
Transformers
Blue
What color is the sky?
Red
Banana
Transformers need to be fine tuned
by humans!

Serrano.Academy Serrano.Academy
Transformers need to be fine tuned
Transformers
It’s, ehhm… Cairo?
What is the capital of Sudan?
No! It’s
Khartoum!
by humans!

Quick intro to Reinforcement Learning

1
1 1 1 3
+4
Gridworld
+4
+5
-1

1
States

+4
+5
-1
1
1 1 1 3
Actions
+5
+4
1
1 1 1 3
-1

+4
+5
-1
1
1
Actions
+5
+4
1
1 1 1 3
-1

+5
+4 1
-1
It costs 1 to move one spot
4

+5
+4 1
-1
What’s the best I can do from here?
32
4
2

+5
+4 1
-1
What’s the best I can do from here?
32
4
1
0
-1
-4 -3 -2
3
2
2

+5
+4 1
-1
How to calculate all the values?
2
2
2
4
4
3
3
3
2
2
2
1
1
1
1
1
0
0
0
-2
-2-1
-1
-1
3
3
-2
-2

Policy

Policy if it’s free to move
+5
+4
1
1 1 1 3
-1

+5
+4 1
-1
Where is the best place to move?
2 1 2 3 4
3 2 1 2 3 4
0 1 2 3
3 -1
2 1 0 -1 -2 -2
1 0 -1 -2 -2

+5
+4 1
-1
Stochastic Policy
4-3
1
2

Neural networks

Problem solved! (Or is it?)
+5
+4
1
1 1 1 3
-1
2 1 2 3 4+5
3 2 1 2 3 4
+4
1 0 1 2 3
3 1 -1 1 1 3
2 1 0 -1 -2 -2
1 0 -1 -2 -2-1
Values Policy

Value network
x
y
Value
Neural
network
x
y
score = 3
x
y
3

Policy network
x
y
x
y
p( )
p( )
p( )
p( )
x
y
= 0.60
= 0.25
= 0.05
= 0.10
Policy
Neural
network

Value neural network
2 1 2 3 4
+5
3 2 1 2 3 4
+4
1 0 1 2 3
3 1 -1 1 1 3
2 1 0 -1 -2 -2
1 0 -1 -2 -2
-1
Values
3.11.42.53.22.7
+5
2.2-0.21.32.94.13.2
+4
1 -0.51.72.83.2
3.9 1 0.2 1 1 3
2.61.51.10.3-1.8-2.3
1.40.80.1-1.5-2
-1
Value neural network

Policy neural network
Policy Policy neural network
+5
+4
1
1 1 1 3
-1
+5
+4
1
1 1 1 3
-1
Deterministic (exploit) Stochastic (explore-exploit)

Serrano.Academy Serrano.Academy
RLHF

Serrano.Academy Serrano.Academy
Blue
What color is the sky?
Red
Banana
Transformers need to be fine tuned
by humans!

Serrano.Academy Serrano.Academy
Training the value network
What color
is the sky?
What color
is the sky?
Blue
What color
is the sky?
Red
What color
is the sky?
BananaBlue Banana
Red
+2
+3
+1
Value
neural
network
Blue
Red
Banana
Policy
neural
network

Serrano.Academy Serrano.Academy
Roses are red
Longer paths
+2
+3
+1
Roses are red,
because they like
ketchup.
Roses are red,
violets are blue
Roses are red,
violets are humans
too
Roses are red,
violets are
humans
Roses are red,
violets are
humans too
Roses are red,
violets are
humans too
<end>
Roses are red,
Roses are red,
violets
Roses are red,
violets are
Roses are red,
because they
like
Roses are red,
because they
like ketchup
Roses are red,
because they
like ketchup
<end>
Roses are red,
because they
Roses are red,
because
Roses are red,Roses are red,
Roses are red,
violets
Roses are red,
violets are
Roses are red,
violets are
blue
Roses are red,
violets are
blue <end>
Value
neural
network
Policy
neural
network

Serrano.Academy Serrano.Academy
Obviously it’s not a grid
What color
is the sky?
Hello, how
are you?
The Iliad
Once upon a
time…
Hello, how
are you?
Good

Serrano.Academy Serrano.Academy
Obviously it’s not a grid
What color
is the sky?
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
monkeyWhat color
is the sky?
Banana
What color
is the sky?
The
What color
is the sky?
Banana
monkey
dishwasher
Hello, how
are you?
The Iliad
Once upon a
time…
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
purple
monkey
What color
is the sky?
Banana
purple
What color
is the sky?
Banana
What color
is the sky?
Banana
purple
monkey
dishwasher
Hello, how
are you?
Good

Serrano.Academy Serrano.Academy
Obviously it’s not a grid
What color
is the sky?
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
monkeyWhat color
is the sky?
Banana
What color
is the sky?
The
What color
is the sky?
Banana
monkey
dishwasher
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
purple
monkey
What color
is the sky?
Banana
purple
What color
is the sky?
Banana
What color
is the sky?
Banana
purple
monkey
dishwasher
Value
neural
network
Policy
neural
network
+3
+1

Proximal Policy Optimization
How to train the two networks:

Training the value neural network

Two neural networks
+5
+41
1 1 1 3
-1
2 1 2 3 4+5
3 2 1 2 3 4
+41 0 1 2 3
3 1-11 1 3
2 1 0-1-2-2
1 0-1-2-2-1
Value
neural
network
Policy
neural
network
Approximates
values
Approximates
policy

Loss function for the training the value network
L
value
(θ)=??????[(V
θ(s
t)−R
t)
2
]
2 1 2 3 4+5
3 2 1 2 3 4
+41 0 1 2 3
3 1-11 1 3
2 1 0-1-2-2
1 0-1-2-2-1

+5
+4
1
1 1 1 3
-1
Training the value neural network
1.2-10.50.60.8+5
0.81.2-0.50.40.51.2
+4
1-0.30.61.20.4
-0.3 10.8 1 1 3
-10.5-1-0.30.40.5
0.800.50.4-0.3-1
Value neural network Actual values
0.40.51.2
-0.30.6
0.8
32
10
-1
42 3 4
0 1
-1
Policy neural network

Training the value neural network
Value neural network
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
Actual values

2.566.257.84
0.090.16
3.24
=
( )
2
1
6
7.846.252.56
0.160.09
3.24

0000000000000000000000000000
6
Training the value neural network
Value neural network
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
Actual values

=
( )
2
Loss
1
6
7.846.252.560.160.093.24+ + + + +
=3.36

0000000000000000000000000000
6
Training the value neural network
=
Loss
7.846.252.560.160.093.24+ + + + +
=3.36
Use this loss to
train the value
neural network

Training the value neural network
1.2-10.50.60.8+5
0.81.2-0.50.40.51.2
+4
1-0.30.61.20.4
-0.3 10.8 1 1 3
-10.5-1-0.30.40.5
0.800.50.4-0.3-1
0.60.91.8
-0.20.7
0.6
Value neural network
0.1
-0.6
1.40.50.4
0.40.1-0.3
-0.7
0.50.8
2.11.7
3.11.42.53.22.7+5
2.2-0.21.32.94.13.2
+4
1 -0.51.72.83.2
3.9 1 0.2 1 1 3
2.61.51.10.3-1.8-2.3
1.40.80.1-1.5-2-1

Formula for the Loss function
L
value
(θ)=??????[(V
θ(s
t)−R
t)
2
]
Expected value (average)
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
−( )
2
1
6

Formula for the Gain
L
value
(θ)=??????[(V
θ(s
t)−R
t)
2
]
Values given by the neural network
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
−( )
2
1
6

Formula for the Gain
L
value
(θ)=??????[(V
θ(s
t)−R
t)
2
]
Actual values (calculated using the path)
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
−( )
2
1
6

Formula for the Gain
L
value
(θ)=??????[(V
θ(s
t)−R
t)
2
]
Actual values (calculated using the path)
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
−( )
2
1
6

Formula for the Gain
L
value
(θ)=??????[(V
θ(s
t)−R
t)
2
]
Use this loss to
train the value
neural network

Training the policy neural network: The surrogate

Two neural networks
2 1 2 3 4+5
3 2 1 2 3 4
+41 0 1 2 3
3 1-11 1 3
2 1 0-1-2-2
1 0-1-2-2-1
Value
neural
network
Perfect
values
+5
+41
1 1 1 3
-1
Policy
neural
network
Perfect
policy

Loss function for training the policy network
L
policy
(θ)=
π
θ
(a
t
|s
t
)
π
θ
old
(a
t|s
t)
A
t
+5
+41
1 1 1 3
-1

+5
+4
1
1 1 1 3
-1
Training the value neural network
1.2-10.50.60.8+5
0.81.20.80.40.51.2
+4
1 11.31.20.4
-0.3 10.8 1 1 3
-10.5-1-0.30.40.5
0.800.50.4-0.3-1
Value neural network Actual values
0.40.51.2
11.3
0.8
32
10
-1
42 3 4
0 1
-1
Policy neural network

+5 +5
Training the value neural network
Value neural network Actual values
0.40.51.2
11.3
0.8
32
10
-1
41.2 4
−p()( )
Gain
π
θ
(a
t
|s
t
)
π
θ
old
(a
t|s
t)
A
t
Impulse

Impulse (Momentum)
v = 10 v = 20 v = -10 v = 5
Force Force
Scenario 1 Scenario 2
Impulse Impulse

Impulse (Momentum)
Iteration
Probability
Gain
Gain
Scenario 1 Scenario 2
Iteration
Probability
Impulse Impulse

Impulse (Momentum)
Iteration
Probability
Gain
Gain
Scenario 3 Scenario 4
Iteration
Probability
Impulse
Impulse

Impulse (Momentum)
Iteration 99
p=0.2
Iteration 100 10
p=0.4
Iteration 101
p=0.8
Gain
Scenario 1 Scenario 2
Iteration 99
p=0.8
Iteration 100 10
p=0.4
Iteration 101
p=0.5
Gain

10 10
Dividing by the previous probability
Iteration 99
p=0.2
Iteration 100 10
p=0.4 Gain
Scenario 1 Scenario 2
Iteration 99
p=0.8
Iteration 100 10
p=0.4 Gain
0.4
0.2
Surrogate
objective
function
0.4
0.8
=20 =5
0.4
0.2
0.4
0.8
10 10

Surrogate Objective Function
L
policy
(θ)=
π
θ(a
t
|s
t)
π
θ
old
(a
t|s
t)
A
t
Probability at iteration 100
Probability at iteration 99
Gain (difference of value)
at iteration 100

Serrano.Academy Serrano.Academy
Direct Preference Optimization

Serrano.Academy Serrano.Academy

Serrano.Academy Serrano.Academy

Serrano.Academy Serrano.Academy
Score: 2
Score: 5
Reinforcement Learning with Human Feedback (RLHF)
banana
time
Transformer
Once upon a
Reward: 2
Reward: 5
Reward
Model
Transformer
Model
(the one that talks)
(the one that guesses how
the human will rate outputs)
Direct Preference Optimization (DPO)
Loss function

DPO

θ

ref
)=−??????
(x,y
w,y
l)∼D
logσ
(
βlog
π
θ
(y
w
|x)
π
ref
(y
w|x)
−βlog
π
θ
(y
l
|x)
π
ref
(y
l|x))

Serrano.Academy Serrano.Academy
Score: 2
Score: 5
banana
time
Transformer
Once upon a
Reward: 2
Reward: 5
Reward
Model
Transformer
Model
(the one that talks)
(the one that guesses how
the human will rate outputs)
Direct Preference Optimization (DPO)
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL

θ
(y|x)∥π
ref
(y|x)]πr π

Serrano.Academy Serrano.Academy
DPO main equation
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL

θ
(y|x)∥π
ref
(y|x)]
Maximizes the rewards Prevents the model from
changing too drastically
πr π
x=Once upon a y=time
xy y yx x
y=banana

Serrano.Academy Serrano.Academy
Two things
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL

θ
(y|x)∥π
ref
(y|x)]πr πxy y yx x
How to get rid of this reward?
(not get rid of it, but turn it into a probability)
How to measure how different two models are?
KL Divergence
Bradley-Terry Model
Original
model
Improved
model

Serrano.Academy Serrano.Academy
Turning rewards(scores) into probabilities
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]rxy
Bradley-Terry Model
How to get rid of this reward?
(not get rid of it, but turn it into a probability)
σ(r
w
−r
l
)

Serrano.Academy Serrano.Academy
Mathematical Manipulation
The DPO loss function

DPO

θ

ref
)=−??????
(x,y
w,y
l)∼D
logσ
(
βlog
π
θ
(y
w
|x)
π
ref
(y
w|x)
−βlog
π
θ
(y
l
|x)
π
ref
(y
l|x))
p(y
1
>y
2
)=
e
r(x,y
1)
e
r(x,y
1)
+e
r(x,y
2)
Bradley Terry Model
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL

θ
(y|x)∥π
ref
(y|x)]πr π

Serrano.Academy Serrano.Academy
The DPO loss function

DPO

θ

ref
)=−??????
(x,y
w,y
l)∼D
logσ
(
βlog
π
θ
(y
w
|x)
π
ref
(y
w|x)
−βlog
π
θ
(y
l
|x)
π
ref
(y
l|x))
log

p
i
Maximize probability
for good response
Don’t change the
model too much
Average over
generated responses

logp
i=
Probability
of some event
Probability
of all events
(independently)
Minimize probability
for bad response
max
min

Serrano.Academy Serrano.Academy

Serrano.Academy Serrano.Academy
Series of 4 videos in Reinforcement Learning

Serrano.Academy Serrano.Academy
Grokking Machine Learning

Serrano.Academy Serrano.Academy
llm.university
Tags