# GKD

GKD (Generalized Knowledge Distillation) training algorithm is proposed in the paper [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://arxiv.org/pdf/2306.13649). This algorithm transfers knowledge from the teacher model to the student model by combining offline and on-policy learning strategies.

## Loss Function

Given an input sequence $x$ and output sequence $y$, the GKD loss function can be written as:

$$
\mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D(P_{\text{teacher}}(\cdot | x, y_{<t}), P_{\text{student}}(\cdot | x, y_{<t}))
$$

Where:
- $y_{<t} = (y_1, y_2, \ldots, y_{t-1})$: sequence of the first $t-1$ tokens
- $P_{\text{teacher}}(\cdot | x, y_{<t})$: output probability distribution of the teacher model given context $x, y_{<t}$
- $P_{\text{student}}(\cdot | x, y_{<t})$: output probability distribution of the student model given context $x, y_{<t}$
- $D(\cdot, \cdot)$: divergence function to measure the difference between two probability distributions

## Divergence Metrics

### KL Divergence (Kullback-Leibler Divergence)

KL divergence is an asymmetric measure of the difference between two probability distributions $P$ and $Q$:

$$
\text{KL}(P \| Q) = \sum_v P(v) \log \frac{P(v)}{Q(v)} = \mathbb{E}_{v \sim P}\left[\log \frac{P(v)}{Q(v)}\right]
$$

### Forward KL and Reverse KL

In knowledge distillation, there are two choices depending on the order of the two distributions in the KL divergence:

#### Forward KL

$$
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
$$

**Characteristics**: Mode-covering
- Expectation is computed under the teacher distribution
- The student model tends to cover the entire teacher distribution (including low-probability regions)

#### Reverse KL

$$
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
$$

**Characteristics**: Mode-seeking
- Expectation is computed under the student distribution
- The student model tends to concentrate on the peak regions (high-probability areas) of the teacher model

### Generalized Jensen-Shannon Divergence (Generalized JSD)

GKD uses generalized JSD as the core metric, performing **smooth interpolation** between Forward KL and Reverse KL through parameter $\beta \in [0, 1]$.

For two probability distributions $P$ and $Q$, generalized JSD is defined as:

$$
D_{\text{JSD}(\beta)}(P, Q) = \beta \cdot \text{KL}(P \| M) + (1-\beta) \cdot \text{KL}(Q \| M)
$$

Where the mixture distribution $M$ is defined as:

$$
M = \beta \cdot P + (1-\beta) \cdot Q
$$

- When $\beta = 0.5$, it reduces to the standard symmetric JSD
- By adjusting $\beta$, one can trade off between Mode-seeking and Mode-covering

In GKD, we set $P = P_{\text{teacher}}$ and $Q = P_{\text{student}}$, therefore:

$$
D_{\text{JSD}(\beta)}(P_{\text{teacher}}, P_{\text{student}}) = \beta \cdot \text{KL}(P_{\text{teacher}} \| M) + (1-\beta) \cdot \text{KL}(P_{\text{student}} \| M)
$$

Where $M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}$

> For extreme cases ($\beta = 0$ or $\beta = 1$), directly compute a single KL divergence:
> - When $\beta = 0$: directly define $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$ (Forward KL, Mode-covering)
> - When $\beta = 1$: directly define $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$ (Reverse KL, Mode-seeking)
> - When $0 < \beta < 1$: use the above mixture distribution formula for interpolation

By adjusting the $\beta$ parameter, interpolation can be performed between different divergence metrics. When $\beta = 0.5$, the divergence is the standard symmetric JSD.

## Three Training Modes

GKD training has three training modes, distinguished by the source of the output sequence $y$.

### Mode Selection Logic

During training, each sample selects a mode according to the following priority:

```python
# Pseudocode: mode selection logic
if random() < lmbda:
    # Mode 1: On-Policy learning, output sequence sampled by student model
    y = student.generate(x)
    source = "student"
elif seq_kd:
    # Mode 2: Sequential KD, output sequence sampled by teacher model
    y = teacher.generate(x)
    source = "teacher"
else:
    # Mode 3: Offline learning, use output sequence from dataset
    y = y_ground_truth
    source = "dataset"

# Same loss function
loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
```

