Reinforcement Learning with Human Feedback ODSC talk
LuisSerranoPhD
131 views
43 slides
Sep 05, 2024
Slide 1 of 77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
About This Presentation
Reinforcement Learning with Human Feedback, PPO, and DPO
Size: 8.18 MB
Language: en
Added: Sep 05, 2024
Slides: 43 pages
Slide Content
Reinforcement Learning with Human Feedback
(RLHF)
Luis Serrano
Topics
•Large Language Models
•How to fine-tune them with reinforcement learning
•Quick intro to reinforcement learning
•PPO (the reinforcement learning technique to fine-tune LLMs)
•DPO (the non-reinforcement learning technique to fine-tune LLMs)
Serrano.Academy Serrano.Academy
Large Language Models
Serrano.Academy Serrano.Academy
Transformers
I generate text… one word at a time
Transformers
Serrano.Academy Serrano.Academy
Transformers
Hello, how are you doing
Transformers
Serrano.Academy Serrano.Academy
Transformers
Write a story. Once
Transformers
Serrano.Academy Serrano.Academy
Transformers
Write a story.Once upon
Transformers
Serrano.Academy Serrano.Academy
Transformers
Write a story.Onceupon a
Transformers
Serrano.Academy Serrano.Academy
How to train transformer models?
The internetCurated datasets
Transformers
Serrano.Academy Serrano.Academy
Is that the end of the story? (Nope)
Serrano.Academy Serrano.Academy
Transformers
Blue
What color is the sky?
Red
Banana
Transformers need to be fine tuned
by humans!
Serrano.Academy Serrano.Academy
Transformers need to be fine tuned
Transformers
It’s, ehhm… Cairo?
What is the capital of Sudan?
No! It’s
Khartoum!
by humans!
Quick intro to Reinforcement Learning
1
1 1 1 3
+4
Gridworld
+4
+5
-1
1
States
+4
+5
-1
1
1 1 1 3
Actions
+5
+4
1
1 1 1 3
-1
+4
+5
-1
1
1
Actions
+5
+4
1
1 1 1 3
-1
+5
+4 1
-1
It costs 1 to move one spot
4
+5
+4 1
-1
What’s the best I can do from here?
32
4
2
+5
+4 1
-1
What’s the best I can do from here?
32
4
1
0
-1
-4 -3 -2
3
2
2
+5
+4 1
-1
How to calculate all the values?
2
2
2
4
4
3
3
3
2
2
2
1
1
1
1
1
0
0
0
-2
-2-1
-1
-1
3
3
-2
-2
Policy
Policy if it’s free to move
+5
+4
1
1 1 1 3
-1
+5
+4 1
-1
Where is the best place to move?
2 1 2 3 4
3 2 1 2 3 4
0 1 2 3
3 -1
2 1 0 -1 -2 -2
1 0 -1 -2 -2
Serrano.Academy Serrano.Academy
Blue
What color is the sky?
Red
Banana
Transformers need to be fine tuned
by humans!
Serrano.Academy Serrano.Academy
Training the value network
What color
is the sky?
What color
is the sky?
Blue
What color
is the sky?
Red
What color
is the sky?
BananaBlue Banana
Red
+2
+3
+1
Value
neural
network
Blue
Red
Banana
Policy
neural
network
Serrano.Academy Serrano.Academy
Roses are red
Longer paths
+2
+3
+1
Roses are red,
because they like
ketchup.
Roses are red,
violets are blue
Roses are red,
violets are humans
too
Roses are red,
violets are
humans
Roses are red,
violets are
humans too
Roses are red,
violets are
humans too
<end>
Roses are red,
Roses are red,
violets
Roses are red,
violets are
Roses are red,
because they
like
Roses are red,
because they
like ketchup
Roses are red,
because they
like ketchup
<end>
Roses are red,
because they
Roses are red,
because
Roses are red,Roses are red,
Roses are red,
violets
Roses are red,
violets are
Roses are red,
violets are
blue
Roses are red,
violets are
blue <end>
Value
neural
network
Policy
neural
network
Serrano.Academy Serrano.Academy
Obviously it’s not a grid
What color
is the sky?
Hello, how
are you?
The Iliad
Once upon a
time…
Hello, how
are you?
Good
Serrano.Academy Serrano.Academy
Obviously it’s not a grid
What color
is the sky?
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
monkeyWhat color
is the sky?
Banana
What color
is the sky?
The
What color
is the sky?
Banana
monkey
dishwasher
Hello, how
are you?
The Iliad
Once upon a
time…
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
purple
monkey
What color
is the sky?
Banana
purple
What color
is the sky?
Banana
What color
is the sky?
Banana
purple
monkey
dishwasher
Hello, how
are you?
Good
Serrano.Academy Serrano.Academy
Obviously it’s not a grid
What color
is the sky?
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
monkeyWhat color
is the sky?
Banana
What color
is the sky?
The
What color
is the sky?
Banana
monkey
dishwasher
What color
is the sky?
The sky is
blue
What color
is the sky?
The sky is
What color
is the sky?
The sky
What color
is the sky?
The
What color
is the sky?
Banana
purple
monkey
What color
is the sky?
Banana
purple
What color
is the sky?
Banana
What color
is the sky?
Banana
purple
monkey
dishwasher
Value
neural
network
Policy
neural
network
+3
+1
Proximal Policy Optimization
How to train the two networks:
Loss function for the training the value network
L
value
(θ)=??????[(V
θ(s
t)−R
t)
2
]
2 1 2 3 4+5
3 2 1 2 3 4
+41 0 1 2 3
3 1-11 1 3
2 1 0-1-2-2
1 0-1-2-2-1
+5
+4
1
1 1 1 3
-1
Training the value neural network
1.2-10.50.60.8+5
0.81.2-0.50.40.51.2
+4
1-0.30.61.20.4
-0.3 10.8 1 1 3
-10.5-1-0.30.40.5
0.800.50.4-0.3-1
Value neural network Actual values
0.40.51.2
-0.30.6
0.8
32
10
-1
42 3 4
0 1
-1
Policy neural network
Training the value neural network
Value neural network
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
Actual values
−
2.566.257.84
0.090.16
3.24
=
( )
2
1
6
7.846.252.56
0.160.09
3.24
0000000000000000000000000000
6
Training the value neural network
Value neural network
2 3 4
0 1
-1
0.40.51.2
-0.30.6
0.8
Actual values
−
=
( )
2
Loss
1
6
7.846.252.560.160.093.24+ + + + +
=3.36
0000000000000000000000000000
6
Training the value neural network
=
Loss
7.846.252.560.160.093.24+ + + + +
=3.36
Use this loss to
train the value
neural network
Loss function for training the policy network
L
policy
(θ)=
π
θ
(a
t
|s
t
)
π
θ
old
(a
t|s
t)
A
t
+5
+41
1 1 1 3
-1
+5
+4
1
1 1 1 3
-1
Training the value neural network
1.2-10.50.60.8+5
0.81.20.80.40.51.2
+4
1 11.31.20.4
-0.3 10.8 1 1 3
-10.5-1-0.30.40.5
0.800.50.4-0.3-1
Value neural network Actual values
0.40.51.2
11.3
0.8
32
10
-1
42 3 4
0 1
-1
Policy neural network
+5 +5
Training the value neural network
Value neural network Actual values
0.40.51.2
11.3
0.8
32
10
-1
41.2 4
−p()( )
Gain
π
θ
(a
t
|s
t
)
π
θ
old
(a
t|s
t)
A
t
Impulse
Impulse (Momentum)
v = 10 v = 20 v = -10 v = 5
Force Force
Scenario 1 Scenario 2
Impulse Impulse
Impulse (Momentum)
Iteration
Probability
Gain
Gain
Scenario 1 Scenario 2
Iteration
Probability
Impulse Impulse
Impulse (Momentum)
Iteration
Probability
Gain
Gain
Scenario 3 Scenario 4
Iteration
Probability
Impulse
Impulse
10 10
Dividing by the previous probability
Iteration 99
p=0.2
Iteration 100 10
p=0.4 Gain
Scenario 1 Scenario 2
Iteration 99
p=0.8
Iteration 100 10
p=0.4 Gain
0.4
0.2
Surrogate
objective
function
0.4
0.8
=20 =5
0.4
0.2
0.4
0.8
10 10
Surrogate Objective Function
L
policy
(θ)=
π
θ(a
t
|s
t)
π
θ
old
(a
t|s
t)
A
t
Probability at iteration 100
Probability at iteration 99
Gain (difference of value)
at iteration 100
Serrano.Academy Serrano.Academy
Direct Preference Optimization
Serrano.Academy Serrano.Academy
Serrano.Academy Serrano.Academy
Serrano.Academy Serrano.Academy
Score: 2
Score: 5
Reinforcement Learning with Human Feedback (RLHF)
banana
time
Transformer
Once upon a
Reward: 2
Reward: 5
Reward
Model
Transformer
Model
(the one that talks)
(the one that guesses how
the human will rate outputs)
Direct Preference Optimization (DPO)
Loss function
ℒ
DPO
(π
θ
;π
ref
)=−??????
(x,y
w,y
l)∼D
logσ
(
βlog
π
θ
(y
w
|x)
π
ref
(y
w|x)
−βlog
π
θ
(y
l
|x)
π
ref
(y
l|x))
Serrano.Academy Serrano.Academy
Score: 2
Score: 5
banana
time
Transformer
Once upon a
Reward: 2
Reward: 5
Reward
Model
Transformer
Model
(the one that talks)
(the one that guesses how
the human will rate outputs)
Direct Preference Optimization (DPO)
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL
[π
θ
(y|x)∥π
ref
(y|x)]πr π
Serrano.Academy Serrano.Academy
DPO main equation
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL
[π
θ
(y|x)∥π
ref
(y|x)]
Maximizes the rewards Prevents the model from
changing too drastically
πr π
x=Once upon a y=time
xy y yx x
y=banana
Serrano.Academy Serrano.Academy
Two things
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL
[π
θ
(y|x)∥π
ref
(y|x)]πr πxy y yx x
How to get rid of this reward?
(not get rid of it, but turn it into a probability)
How to measure how different two models are?
KL Divergence
Bradley-Terry Model
Original
model
Improved
model
Serrano.Academy Serrano.Academy
Turning rewards(scores) into probabilities
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]rxy
Bradley-Terry Model
How to get rid of this reward?
(not get rid of it, but turn it into a probability)
σ(r
w
−r
l
)
Serrano.Academy Serrano.Academy
Mathematical Manipulation
The DPO loss function
ℒ
DPO
(π
θ
;π
ref
)=−??????
(x,y
w,y
l)∼D
logσ
(
βlog
π
θ
(y
w
|x)
π
ref
(y
w|x)
−βlog
π
θ
(y
l
|x)
π
ref
(y
l|x))
p(y
1
>y
2
)=
e
r(x,y
1)
e
r(x,y
1)
+e
r(x,y
2)
Bradley Terry Model
max
π
θ
??????
x∼??????,y∼π
θ(y|x)
[r
ϕ
(x,y)]−β??????
KL
[π
θ
(y|x)∥π
ref
(y|x)]πr π
Serrano.Academy Serrano.Academy
The DPO loss function
ℒ
DPO
(π
θ
;π
ref
)=−??????
(x,y
w,y
l)∼D
logσ
(
βlog
π
θ
(y
w
|x)
π
ref
(y
w|x)
−βlog
π
θ
(y
l
|x)
π
ref
(y
l|x))
log
∏
p
i
Maximize probability
for good response
Don’t change the
model too much
Average over
generated responses
∑
logp
i=
Probability
of some event
Probability
of all events
(independently)
Minimize probability
for bad response
max
min
Serrano.Academy Serrano.Academy
Serrano.Academy Serrano.Academy
Series of 4 videos in Reinforcement Learning