Extending¶
All base classes and relevant types are importable from wsi_patching.custom_component.
from wsi_patching.custom_component import (
Stage,
StreamWriterBase,
MaterializeWriterBase,
PipelineContext,
# types you may need for annotations:
CollatedPatchBatch,
EncodedCollatedPatchBatch,
Slide,
SlideWithROIs,
)
Custom Stage¶
A Stage is a callable that accepts an iterable and yields transformed items. Implement it as a generator.
from typing import Iterable
from wsi_patching.custom_component import Stage, PipelineContext, CollatedPatchBatch
class MyFilter(Stage):
input_type = CollatedPatchBatch # checked against the preceding stage's output_type
output_type = CollatedPatchBatch # checked against the succeeding stage's input_type
def __init__(self, threshold: float = 0.5):
self.threshold = threshold
def export_context(self, ctx: PipelineContext) -> None:
# Optional: write values that other stages may need
ctx["my_filter_threshold"] = self.threshold
def validate(self) -> None:
# Optional: assert preconditions using the fully-populated context
self.ctx.require_key("use_gpu") # raises KeyError with a clear message if missing
def __call__(self, it: Iterable[CollatedPatchBatch]) -> Iterable[CollatedPatchBatch]:
for batch in it:
filtered = self._apply(batch)
if filtered is not None:
yield filtered
def _apply(self, batch: CollatedPatchBatch):
...
Key rules:
input_typeandoutput_typeare class attributes used for preflight type checking. Set them to the exact types fromwsi_patching.custom_component(orobjectto opt out of checking).export_contextandvalidateare optional — omit them if you have nothing to export or validate.validateruns after the full context is populated, soself.ctx["key"]is safe to read there.- The
__call__method must be a generator (useyield). Do not return a list.
Custom StreamWriter¶
Implement stream() as a generator that yields whatever your downstream consumer expects.
from typing import Any, Iterator
from wsi_patching.custom_component import StreamWriterBase, CollatedPatchBatch
class MyStreamWriter(StreamWriterBase):
input_type = CollatedPatchBatch
def stream(self, batch: CollatedPatchBatch) -> Iterator[tuple[str, Any]]:
# Called once per batch produced by the pipeline
yield batch.wsi_id, batch.patches
stream() is called once per queue item. The pipeline handles threading and queue management; you only need to focus on the transformation.
Custom MaterializeWriter¶
Implement open(), write(), and close() — called once, repeatedly, and once respectively.
import csv
from typing import Any, Optional
from wsi_patching.custom_component import MaterializeWriterBase, CollatedPatchBatch
class CsvWriter(MaterializeWriterBase):
input_type = CollatedPatchBatch
def __init__(self, path: str):
super().__init__()
self.path = path
self._file = None
self._writer = None
def open(self) -> None:
self._file = open(self.path, "w", newline="")
self._writer = csv.writer(self._file)
self._writer.writerow(["wsi_id", "x", "y"])
def write(self, batch: CollatedPatchBatch) -> None:
for coord in batch.coords:
self._writer.writerow([batch.wsi_id, coord[0], coord[1]])
def close(self) -> None:
if self._file:
self._file.close()
def get_output(self) -> Optional[str]:
# Optional: return something after materialization completes
return self.path
Note
close() is always called, even if an error occurred during write(). Use it for cleanup (closing file handles, flushing buffers).
Profiling your component¶
To appear in the pipeline's profile output, instrument your __call__ method with the built-in profiler:
import time
from typing import Iterable
from wsi_patching.custom_component import Stage, CollatedPatchBatch
class MyStage(Stage):
input_type = CollatedPatchBatch
output_type = CollatedPatchBatch
def __call__(self, it: Iterable[CollatedPatchBatch]) -> Iterable[CollatedPatchBatch]:
prof = self.get_current_profiler() # None when profile=False
for batch in it:
t0 = time.perf_counter()
result = self._process(batch)
dt = time.perf_counter() - t0
if result is not None:
prof.add_time("MyStage", dt, yielded=True)
yield result
else:
prof.add_time("MyStage", dt, yielded=False)
yielded=True— this timing window produced output (theyieldbranch).yielded=False— this timing window filtered the item (noyield).- The profiler is per-slide and per-thread; the pipeline aggregates results automatically.
- When
profile=False,get_current_profiler()returnsNone— guard withif prof:if you call methods on it conditionally, or just pass it directly toadd_time(it no-ops when disabled).
Type flow¶
The pipeline validates that each stage's output_type is a subclass of the next stage's input_type. This check runs at pipeline construction (in .then() and .to()), not at run time.
WSIGrid → output_type = Slide
AttachROIs → input_type = Slide, output_type = SlideWithROIs
PatchExtractor → input_type = SlideWithROIs, output_type = CollatedPatchBatch
PNGEncoder → input_type = CollatedPatchBatch, output_type = EncodedCollatedPatchBatch
WebDatasetWriter → input_type = EncodedCollatedPatchBatch
Set input_type = object or output_type = object on your custom component to opt out of the check for that boundary.