""" Unified SegEarth pipeline: OV, OV-2 (CLIP-based), OV-3 (SAM3-based). Training-free open-vocabulary segmentation for remote sensing. """ import contextlib from pathlib import Path from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms try: from .upsamplers import get_upsampler, FEATUP_CHECKPOINTS except ImportError: from upsamplers import get_upsampler, FEATUP_CHECKPOINTS try: from .prompts.imagenet_template import openai_imagenet_template, sub_imagenet_template except ImportError: openai_imagenet_template = [ lambda c: f"a photo of a {c}.", lambda c: f"a bad photo of a {c}.", lambda c: f"a photo of many {c}.", lambda c: f"a photo of the large {c}.", lambda c: f"a photo of the small {c}.", ] sub_imagenet_template = openai_imagenet_template[:7] def get_cls_idx(path: Union[str, Path]) -> Tuple[List[str], List[int]]: """Parse class list file (one line per class, comma-separated synonyms).""" path = Path(path) with open(path) as f: lines = f.readlines() class_names, class_indices = [], [] for idx, line in enumerate(lines): names_i = [n.strip() for n in line.strip().split(",")] class_names.extend(names_i) class_indices.extend([idx] * len(names_i)) return class_names, class_indices class SegEarthPipelineCLIP: """ CLIP-based SegEarth pipeline (OV, OV-2). Uses transformers.CLIPModel + SimFeatUp for dense prediction. """ def __init__( self, model_id: str = "openai/clip-vit-base-patch16", featup_model: str = "jbu_one", featup_weights_path: Optional[Union[str, Path]] = None, class_names_path: Optional[Union[str, Path]] = None, device: str = "cuda", dtype: torch.dtype = torch.float16, cls_token_lambda: float = -0.3, logit_scale: float = 50.0, prob_thd: float = 0.0, bg_idx: int = 0, slide_crop: int = 0, slide_stride: int = 112, template_set: str = "openai", ): from transformers import CLIPModel, CLIPProcessor self.device = device self.dtype = dtype self.cls_token_lambda = cls_token_lambda self.logit_scale = logit_scale self.prob_thd = prob_thd self.bg_idx = bg_idx self.slide_crop = slide_crop self.slide_stride = slide_stride self.output_cls_token = cls_token_lambda != 0 self.templates = sub_imagenet_template if template_set == "sub" else openai_imagenet_template self.clip = CLIPModel.from_pretrained(model_id).to(device).to(dtype).eval() try: self.processor = CLIPProcessor.from_pretrained(model_id) except Exception: # Fallback: use tokenizer only (CLIPProcessor can trigger mistral_common compat in some envs) from transformers import CLIPTokenizer self.processor = None self._tokenizer = CLIPTokenizer.from_pretrained(model_id) self.patch_size = 16 self.feat_dim = 512 # Resolve featup path: self-contained repo only (OV/OV-2/weights/featup) ckpt_name = FEATUP_CHECKPOINTS.get(featup_model, "").split("/")[-1] repo_dir = Path(__file__).parent _candidates = [ Path(featup_weights_path) if featup_weights_path else None, repo_dir / "OV" / "weights" / "featup" / ckpt_name, repo_dir / "OV-2" / "weights" / "featup" / ckpt_name, repo_dir / "weights" / "featup" / ckpt_name, ] featup_path = next((p for p in _candidates if p and p.exists()), None) self.use_featup = featup_path is not None and featup_path.exists() upsampler_name = "bilinear" if not self.use_featup else featup_model.replace("_maskclip", "") self.upsampler = get_upsampler(upsampler_name, self.feat_dim).to(device).to(dtype).eval() if self.use_featup: ckpt = torch.load(featup_path, map_location="cpu") sd = ckpt.get("state_dict", ckpt) weights = {k[10:]: v for k, v in sd.items() if k.startswith("upsampler.")} self.upsampler.load_state_dict(weights, strict=True) repo_dir = Path(__file__).parent cls_path = class_names_path or (repo_dir / "configs" / "cls_openearthmap_sar.txt") cls_path = Path(cls_path) if cls_path.exists(): self.class_names, self.class_indices = get_cls_idx(cls_path) else: self.class_names = ["building", "road", "water", "vegetation", "bare soil"] self.class_indices = list(range(len(self.class_names))) self.num_classes = max(self.class_indices) + 1 self.num_queries = len(self.class_indices) self.query_idx = torch.tensor(self.class_indices, dtype=torch.int64, device=device) self._build_query_features() def _build_query_features(self): query_features = [] with torch.no_grad(): tokenizer = getattr(self, "_tokenizer", None) or (self.processor.tokenizer if self.processor else None) for name in self.class_names: texts = [t(name) for t in self.templates] inputs = tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} out = self.clip.get_text_features(**inputs) if hasattr(out, "shape"): feat_t = out elif hasattr(out, "pooler_output") and out.pooler_output is not None: feat_t = out.pooler_output else: feat_t = out.last_hidden_state.mean(1) feat = feat_t.mean(0) / feat_t.mean(0).norm() query_features.append(feat.unsqueeze(0)) self.query_features = torch.cat(query_features, dim=0).to(self.dtype) def _encode_image_patches(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: out = self.clip.vision_model(pixel_values) hidden = out.last_hidden_state proj = self.clip.visual_projection.weight patch_tokens = hidden[:, 1:, :] patch_feats = patch_tokens @ proj.T cls_token = None if self.output_cls_token: cls_tok = hidden[:, 0:1, :] cls_token = (cls_tok @ proj.T).squeeze(1) cls_token = F.normalize(cls_token, dim=-1) return patch_feats, cls_token def _preprocess_image(self, image: Image.Image, size: Optional[int] = 224, keep_size: bool = False) -> torch.Tensor: t = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711], ), ]) x = t(image.convert("RGB")) if not keep_size and size: x = transforms.functional.resize(x, (size, size)) return x.unsqueeze(0).to(self.device).to(self.dtype) def _compute_padsize(self, H: int, W: int) -> Tuple[int, int, int, int]: l, r, t, b = 0, 0, 0, 0 if W % self.patch_size: lr = self.patch_size - (W % self.patch_size) l = lr // 2 r = lr - l if H % self.patch_size: tb = self.patch_size - (H % self.patch_size) t = tb // 2 b = tb - t return l, r, t, b def _forward_single_crop(self, img_tensor: torch.Tensor) -> torch.Tensor: B, C, H, W = img_tensor.shape patch_h, patch_w = H // self.patch_size, W // self.patch_size patch_feats, cls_token = self._encode_image_patches(img_tensor) patch_feats = patch_feats.permute(0, 2, 1).view(B, self.feat_dim, patch_h, patch_w) patch_feats = patch_feats.to(self.dtype) img_tensor = img_tensor.to(self.dtype) patch_feats = self.upsampler(patch_feats, img_tensor) out_h, out_w = H, W patch_feats = patch_feats.view(B, self.feat_dim, -1).permute(0, 2, 1) patch_feats = F.normalize(patch_feats, dim=-1) logits = patch_feats @ self.query_features.T if self.output_cls_token and cls_token is not None: cls_logits = cls_token @ self.query_features.T logits = logits + cls_logits.unsqueeze(1) * self.cls_token_lambda logits = logits.permute(0, 2, 1).view(B, self.num_queries, out_h, out_w) return logits[0] def _forward_slide(self, img_tensor: torch.Tensor, ori_shape: Tuple[int, int]) -> torch.Tensor: B, _, h_img, w_img = img_tensor.shape stride = (self.slide_stride, self.slide_stride) crop = (self.slide_crop, self.slide_crop) h_stride, w_stride = stride h_crop, w_crop = crop h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 preds = img_tensor.new_zeros((B, self.num_queries, h_img, w_img)) count_mat = img_tensor.new_zeros((B, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img_tensor[:, :, y1:y2, x1:x2] H, W = crop_img.shape[2:] l, r, t, b = self._compute_padsize(H, W) if any([l, r, t, b]): crop_img = F.pad(crop_img, (l, r, t, b)) crop_logits = self._forward_single_crop(crop_img) if any([l, r, t, b]): crop_logits = crop_logits[:, t : t + H, l : l + W] pad_crop = F.pad( crop_logits.unsqueeze(0), (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)), ) preds += pad_crop count_mat[:, :, y1:y2, x1:x2] += 1 preds = preds / count_mat.clamp(min=1) logits = F.interpolate(preds, size=ori_shape, mode="bilinear") return logits[0] def _postprocess(self, logits: torch.Tensor) -> torch.Tensor: logits = logits * self.logit_scale probs = logits.softmax(0) if self.num_classes != self.num_queries: cls_idx = F.one_hot(self.query_idx, self.num_classes) cls_idx = cls_idx.T.view(self.num_classes, self.num_queries, 1, 1) probs = (probs.unsqueeze(0) * cls_idx).max(1)[0] seg_pred = probs.argmax(0, keepdim=True) if self.prob_thd > 0: max_prob = probs.max(0, keepdim=True)[0] seg_pred[max_prob < self.prob_thd] = self.bg_idx return seg_pred.squeeze(0) @torch.no_grad() def __call__(self, image: Union[Image.Image, torch.Tensor], return_logits: bool = False) -> torch.Tensor: if isinstance(image, Image.Image): use_slide = self.slide_crop > 0 keep_size = use_slide img_tensor = self._preprocess_image(image, size=224, keep_size=keep_size) else: img_tensor = image.to(self.device).to(self.dtype) if img_tensor.dim() == 3: img_tensor = img_tensor.unsqueeze(0) B, C, H, W = img_tensor.shape ori_shape = (H, W) use_slide = self.slide_crop > 0 and (H > self.slide_crop or W > self.slide_crop) if use_slide: logits = self._forward_slide(img_tensor, ori_shape) else: l, r, t, b = self._compute_padsize(H, W) if any([l, r, t, b]): img_tensor = F.pad(img_tensor, (l, r, t, b)) out_h, out_w = img_tensor.shape[2], img_tensor.shape[3] else: out_h, out_w = H, W logits = self._forward_single_crop(img_tensor) if any([l, r, t, b]): logits = logits[:, t : t + H, l : l + W] if (out_h, out_w) != ori_shape: logits = F.interpolate(logits.unsqueeze(0), size=ori_shape, mode="bilinear").squeeze(0) if return_logits: if self.num_classes != self.num_queries: cls_idx = F.one_hot(self.query_idx, self.num_classes) cls_idx = cls_idx.T.view(self.num_classes, self.num_queries, 1, 1) logits = (logits.unsqueeze(0) * cls_idx).max(1)[0] return logits return self._postprocess(logits) class SegEarthPipelineSAM3: """ SAM3-based SegEarth pipeline (OV-3). Uses sam3 package for open-vocabulary segmentation. Requires: pip install sam3 (or transformers>=4.45 for Sam3Model) """ def __init__( self, model_id: str = "facebook/sam3", local_checkpoint: Optional[Union[str, Path]] = None, class_names_path: Optional[Union[str, Path]] = None, device: str = "cuda", prob_thd: float = 0.0, bg_idx: int = 0, slide_crop: int = 0, slide_stride: int = 112, confidence_threshold: float = 0.5, use_sem_seg: bool = True, use_presence_score: bool = True, use_transformer_decoder: bool = True, ): self.device = device self.prob_thd = prob_thd self.bg_idx = bg_idx self.slide_crop = slide_crop self.slide_stride = slide_stride self.confidence_threshold = confidence_threshold self.use_sem_seg = use_sem_seg self.use_presence_score = use_presence_score self.use_transformer_decoder = use_transformer_decoder # Workaround for cuDNN "No execution plans support the graph" with SDPA if device == "cuda": if hasattr(torch.backends.cuda, "enable_flash_sdp"): torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) if hasattr(torch.backends.cuda, "enable_math_sdp"): torch.backends.cuda.enable_math_sdp(True) try: from sam3 import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor except ImportError: raise ImportError( "SegEarth OV-3 requires the sam3 package. Install from: " "https://github.com/facebookresearch/sam3 or use transformers.Sam3Model.from_pretrained('facebook/sam3')" ) ckpt_path = Path(local_checkpoint) if local_checkpoint else None if ckpt_path and not ckpt_path.is_absolute(): ckpt_path = Path(__file__).parent / "OV-3" / ckpt_path use_safetensors = ckpt_path and str(ckpt_path).endswith(".safetensors") and ckpt_path.exists() use_pt = ckpt_path and (str(ckpt_path).endswith(".pt") or str(ckpt_path).endswith(".bin")) and ckpt_path.exists() if use_safetensors: self.model = build_sam3_image_model(checkpoint_path=None, load_from_HF=False, device=device) from safetensors.torch import load_file state_dict = load_file(str(ckpt_path)) # HF model.safetensors uses "detector_model." prefix; sam3 expects "detector." -> stripped state_dict = {k.replace("detector_model.", ""): v for k, v in state_dict.items()} self.model.load_state_dict(state_dict, strict=False) elif use_pt: self.model = build_sam3_image_model(checkpoint_path=str(ckpt_path), load_from_HF=False, device=device) else: self.model = build_sam3_image_model(checkpoint_path=None, load_from_HF=True, device=device) self.processor = Sam3Processor(self.model, confidence_threshold=confidence_threshold, device=device) repo_dir = Path(__file__).parent cls_path = class_names_path or (repo_dir / "configs" / "cls_openearthmap_sar.txt") cls_path = Path(cls_path) if cls_path.exists(): self.class_names, self.class_indices = get_cls_idx(cls_path) else: self.class_names = ["building", "road", "water", "vegetation", "bare soil"] self.class_indices = list(range(len(self.class_names))) self.num_classes = max(self.class_indices) + 1 self.num_queries = len(self.class_indices) self.query_idx = torch.tensor(self.class_indices, dtype=torch.int64, device=device) def _inference_single_view(self, image: Image.Image) -> torch.Tensor: w, h = image.size seg_logits = torch.zeros((self.num_queries, h, w), device=self.device) sdp_ctx = ( torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False, enable_cudnn=False) if self.device == "cuda" and hasattr(torch.backends.cuda, "sdp_kernel") else contextlib.nullcontext() ) with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16), sdp_ctx: inference_state = self.processor.set_image(image) for query_idx, query_word in enumerate(self.class_names): self.processor.reset_all_prompts(inference_state) inference_state = self.processor.set_text_prompt(state=inference_state, prompt=query_word) if self.use_transformer_decoder and inference_state.get("masks_logits") is not None: inst_len = inference_state["masks_logits"].shape[0] for inst_id in range(inst_len): instance_logits = inference_state["masks_logits"][inst_id].squeeze() instance_score = inference_state["object_score"][inst_id] if instance_logits.shape != (h, w): instance_logits = F.interpolate( instance_logits.view(1, 1, *instance_logits.shape), size=(h, w), mode="bilinear", align_corners=False ).squeeze() seg_logits[query_idx] = torch.max(seg_logits[query_idx], instance_logits * instance_score) if self.use_sem_seg and inference_state.get("semantic_mask_logits") is not None: semantic_logits = inference_state["semantic_mask_logits"] if semantic_logits.shape != (h, w): semantic_logits = F.interpolate( semantic_logits.view(1, 1, *semantic_logits.shape) if semantic_logits.dim() == 2 else semantic_logits.unsqueeze(0), size=(h, w), mode="bilinear", align_corners=False ).squeeze() seg_logits[query_idx] = torch.max(seg_logits[query_idx], semantic_logits) if self.use_presence_score and inference_state.get("presence_score") is not None: seg_logits[query_idx] = seg_logits[query_idx] * inference_state["presence_score"] return seg_logits def slide_inference(self, image: Image.Image) -> torch.Tensor: w_img, h_img = image.size stride = (self.slide_stride, self.slide_stride) crop = (self.slide_crop, self.slide_crop) h_stride, w_stride = stride h_crop, w_crop = crop h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 preds = torch.zeros((self.num_queries, h_img, w_img), device=self.device) count_mat = torch.zeros((1, h_img, w_img), device=self.device) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = image.crop((x1, y1, x2, y2)) crop_seg = self._inference_single_view(crop_img) preds[:, y1:y2, x1:x2] += crop_seg count_mat[:, y1:y2, x1:x2] += 1 return preds / count_mat.clamp(min=1) @torch.no_grad() def __call__(self, image: Union[Image.Image, torch.Tensor]) -> torch.Tensor: if isinstance(image, torch.Tensor): image = transforms.functional.to_pil_image(image) image = image.convert("RGB") if self.slide_crop > 0 and (image.size[0] > self.slide_crop or image.size[1] > self.slide_crop): seg_logits = self.slide_inference(image) else: seg_logits = self._inference_single_view(image) if self.num_classes != self.num_queries: cls_idx = F.one_hot(self.query_idx, self.num_classes) cls_idx = cls_idx.T.view(self.num_classes, self.num_queries, 1, 1) seg_logits = (seg_logits.unsqueeze(0) * cls_idx).max(1)[0] seg_pred = seg_logits.argmax(0, keepdim=True) if self.prob_thd > 0: max_prob = seg_logits.max(0, keepdim=True)[0] seg_pred[max_prob < self.prob_thd] = self.bg_idx return seg_pred.squeeze(0) def SegEarthPipeline( variant: str = "OV-2", model_id: Optional[str] = None, **kwargs, ): """ Factory for SegEarth pipelines. Load from self-contained subfolders OV/, OV-2/, OV-3/. Args: variant: One of OV, OV-2, OV-3 (or legacy: ov_clip_openai_vitb16, ov2_alignearth_sar, ov3_sam3) model_id: Override HF model ID **kwargs: Passed to pipeline constructor """ import json repo_dir = Path(__file__).parent variant_map = {"ov_clip_openai_vitb16": "OV", "ov2_alignearth_sar": "OV-2", "ov3_sam3": "OV-3"} subfolder = variant_map.get(variant, variant) sub_path = repo_dir / subfolder / "pipeline.py" if sub_path.exists(): import importlib.util spec = importlib.util.spec_from_file_location(f"segearth_{subfolder}", sub_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod.load(**kwargs) if model_id is None else mod.load(model_id=model_id, **kwargs) # Fallback: legacy flat config if model_id is None: model_id = "BiliSakura/AlignEarth-SAR-ViT-B-16" if variant in ("ov3_sam3", "OV-3"): return SegEarthPipelineSAM3(model_id=model_id or "facebook/sam3", **kwargs) return SegEarthPipelineCLIP(model_id=model_id, **kwargs)