diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index 1b3a43de..612ebf7d 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -38,6 +38,11 @@ def get_model_config( disable_log_requests=True, enable_sleep_mode=enable_sleep_mode, generation_config="vllm", + # Tie vLLM context/gpu utilization defaults to Unsloth init args to avoid + # zeroed KV cache calculations on large cards (e.g., H100) when defaults + # are missing. + gpu_memory_utilization=init_args["gpu_memory_utilization"], # type: ignore + max_model_len=init_args["max_seq_length"], # type: ignore ) engine_args.update(config.get("engine_args", {})) init_args.update(config.get("init_args", {}))