WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions application/backend/src/pydantic_models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ class TrainJobPayload(BaseModel):
model_name: str
device: str | None = Field(default=None)
dataset_snapshot_id: str | None = Field(default=None) # used because UUID is not JSON serializable
max_epochs: int | None = Field(default=None, ge=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we planning on dealing with rest of the configurable parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't for the time being. We're likely going to remove this once we have a proper design for configurable parameters.

13 changes: 7 additions & 6 deletions application/backend/src/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ async def predict_image(
image_bytes: bytes,
cached_models: dict[UUID, OpenVINOInferencer] | None = None,
device: str | None = None,
is_bgr: bool = False,
) -> PredictionResponse:
"""Run prediction on an image using the specified model.

Expand All @@ -264,6 +265,7 @@ async def predict_image(
image_bytes: Raw image bytes from uploaded file
cached_models: Optional dict to cache loaded models (for performance)
device: Optional string indicating the device to use for inference
is_bgr: Whether the image is in BGR format

Returns:
PredictionResponse: Structured prediction results
Expand All @@ -285,20 +287,19 @@ async def predict_image(

# Run entire prediction pipeline in a single thread
# This includes image processing, model inference, and result processing
response_data = await asyncio.to_thread(cls._run_prediction_pipeline, inference_model, image_bytes)
response_data = await asyncio.to_thread(cls._run_prediction_pipeline, inference_model, image_bytes, is_bgr)

return PredictionResponse(**response_data)

@staticmethod
def _run_prediction_pipeline(inference_model: OpenVINOInferencer, image_bytes: bytes) -> dict:
def _run_prediction_pipeline(inference_model: OpenVINOInferencer, image_bytes: bytes, is_bgr: bool = False) -> dict:
"""Run the complete prediction pipeline in a single thread."""
# Process image
npd = np.frombuffer(image_bytes, np.uint8)
bgr_image = cv2.imdecode(npd, -1)
if bgr_image is None:
image = cv2.imdecode(npd, -1)
if image is None:
raise ValueError("Failed to decode image")

numpy_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
numpy_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if is_bgr else image

# Run prediction
pred = inference_model.predict(numpy_image)
Expand Down
5 changes: 4 additions & 1 deletion application/backend/src/services/training_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model | N
device = job.payload.get("device")
snapshot_id_ = job.payload.get("dataset_snapshot_id")
snapshot_id = UUID(snapshot_id_) if snapshot_id_ else None
max_epochs = job.payload.get("max_epochs", 200)

if model_name is None:
raise ValueError(f"Job {job.id} payload must contain 'model_name'")
Expand Down Expand Up @@ -105,6 +106,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model | N
model=model,
device=device,
synchronization_parameters=synchronization_parameters,
max_epochs=max_epochs,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add this to the model training dialog? cc @ashwinvaidya17, @MarkRedeman

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we can already expose this (perhaps next to the picker of the training device from the new designs)

However there are some edge cases we would need to deal with. Some models will overwrite the max epochs to be 1 since they only need to do a single pass.
We could hardcode the UI to not show the input for those models, but that might get messy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's expose this and just have a note saying something like, "some models only extract features so max_epochs will be overridden for those models".

dataset_root=dataset_root,
)

Expand Down Expand Up @@ -155,6 +157,7 @@ def _train_model(
model: Model,
synchronization_parameters: ProgressSyncParams,
dataset_root: str,
max_epochs: int,
device: str | None = None,
) -> Model | None:
"""
Expand Down Expand Up @@ -208,7 +211,7 @@ def _train_model(
default_root_dir=model.export_path,
logger=[trackio, tensorboard],
devices=[0], # Only single GPU training is supported for now
max_epochs=10,
max_epochs=max_epochs,
callbacks=[GetiInspectProgressCallback(synchronization_parameters)],
accelerator=training_device,
)
Expand Down
74 changes: 33 additions & 41 deletions application/backend/src/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from loguru import logger

from pydantic_models import PredictionResponse
from pydantic_models import PredictionLabel, PredictionResponse


class Visualizer:
Expand Down Expand Up @@ -46,7 +46,7 @@ def overlay_predictions(
def overlay_anomaly_heatmap(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something for the future but maybe we should consolidate the visualization methods here and the ones in anomalib. This way we won't have to maintain two separate visualizers

base_image: np.ndarray,
prediction: PredictionResponse,
threshold_value: int = 128,
threshold: float = 0.5,
alpha: float = 0.25,
) -> np.ndarray:
"""Overlay the anomaly heatmap onto the image.
Expand All @@ -58,66 +58,58 @@ def overlay_anomaly_heatmap(
- Blend onto the base image using alpha
"""
try:
anomaly_map_base64 = prediction.anomaly_map
result = base_image.copy()
try:
anomaly_png_bytes = base64.b64decode(anomaly_map_base64)
anomaly_np = np.frombuffer(anomaly_png_bytes, dtype=np.uint8)
anomaly_img = cv2.imdecode(anomaly_np, cv2.IMREAD_UNCHANGED)
except Exception:
return result
# Decode anomaly map
anomaly_bytes = base64.b64decode(prediction.anomaly_map)
anomaly_img = cv2.imdecode(np.frombuffer(anomaly_bytes, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)

if anomaly_img is None:
return result

try:
if anomaly_img.ndim == 3 and anomaly_img.shape[2] > 1:
anomaly_gray = cv2.cvtColor(anomaly_img, cv2.COLOR_BGR2GRAY)
else:
anomaly_gray = anomaly_img
return base_image

if anomaly_gray.dtype != np.uint8:
anomaly_gray = anomaly_gray.astype(np.uint8)
# Resize to match base image
h, w = base_image.shape[:2]
anomaly_gray = cv2.resize(anomaly_img, (w, h))

heatmap = cv2.applyColorMap(anomaly_gray, cv2.COLORMAP_JET)
heatmap_resized = cv2.resize(heatmap, (result.shape[1], result.shape[0]))
# Apply colormap and create threshold mask
heatmap = cv2.applyColorMap(anomaly_gray, cv2.COLORMAP_JET)
mask = anomaly_gray >= (threshold * 255)

mask_gray = cv2.resize(anomaly_gray, (result.shape[1], result.shape[0]))
mask_bool = mask_gray >= threshold_value

masked_heatmap = np.zeros_like(heatmap_resized)
try:
masked_heatmap[mask_bool] = heatmap_resized[mask_bool]
except Exception as e:
logger.debug(f"Failed to apply heatmap mask: {e}")
# Create masked heatmap (only show where above threshold)
masked_heatmap = np.zeros_like(heatmap)
masked_heatmap[mask] = heatmap[mask]

result = cv2.addWeighted(result, 1.0, masked_heatmap, alpha, 0)
except Exception as e:
logger.debug(f"Failed to overlay heatmap: {e}")
return result
# Blend onto base image
result = base_image.copy()
return cv2.addWeighted(result, 1.0, masked_heatmap, alpha, 0)
except Exception as e:
logger.debug(f"Failed in overlay_anomaly_heatmap: {e}")
logger.debug(f"Failed to overlay heatmap: {e}")
return base_image

@staticmethod
def draw_prediction_label(
base_image: np.ndarray,
prediction: PredictionResponse,
*,
position: tuple[int, int] = (10, 20),
font_scale: float = 2.0,
thickness: int = 3,
text_color: tuple[int, int, int] = (0, 255, 0),
background_color: tuple[int, int, int] = (0, 0, 0),
position: tuple[int, int] = (5, 5),
font_scale: float = 1.0,
thickness: int = 2,
) -> np.ndarray:
"""Draw the prediction label with a background rectangle for readability."""
alpha = 0.85
text_color = (36, 37, 40)
green = (139, 174, 70)
red = (255, 86, 98)
background_color: tuple[int, int, int] = green if prediction.label == PredictionLabel.NORMAL else red
try:
label_text = f"{prediction.label.value} ({prediction.score:.3f})"
label_text = f"{prediction.label.value} {int(prediction.score * 100)}%"
result = base_image.copy()
font = cv2.FONT_HERSHEY_SIMPLEX
(text_w, text_h), _ = cv2.getTextSize(label_text, font, font_scale, thickness)
x, y = position[0], position[1] + text_h
cv2.rectangle(result, (x - 8, y - text_h - 8), (x - 8 + text_w + 16, y + 8), background_color, -1)
# Create overlay for transparent background
overlay = result.copy()
cv2.rectangle(overlay, (x - 8, y - text_h - 8), (x - 8 + text_w + 16, y + 8), background_color[::-1], -1)
result = cv2.addWeighted(result, 1.0 - alpha, overlay, alpha, 0)

cv2.putText(result, label_text, (x, y), font, font_scale, text_color, thickness, cv2.LINE_AA)
return result
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions application/backend/src/workers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ async def _run_inference(self, image_bytes: bytes) -> Any | None:
self._loaded_model.model,
image_bytes,
self._cached_models, # type: ignore[arg-type]
is_bgr=True,
)
except Exception as e:
logger.error(f"Inference failed: {e}", exc_info=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_train_pending_job_cleanup_on_failure(
with patch("services.training_service.asyncio.to_thread") as mock_to_thread:
# Mock the training to succeed first, setting export_path, then fail
def mock_train_model(
cls, model, synchronization_parameters: ProgressSyncParams, device=None, dataset_root=None
cls, model, synchronization_parameters: ProgressSyncParams, device=None, dataset_root=None, max_epochs=1
):
model.export_path = "/path/to/model"
raise Exception("Training failed")
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_train_model_success(
# Call the method
with patch.object(TrainingService, "_compute_export_size", return_value=123):
result = TrainingService._train_model(
fxt_model, synchronization_parameters=ProgressSyncParams(), dataset_root="/tmp/dataset"
fxt_model, synchronization_parameters=ProgressSyncParams(), dataset_root="/tmp/dataset", max_epochs=42
)

# Verify the result
Expand All @@ -282,7 +282,7 @@ def test_train_model_success(
assert call_args[1]["default_root_dir"] == "/path/to/model"
assert "logger" in call_args[1]
assert len(call_args[1]["logger"]) == 2 # trackio and tensorboard
assert call_args[1]["max_epochs"] == 10
assert call_args[1]["max_epochs"] == 42

fxt_mock_anomalib_components["engine"].fit.assert_called_once_with(
model=fxt_mock_anomalib_components["anomalib_model"], datamodule=fxt_mock_anomalib_components["folder"]
Expand Down Expand Up @@ -332,7 +332,7 @@ def test_train_model_cancelled_before_start(
sync_params.set_cancel_training_event()

result = TrainingService._train_model(
fxt_model, synchronization_parameters=sync_params, dataset_root="/tmp/dataset"
fxt_model, synchronization_parameters=sync_params, dataset_root="/tmp/dataset", max_epochs=1
)

assert result is None
Loading