from trl import GoldTrainer, GoldConfig
from datasets import load_dataset
dataset = load_dataset('json', data_files='training_data.jsonl')
config = GoldConfig(
use_uld_loss=True, # Universal Logit Distillation
lmbda=0.5, # Mix of student/teacher outputs
beta=0.5 # Loss interpolation
)
trainer = GoldTrainer(
model=student_model,
teacher=teacher_model,
args=config,
train_dataset=dataset, # DeepFabric data
)
trainer.train()