### Mode 1: On-Policy Learning
Set parameter `lambda`, triggered with probability $\lambda$, using student model sampling $y \sim P_{\text{student}}(\cdot | x)$

- The student model learns from **sequences generated by itself**
- Exposed to errors it might make, learning to **self-correct and recover from errors**
- Aligns training distribution with inference distribution
- Improves model robustness and practical application performance

**Applicable Scenarios**:
- The student model already has certain generation capabilities
- Want to improve model performance in real inference scenarios

### Mode 2: Sequential KD (`seq_kd=True` and on-policy not triggered)
Set parameter `seq_kd=True`, when on-policy is not triggered, use teacher model sampling

**Data Source**: $y \sim P_{\text{teacher}}(\cdot | x)$

### Mode 3: Offline Learning (other cases)

**Data Source**: $y = y^* \sim \text{Dataset}$

- The student model learns from **annotated sequences in the dataset**


## Parameter Settings

We can perform GKD training by setting the following parameters:

| Parameter | Type | Default | Range | Description |
|------|------|--------|---------|------|
| `--teacher_model` | str | Required | - | Teacher model path or model ID |
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient<br>• 0.0: Forward KL <br>• 0.5: JSD (balanced)<br>• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability<br>• 0.0: Pure Offline<br>• 0.5: Mixed strategy (**recommended**)<br>• 1.0: Pure On-Policy |
| `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences<br>• False: Use dataset when not on-policy<br>• True: Use teacher generation when not on-policy |
| `--temperature` | float | 0.9 | > 0 | Generation sampling temperature, controls randomness |
| `--sft_alpha` | float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
| `--max_completion_length` | int | 512 | > 0 | Maximum number of tokens during generation |

## Sampling Acceleration

In GKD training, there are two types of online sampling scenarios:

1. **Student model sampling** (when `lmbda > 0`): triggered with probability $\lambda$
2. **Teacher model sampling** (when `seq_kd=True`): triggered with probability $1-\lambda$

Since the sampling process significantly slows down training speed, you can refer to the following two acceleration schemes:

### Solution 1: Student Model Sampling Acceleration

**Requirement**: swift >= 3.10.dev

Use vLLM as the inference backend to accelerate student model sampling. Supports two deployment modes, consistent with GRPO. Refer to [GRPO documentation](./GRPO/GetStarted/GRPO.md#cluster-support)

> **Note**: vLLM acceleration only applies to student model on-policy sampling (`lmbda > 0`). Teacher model sequential KD sampling (`seq_kd=True`) currently still uses Transformers. Pre-sampling scheme is recommended.

Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh), for related parameters, please refer to [GRPO vLLM Parameters](./Command-line-parameters.md#vllm_mode).


### Solution 2: Teacher Model Pre-sampling

For teacher model sampling (`seq_kd=True`), **pre-sampling** is recommended: first use the teacher model to offline generate high-quality data, then train.

**Step 1: Generate data using teacher model**
```bash
export teacher_model='OpenGVLab/InternVL3-8B'

NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift infer \
    --model $teacher_model \
    --infer_backend vllm \
    --val_dataset 'modelscope/coco_2014_caption:validation#5000' \
    --vllm_gpu_memory_utilization 0.9 \
    --vllm_max_model_len 8192 \
    --max_new_tokens 2048 \
    --write_batch_size 1000 \
    --result_path teacher_generated_data.jsonl
```

**Step 2: Train using pre-generated data**
```bash
swift rlhf \
    --rlhf_type gkd \
    --model OpenGVLab/InternVL3-2B-Pretrained \
    --teacher_model $teacher_model \
    --dataset 'teacher_generated_data.jsonl' \
    --seq_kd false \
    ...
```

Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/fast.sh)


## On-Policy Distillation
We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/) training described in the Thinking Machines Lab blog by setting the following parameters:

```bash
--lmbda 1 # on-policy
--beta 1 # reverse
```

For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh).
