MaskSearch: Training LLMs for Expert-Level Search Capabilities

6月11日 Published inAI Agent Tools

MaskSearch is a pretraining framework designed to sharpen a model's fundamental search capabilities. Rather than relying on fine-tuning for narrow, specific tasks, it introduces a Retrieval Augmented Mask Prediction (RAMP) objective. Under this framework, the model learns to reconstruct masked segments of pretraining data by actively utilizing search tools. This process establishes a robust foundation for retrieval and reasoning, which translates into superior performance for large language model (LLM) agents.

The generation of training data combines agentic workflows with knowledge distillation. The process begins with a multi-agent system—consisting of a Planner, a Rewriter, and an Observer—to coordinate complex search paths. Subsequently, a self-evolving teacher model refines these outputs to create high-quality reasoning trajectories for the student model.

Empirical results highlight the effectiveness of this approach. MaskSearch consistently outperforms existing methods for LLM-based search agents across both in-domain and out-of-domain tasks.

The following table details the evaluation across several open-domain QA benchmarks:

Method Pre-train Post-train HotpotQA FanoutQA Musique 2Wiki Bamboogle FreshQA Avg.
Qwen2.5-1.5B
Agent-PE 41.23 46.51 23.45 56.12 49.80 61.23 46.39
RAG-PE 38.23 42.18 19.82 49.52 35.20 62.18 41.19
Distilled Search-R1 SFT 64.31 52.18 36.41 78.23 63.15 76.12 61.73
Search-R1 RL 61.72 42.68 35.63 64.40 64.59 74.44 57.24
MASKSEARCH SFT SFT 65.18 54.22 38.76 81.23 70.18 77.82 64.57
MASKSEARCH RL RL 66.23 55.18 39.23 81.76 72.34 79.18 65.65
Qwen2.5-3B
RAG-PE 38.37 41.48 20.78 51.14 37.60 61.55 41.82
Agent-PE 51.17 49.82 25.27 58.14 56.40 67.80 48.10
Distilled Search-R1 SFT 67.38 54.00 38.20 79.76 68.05 77.59 64.17
Search-R1 RL 69.03 48.55 39.08 78.85 72.53 76.78 64.14
MASKSEARCH SFT SFT 69.30 56.03 40.12 82.36 74.52 79.84 67.03
MASKSEARCH RL RL 73.08 53.02 44.48 80.43 80.13 85.07 69.37
Qwen2.5-7B
RAG-PE 43.55 51.92 25.05 53.86 44.60 64.40 47.23
Agent-PE 61.75 55.69 34.25 68.77 63.25 75.81 58.25
Distilled Search-R1 SFT 69.55 57.24 41.06 83.84 73.07 78.97 67.29
Search-R1 RL 70.59 56.25 41.29 80.50 79.33 78.46 67.74
MASKSEARCH SFT SFT 70.44 60.85 41.76 84.65 80.13 81.12 69.83
MASKSEARCH RL RL 75.61 60.98 45.15 85.32 82.10 84.23 72.23

*✗ indicates no specific pre-training or post-training in that category. SFT = Supervised Fine-Tuning. RL = Reinforcement Learning. RAG-PE and Agent-PE are baseline prompting methods.

Installation and Setup

Before running the framework, replace the placeholder keys in the following files with your Qwen (DashScope) and Google Search credentials:

  • src/RAMP/model.py
  • src/multi_agent/model.py
  • src/multi_agent/web_news_get.py

Set your API keys as follows:

DASHSCOPE_API_KEY = "YOUR_API_KEY"
GOOGLE_API_KEY = "YOUR_API_KEY"

Install the necessary dependencies:

pip install -r requirements.txt

Generate RAMP QA Data

Generate the QA dataset using a Wikipedia dump. Point the --corpus argument to your local Wikipedia directory.

python gen_qa.py \
    --model "$model" \
    --corpus "Wikipedia Directory" \
    --output_path "output_path"

Build CoT Trajectories

Run the multi-agent pipeline to construct Chain-of-Thought (CoT) reasoning steps. This generates the SFT dataset. If necessary, configure your data paths within src/multi_agent/dataset.py.

python cot_construct.py \
    --model "$model" \
    --dataset "dataset" \
    --output_path "output_path"

Training

Once the data is prepared, begin the training process. For SFT, follow the standard LLaMA-Factory workflow. For RL training, refer to the configurations used in Search-R1 or ZeroSearch.