diff --git a/inference/huggingface/text-generation/arguments.py b/inference/huggingface/text-generation/arguments.py index 87d2def0e..60fc6698e 100644 --- a/inference/huggingface/text-generation/arguments.py +++ b/inference/huggingface/text-generation/arguments.py @@ -17,4 +17,5 @@ parser.add_argument("--test_performance", action='store_true', help="enable latency, bandwidth, and throughout testing") parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank") parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", "1")), help="world_size") -parser.add_argument("--test_hybrid_engine", action='store_true', help="enable hybrid engine testing") \ No newline at end of file +parser.add_argument("--test_hybrid_engine", action='store_true', help="enable hybrid engine testing") +parser.add_argument("--quantize_groups", type=int, required=False, default=0, help="number of weight quantization groups to use") \ No newline at end of file diff --git a/inference/huggingface/text-generation/inference-test.py b/inference/huggingface/text-generation/inference-test.py index f8e1dc548..07a6ed072 100644 --- a/inference/huggingface/text-generation/inference-test.py +++ b/inference/huggingface/text-generation/inference-test.py @@ -49,6 +49,7 @@ replace_with_kernel_inject=args.use_kernel, max_tokens=args.max_tokens, save_mp_checkpoint_path=args.save_mp_checkpoint_path, + quantize_groups=args.quantize_groups, **ds_kwargs )