|
1 | 1 | import logging |
2 | 2 | import os |
3 | 3 | import json |
| 4 | +import shlex |
4 | 5 | from typing import List, Optional |
5 | 6 | import ray |
6 | 7 | import ray._private.runtime_env.constants as runtime_env_constants |
|
26 | 27 | async def _create_impl(image_uri: str, logger: logging.Logger): |
27 | 28 | # Pull image if it doesn't exist |
28 | 29 | # Also get path to `default_worker.py` inside the image. |
29 | | - pull_image_cmd = [ |
30 | | - "podman", |
31 | | - "run", |
32 | | - "--rm", |
33 | | - image_uri, |
34 | | - "python", |
35 | | - "-c", |
36 | | - ( |
37 | | - "import ray._private.workers.default_worker as default_worker; " |
38 | | - "print(default_worker.__file__)" |
39 | | - ), |
40 | | - ] |
41 | | - logger.info("Pulling image %s", image_uri) |
42 | | - worker_path = await check_output_cmd(pull_image_cmd, logger=logger) |
| 30 | + custom_pull_cmd = os.getenv("RAY_PODMAN_PULL_CMD", "") |
| 31 | + if custom_pull_cmd: |
| 32 | + logger.info("Using custom pull command: %s", custom_pull_cmd) |
| 33 | + shell_cmd = ["sh", "-c", custom_pull_cmd] |
| 34 | + worker_path = await check_output_cmd(shell_cmd, logger=logger) |
| 35 | + else: |
| 36 | + pull_image_cmd = [ |
| 37 | + "podman", |
| 38 | + "run", |
| 39 | + "--rm", |
| 40 | + image_uri, |
| 41 | + "python", |
| 42 | + "-c", |
| 43 | + ( |
| 44 | + "import ray._private.workers.default_worker as default_worker; " |
| 45 | + "print(default_worker.__file__)" |
| 46 | + ), |
| 47 | + ] |
| 48 | + logger.info("Pulling image %s", image_uri) |
| 49 | + worker_path = await check_output_cmd(pull_image_cmd, logger=logger) |
| 50 | + |
| 51 | + output_filter = os.getenv("RAY_PODMAN_OUTPUT_FILTER", "") |
| 52 | + if output_filter: |
| 53 | + worker_path = await _apply_output_filter(logger, worker_path, output_filter) |
43 | 54 | return worker_path.strip() |
44 | 55 |
|
| 56 | +async def _apply_output_filter(logger, worker_path, output_filter): |
| 57 | + safe_worker_path = shlex.quote(worker_path) |
| 58 | + filter_cmd = ["sh", "-c", f"printf '%s' {safe_worker_path} | {output_filter}"] |
| 59 | + filtered_path = await check_output_cmd(filter_cmd, logger=logger) |
| 60 | + worker_path = filtered_path |
| 61 | + return worker_path |
45 | 62 |
|
46 | 63 | def _modify_container_context_impl( |
47 | 64 | runtime_env: "RuntimeEnv", # noqa: F821 |
@@ -157,9 +174,9 @@ def _modify_container_context_impl( |
157 | 174 |
|
158 | 175 | redirected_pyenv_folder = None |
159 | 176 | if container_install_ray or container_pip_packages: |
160 | | - container_to_host_mount_dict[ |
161 | | - container_dependencies_installer_path |
162 | | - ] = get_dependencies_installer_path() |
| 177 | + container_to_host_mount_dict[container_dependencies_installer_path] = ( |
| 178 | + get_dependencies_installer_path() |
| 179 | + ) |
163 | 180 | if runtime_env_constants.RAY_PODMAN_UES_WHL_PACKAGE: |
164 | 181 | container_to_host_mount_dict[get_ray_whl_dir()] = get_ray_whl_dir() |
165 | 182 |
|
@@ -253,6 +270,16 @@ def _modify_context_impl( |
253 | 270 | ray_tmp_dir: str, |
254 | 271 | ): |
255 | 272 | context.override_worker_entrypoint = worker_path |
| 273 | + custom_container_cmd = os.getenv("RAY_PODMAN_CONTAINER_CMD", "") |
| 274 | + if custom_container_cmd: |
| 275 | + custom_container_cmd_str = custom_container_cmd.format( |
| 276 | + ray_tmp_dir=ray_tmp_dir, image_uri=image_uri |
| 277 | + ) |
| 278 | + logger.info( |
| 279 | + f"Starting worker in container with prefix {custom_container_cmd_str}" |
| 280 | + ) |
| 281 | + context.py_executable = custom_container_cmd_str |
| 282 | + return |
256 | 283 |
|
257 | 284 | container_driver = "podman" |
258 | 285 | container_command = [ |
@@ -285,6 +312,12 @@ def _modify_context_impl( |
285 | 312 | if env_var_name.startswith("RAY_"): |
286 | 313 | env_vars[env_var_name] = env_var_value |
287 | 314 |
|
| 315 | + extra_env_keys = os.getenv("RAY_PODMAN_EXTRA_ENV_KEYS", "") |
| 316 | + if extra_env_keys: |
| 317 | + for key in (k.strip() for k in extra_env_keys.split(",")): |
| 318 | + if key and key in os.environ: |
| 319 | + env_vars[key] = os.environ[key] |
| 320 | + |
288 | 321 | # Support for runtime_env['env_vars'] |
289 | 322 | env_vars.update(context.env_vars) |
290 | 323 |
|
|
0 commit comments