-
Notifications
You must be signed in to change notification settings - Fork 847
refactor(inspect): Improve pipeline visualization and add max_epochs #3201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
e69eb81
c885465
b4a6fdf
48ce89d
e107d95
d1ae1fb
0bbd23c
83a53cd
fa24f0d
d54c591
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model | N | |
| device = job.payload.get("device") | ||
| snapshot_id_ = job.payload.get("dataset_snapshot_id") | ||
| snapshot_id = UUID(snapshot_id_) if snapshot_id_ else None | ||
| max_epochs = job.payload.get("max_epochs", 200) | ||
|
|
||
| if model_name is None: | ||
| raise ValueError(f"Job {job.id} payload must contain 'model_name'") | ||
|
|
@@ -105,6 +106,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model | N | |
| model=model, | ||
| device=device, | ||
| synchronization_parameters=synchronization_parameters, | ||
| max_epochs=max_epochs, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add this to the model training dialog? cc @ashwinvaidya17, @MarkRedeman
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice if we can already expose this (perhaps next to the picker of the training device from the new designs) However there are some edge cases we would need to deal with. Some models will overwrite the max epochs to be 1 since they only need to do a single pass.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's expose this and just have a note saying something like, "some models only extract features so max_epochs will be overridden for those models". |
||
| dataset_root=dataset_root, | ||
| ) | ||
|
|
||
|
|
@@ -155,6 +157,7 @@ def _train_model( | |
| model: Model, | ||
| synchronization_parameters: ProgressSyncParams, | ||
| dataset_root: str, | ||
| max_epochs: int, | ||
| device: str | None = None, | ||
| ) -> Model | None: | ||
| """ | ||
|
|
@@ -208,7 +211,7 @@ def _train_model( | |
| default_root_dir=model.export_path, | ||
| logger=[trackio, tensorboard], | ||
| devices=[0], # Only single GPU training is supported for now | ||
| max_epochs=10, | ||
| max_epochs=max_epochs, | ||
| callbacks=[GetiInspectProgressCallback(synchronization_parameters)], | ||
| accelerator=training_device, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| import numpy as np | ||
| from loguru import logger | ||
|
|
||
| from pydantic_models import PredictionResponse | ||
| from pydantic_models import PredictionLabel, PredictionResponse | ||
|
|
||
|
|
||
| class Visualizer: | ||
|
|
@@ -46,7 +46,7 @@ def overlay_predictions( | |
| def overlay_anomaly_heatmap( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something for the future but maybe we should consolidate the visualization methods here and the ones in anomalib. This way we won't have to maintain two separate visualizers |
||
| base_image: np.ndarray, | ||
| prediction: PredictionResponse, | ||
| threshold_value: int = 128, | ||
| threshold: float = 0.5, | ||
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| alpha: float = 0.25, | ||
| ) -> np.ndarray: | ||
| """Overlay the anomaly heatmap onto the image. | ||
|
|
@@ -58,66 +58,58 @@ def overlay_anomaly_heatmap( | |
| - Blend onto the base image using alpha | ||
| """ | ||
| try: | ||
| anomaly_map_base64 = prediction.anomaly_map | ||
| result = base_image.copy() | ||
| try: | ||
| anomaly_png_bytes = base64.b64decode(anomaly_map_base64) | ||
| anomaly_np = np.frombuffer(anomaly_png_bytes, dtype=np.uint8) | ||
| anomaly_img = cv2.imdecode(anomaly_np, cv2.IMREAD_UNCHANGED) | ||
| except Exception: | ||
| return result | ||
| # Decode anomaly map | ||
| anomaly_bytes = base64.b64decode(prediction.anomaly_map) | ||
| anomaly_img = cv2.imdecode(np.frombuffer(anomaly_bytes, dtype=np.uint8), cv2.IMREAD_GRAYSCALE) | ||
|
|
||
| if anomaly_img is None: | ||
| return result | ||
|
|
||
| try: | ||
| if anomaly_img.ndim == 3 and anomaly_img.shape[2] > 1: | ||
| anomaly_gray = cv2.cvtColor(anomaly_img, cv2.COLOR_BGR2GRAY) | ||
| else: | ||
| anomaly_gray = anomaly_img | ||
| return base_image | ||
|
|
||
| if anomaly_gray.dtype != np.uint8: | ||
| anomaly_gray = anomaly_gray.astype(np.uint8) | ||
| # Resize to match base image | ||
| h, w = base_image.shape[:2] | ||
| anomaly_gray = cv2.resize(anomaly_img, (w, h)) | ||
|
|
||
| heatmap = cv2.applyColorMap(anomaly_gray, cv2.COLORMAP_JET) | ||
| heatmap_resized = cv2.resize(heatmap, (result.shape[1], result.shape[0])) | ||
| # Apply colormap and create threshold mask | ||
| heatmap = cv2.applyColorMap(anomaly_gray, cv2.COLORMAP_JET) | ||
| mask = anomaly_gray >= (threshold * 255) | ||
|
|
||
| mask_gray = cv2.resize(anomaly_gray, (result.shape[1], result.shape[0])) | ||
| mask_bool = mask_gray >= threshold_value | ||
|
|
||
| masked_heatmap = np.zeros_like(heatmap_resized) | ||
| try: | ||
| masked_heatmap[mask_bool] = heatmap_resized[mask_bool] | ||
| except Exception as e: | ||
| logger.debug(f"Failed to apply heatmap mask: {e}") | ||
| # Create masked heatmap (only show where above threshold) | ||
| masked_heatmap = np.zeros_like(heatmap) | ||
| masked_heatmap[mask] = heatmap[mask] | ||
|
|
||
| result = cv2.addWeighted(result, 1.0, masked_heatmap, alpha, 0) | ||
| except Exception as e: | ||
| logger.debug(f"Failed to overlay heatmap: {e}") | ||
| return result | ||
| # Blend onto base image | ||
| result = base_image.copy() | ||
| return cv2.addWeighted(result, 1.0, masked_heatmap, alpha, 0) | ||
| except Exception as e: | ||
| logger.debug(f"Failed in overlay_anomaly_heatmap: {e}") | ||
| logger.debug(f"Failed to overlay heatmap: {e}") | ||
| return base_image | ||
|
|
||
| @staticmethod | ||
| def draw_prediction_label( | ||
| base_image: np.ndarray, | ||
| prediction: PredictionResponse, | ||
| *, | ||
| position: tuple[int, int] = (10, 20), | ||
| font_scale: float = 2.0, | ||
| thickness: int = 3, | ||
| text_color: tuple[int, int, int] = (0, 255, 0), | ||
| background_color: tuple[int, int, int] = (0, 0, 0), | ||
| position: tuple[int, int] = (5, 5), | ||
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| font_scale: float = 1.0, | ||
| thickness: int = 2, | ||
| ) -> np.ndarray: | ||
| """Draw the prediction label with a background rectangle for readability.""" | ||
| alpha = 0.85 | ||
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| text_color = (36, 37, 40) | ||
| green = (139, 174, 70) | ||
| red = (255, 86, 98) | ||
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| background_color: tuple[int, int, int] = green if prediction.label == PredictionLabel.NORMAL else red | ||
| try: | ||
| label_text = f"{prediction.label.value} ({prediction.score:.3f})" | ||
| label_text = f"{prediction.label.value} {int(prediction.score * 100)}%" | ||
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| result = base_image.copy() | ||
| font = cv2.FONT_HERSHEY_SIMPLEX | ||
| (text_w, text_h), _ = cv2.getTextSize(label_text, font, font_scale, thickness) | ||
| x, y = position[0], position[1] + text_h | ||
| cv2.rectangle(result, (x - 8, y - text_h - 8), (x - 8 + text_w + 16, y + 8), background_color, -1) | ||
| # Create overlay for transparent background | ||
| overlay = result.copy() | ||
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| cv2.rectangle(overlay, (x - 8, y - text_h - 8), (x - 8 + text_w + 16, y + 8), background_color[::-1], -1) | ||
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
maxxgx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| result = cv2.addWeighted(result, 1.0 - alpha, overlay, alpha, 0) | ||
|
|
||
| cv2.putText(result, label_text, (x, y), font, font_scale, text_color, thickness, cv2.LINE_AA) | ||
| return result | ||
| except Exception as e: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are we planning on dealing with rest of the configurable parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't for the time being. We're likely going to remove this once we have a proper design for configurable parameters.