Official implementation for ThinkTwice, a two-phase extension of Group Relative Policy Optimization (GRPO) that jointly optimizes LLMs to solve reasoning problems and refine their answers. In each training cycle, ThinkTwice first trains the model on the reasoning task and subsequently on revising its responses, using a consistent correctness reward without external guidance.
- Hardware: 2+ NVIDIA GPUs (tested on A100/H100)
- Software: Linux, CUDA 12.x, Conda
conda create -n verl python=3.11 -y
conda activate verl
pip install -e verl/
pip install flash-attn --no-build-isolationThe evaluation benchmarks are built from HuggingFace datasets. Run the preparation scripts to generate parquet files under scratch/:
conda activate verl
python math_eval/ppc/math500.py
python math_eval/ppc/aime2024.py
python math_eval/ppc/amc.py
python math_eval/ppc/minerva_math.py
python math_eval/ppc/olympiadbench.pyThe training data (scratch/hendrycks_math/train.parquet) and combined validation set (scratch/math_combined/test.parquet) should also be prepared before training.
Download the base model weights to a local directory:
| Model | HuggingFace ID |
|---|---|
| Qwen3-4B-Instruct-2507 | Qwen/Qwen3-4B-Instruct-2507 |
| OLMo-3-7B-Instruct | allenai/OLMo-3-7B-Instruct |
The model paths are configured at the top of each training script (actor_rollout_ref.model.path). Update them to point to your local copies.
All training scripts are self-contained and one-click runnable. They activate the conda environment, configure Ray, and launch the trainer with the appropriate Hydra overrides.
Trains Qwen3-4B-Instruct-2507 with ThinkTwice:
bash verl/run_thinktwice_qwen3.shTrains OLMo-3-7B-Instruct with ThinkTwice:
bash verl/run_thinktwice_olmo3.shGenerates multiple samples per problem and estimates pass@k (k=1,2,4,8,16,32,and more) for both base responses and self-refinement responses:
python math_eval/reward/evaluate_passatk.pyEvaluates each model as a refinement model applied to base solutions generated by every other model.
python math_eval/reward/evaluate_cross_refinement.py