DPO RL with Implicit Feedback

NYU Courant
Deep Decision Making and Reinforcement Learning, Spring 2024

Abstract

Descriptive Text for Image
Descriptive Text for Image

The use of Large Language Models (LLMs) in healthcare challenges traditional trust, because generated outputs are untraceable and sometimes inaccurate [1]. LLMs are trained on conflicting sources of information through the maximum likelihood objective in an unsupervised fashion [3]. This is typically followed by 1) Supervised Fine-Tuning 2) Preference sampling and reward learning and 3) Reinforcement Learning optimization (RLHF) [4]. Prior work such as [5] and [6] mapped the biases in human annotator preferences in the RLHF stage as the source of trickle-down biases in LLMs. Furthermore, RLHF is a complex and often unstable procedure, and obtaining alignment healthcare data at scale is inherently expensive [2]. In this project, we propose to explore updating the LLM policy network with the family of memory-efficient Direct Preference Optimization(DPO) algorithms [7]. To circumvent the need for human annotation we propose to leverage Multiple Choice Questions and generate answers for each option such that the answer corresponding to the correct option is the de facto chosen response equivalent. To benchmark, we propose comparing against the base LLM on the in-distribution test set. Moreover, rigorous testing against datasets like MMLU [8] splits for anatomy, clinical knowledge, medical genetics, etc. will enable an understanding of generalization trends.

Direct Preference Optimisation (DPO)

Descriptive Text for Image

DPO is a method used primarily in preference-based reinforcement learning. In this approach, instead of receiving explicit rewards, the learning algorithm is guided by preferences between different actions provided by an external source. The method is derived from Proximal Policy Optimization (PPO) for Large Language Models by relaxing the need of a critic model. Instead, the loss function is given as:

Descriptive Text for Image

where, the updated policy’s distribution over preferred and non-preferred responses is explicitly constrained w.r.t the original policy. Here beta corresponds to the strength of the KL-constrain in the original PPO objective given by:

Descriptive Text for Image

Thus, beta controls the tradeoff-between exploration and exploitation, with higher values of beta enabling higher exploration.

Experiment Setup

We look to further relax the requirement for a preference dataset in two settings, Reinforcement Learning from AI Feedback (RLAIF) and Reinforcement Learning from Implicit Feedback (RLIF). We believe this will utilize existing & abundant resources such as multiple-choice question & answer datasets as well as existing pre-trained models. Our intuition is to lower the cost, speed up the development and remove biases from the training pipeline of LLMs.

Descriptive Text for Image

RLAIF - Open World Problem

The stanford-nlp IMDB reviews dataset is a collection of movie reviews from IMDB. The dataset contains a total of 50,000 reviews, each of which is annotated with a sentiment label (positive or negative). The dataset is divided into two subsets: the training set and the test set. The positive reviews are used for supervised fine-tuning of a base gpt2-large model as a next word prediction task. The model thus learns how to generate positive reviews of movies given a few starting tokens as a prior. Now, the first few tokens (2-8) of the entire IMDB reviews dataset (positive and negative) is used prompt the supervised fine-tuned (SFT) gpt2-large model to generate positive reviews. Each prompt generated four reviews, the reviews are varied due to the use of different decoding strategies like temperature settings, beam-search, top-k and top-p. Each of these reviews are fed into sentiment-roberta-large model to acertain the positive sentiment score of the reviews. These scores are used to create 6 pairs of reviews from the 4 individual reviews, with the review with the higher score among the two being the "chosen"/"accepted"/"preferred" response and the other review being the "rejected"/"unaccepted"/"disliked" response. This, along with the input prompt tokens creates our preference dataset. This preference dataset is used to run Direct Preference Optimisation on the SFT gpt2-large model, producing our final aligned gpt2-large model.

Descriptive Text for Image
Figure 1: RLAIF - Open World Problem
Descriptive Text for Image
Figure 2: RLAIF - Reward Hacking

As showcased in Figure 3, the Beta value affects the tradeoff between exploration and exploitation. Higher values of beta enable higher exploration, but at the cost of higher exploitation. This is apparent in the eval_ngrams plot where the number of unique ngrams and plotted against training steps for various values of beta. We notice that for beta=0.01, the number of unqiuue ngrams is significantly lower than for higher values of beta.

Descriptive Text for Image
Figure 3: Reward Hacking Example

Comparing the outputs of models trained with two values of beta - 0.01 and 0.6, we can see that the model trained with beta=0.01 repeats some phrases like "out of 10" and "miss miss missing miss". Whereas the model with beta=0.6 generates a much more unique and nuanced response. It touches on relevant points like "acting", "direction", "soundtrack", "cinematography".

Descriptive Text for Image
Figure 4: Reward Margins

RLIF - Closed World Problem

Descriptive Text for Image
Figure 6: Medical Question Answering
The Medhalt dataset is a multinational dataset derived from medical examinations across various countries and includes multiple choice questions (MCQs) from NEET, AIIMS PG, USMILE, TWMLE etc. The dataset contains a total of 16,000 MCQs (Question, Options, Correct Option index). We use a base Llama-2-13-B model to generate "explainations"/"reasons" for why each of the options (whether correct or not) should be the correct choice. Llama-2 is at such a junction point in its training and capabilities that it doesn't contradict the prompt (eg. "Actually this option cannot be the correct one") but is still able to hullucinate believable explanations for why the option is correct. This data is used to create two datasets. First, the question and options are matched to the explaination for the correct option. This is our SFT dataset. Second, the question and option are matched with a pair of correct option reasoning (remains constant for a question) and incorrect option reasoning (the other 3). Thus for each question, we have three rows of such pairs. This creates our preference dataset with the correct option reasoning as the "chosen"/"accepted"/"preferred" response and the other three reasonings as the "rejected"/"unaccepted"/"disliked" response. Like before, we use the SFT dataset to train the base gpt2-large model and then use the preference dataset to run Direct Preference Optimisation on the SFT gpt2-large model.

Descriptive Text for Image
Figure 5: RLIF - Closed World Problem
Descriptive Text for Image Descriptive Text for Image Descriptive Text for Image

The model trained on DPO with explainations performs the best on an MMLU test dataset with an increment of 5.5% over Base GPT and 1.5% over SFT on options.