Coverage for src / bioimageio / core / tensor.py: 83%

257 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from itertools import permutations 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Callable, 

9 Dict, 

10 Iterator, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Union, 

16 cast, 

17 get_args, 

18) 

19 

20import numpy as np 

21import xarray as xr 

22from loguru import logger 

23from numpy.typing import DTypeLike, NDArray 

24from typing_extensions import Self, assert_never 

25 

26from bioimageio.spec.model import v0_5 

27 

28from ._magic_tensor_ops import MagicTensorOpsMixin 

29from .axis import AxisId, AxisInfo, AxisLike, PerAxis 

30from .common import ( 

31 CropWhere, 

32 DTypeStr, 

33 PadMode, 

34 PadWhere, 

35 PadWidth, 

36 PadWidthLike, 

37 QuantileMethod, 

38 SliceInfo, 

39) 

40 

41if TYPE_CHECKING: 

42 from numpy.typing import ArrayLike, NDArray 

43 

44 

45_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 

46 

47 

48def _resolve_pad_mode(mode: PadMode): 

49 constant_value = None 

50 if isinstance(mode, str): 

51 mode_name = mode 

52 elif isinstance(mode, v0_5.ConstantPadding): 

53 mode_name = mode.mode 

54 constant_value = mode.value 

55 elif isinstance( 

56 mode, (v0_5.EdgePadding, v0_5.ReflectPadding, v0_5.SymmetricPadding) 

57 ): 

58 mode_name = mode.mode 

59 else: 

60 assert_never(mode) 

61 

62 return mode_name, constant_value 

63 

64 

65# TODO: complete docstrings 

66# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray 

67class Tensor(MagicTensorOpsMixin): 

68 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 

69 and improved type annotations.""" 

70 

71 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 

72 

73 def __init__( 

74 self, 

75 array: NDArray[Any], 

76 dims: Sequence[Union[AxisId, AxisLike]], 

77 ) -> None: 

78 super().__init__() 

79 axes = tuple( 

80 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 

81 ) 

82 self._data = xr.DataArray(array, dims=axes) 

83 

84 def __repr__(self) -> str: 

85 return f"<Tensor {repr(self._data)}>" 

86 

87 def __array__(self, dtype: DTypeLike = None): 

88 return np.asarray(self._data, dtype=dtype) 

89 

90 def __getitem__( 

91 self, 

92 key: Union[ 

93 SliceInfo, 

94 slice, 

95 int, 

96 PerAxis[Union[SliceInfo, slice, int]], 

97 Tensor, 

98 xr.DataArray, 

99 ], 

100 ) -> Self: 

101 if isinstance(key, SliceInfo): 

102 key = slice(*key) 

103 elif isinstance(key, collections.abc.Mapping): 

104 key = { 

105 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 

106 for a, s in key.items() 

107 } 

108 elif isinstance(key, Tensor): 

109 key = key._data 

110 

111 return self.__class__.from_xarray(self._data[key]) 

112 

113 def __setitem__( 

114 self, 

115 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 

116 value: Union[Tensor, xr.DataArray, float, int], 

117 ) -> None: 

118 if isinstance(key, Tensor): 

119 key = key._data 

120 elif isinstance(key, xr.DataArray): 

121 pass 

122 else: 

123 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 

124 

125 if isinstance(value, Tensor): 

126 value = value._data 

127 

128 self._data[key] = value 

129 

130 def __len__(self) -> int: 

131 return len(self.data) 

132 

133 def _iter(self: Any) -> Iterator[Any]: 

134 for n in range(len(self)): 

135 yield self[n] 

136 

137 def __iter__(self: Any) -> Iterator[Any]: 

138 if self.ndim == 0: 

139 raise TypeError("iteration over a 0-d array") 

140 return self._iter() 

141 

142 def _binary_op( 

143 self, 

144 other: _Compatible, 

145 f: Callable[[Any, Any], Any], 

146 reflexive: bool = False, 

147 ) -> Self: 

148 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 

149 (other._data if isinstance(other, Tensor) else other), 

150 f, 

151 reflexive, 

152 ) 

153 return self.__class__.from_xarray(data) 

154 

155 def _inplace_binary_op( 

156 self, 

157 other: _Compatible, 

158 f: Callable[[Any, Any], Any], 

159 ) -> Self: 

160 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 

161 ( 

162 other_d 

163 if (other_d := getattr(other, "data")) is not None 

164 and isinstance( 

165 other_d, 

166 xr.DataArray, 

167 ) 

168 else other 

169 ), 

170 f, 

171 ) 

172 return self 

173 

174 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 

175 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 

176 f, *args, **kwargs 

177 ) 

178 return self.__class__.from_xarray(data) 

179 

180 @classmethod 

181 def from_xarray(cls, data_array: xr.DataArray) -> Self: 

182 """create a `Tensor` from an xarray data array 

183 

184 note for internal use: this factory method is round-trip save 

185 for any `Tensor`'s `data` property (an xarray.DataArray). 

186 """ 

187 return cls( 

188 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 

189 ) 

190 

191 @classmethod 

192 def from_numpy( 

193 cls, 

194 array: NDArray[Any], 

195 *, 

196 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 

197 ) -> Tensor: 

198 """create a `Tensor` from a numpy array 

199 

200 Args: 

201 array: the nd numpy array 

202 dims: A description of the array's axes. 

203 If None axes are guessed (which might fail and raise a ValueError.) 

204 If dims do not match array shape, permutations and singleton dimensions are tried to find a match. 

205 Raises: 

206 ValueError: if `dims` is None and dims guessing fails. 

207 """ 

208 

209 if dims is None: 

210 return cls._interprete_array_wo_known_axes(array) 

211 elif isinstance(dims, collections.abc.Sequence): 

212 dim_seq = list(dims) 

213 else: 

214 dim_seq = [dims] 

215 

216 axis_infos = [AxisInfo.create(a) for a in dim_seq] 

217 original_shape = tuple(array.shape) 

218 

219 successful_view = _get_array_view(array, axis_infos) 

220 if successful_view is None: 

221 raise ValueError( 

222 f"Array shape {original_shape} does not map to axes {dims}" 

223 ) 

224 

225 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 

226 

227 @property 

228 def data(self): 

229 return self._data 

230 

231 @property 

232 def dims(self): # TODO: rename to `axes`? 

233 """Tuple of dimension names associated with this tensor.""" 

234 return cast(Tuple[AxisId, ...], self._data.dims) 

235 

236 @property 

237 def dtype(self) -> DTypeStr: 

238 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 

239 assert dt in get_args(DTypeStr) 

240 return dt # pyright: ignore[reportReturnType] 

241 

242 @property 

243 def ndim(self): 

244 """Number of tensor dimensions.""" 

245 return self._data.ndim 

246 

247 @property 

248 def shape(self): 

249 """Tuple of tensor axes lengths""" 

250 return self._data.shape 

251 

252 @property 

253 def shape_tuple(self): 

254 """Tuple of tensor axes lengths""" 

255 return self._data.shape 

256 

257 @property 

258 def size(self): 

259 """Number of elements in the tensor. 

260 

261 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 

262 """ 

263 return self._data.size 

264 

265 @property 

266 def sizes(self): 

267 """Ordered, immutable mapping from axis ids to axis lengths.""" 

268 return cast(Mapping[AxisId, int], self.data.sizes) 

269 

270 @property 

271 def tagged_shape(self): 

272 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 

273 return self.sizes 

274 

275 def to_numpy(self) -> NDArray[Any]: 

276 """Return the data of this tensor as a numpy array.""" 

277 return self.data.to_numpy() # pyright: ignore[reportUnknownVariableType] 

278 

279 def argmax(self) -> Mapping[AxisId, int]: 

280 ret = self._data.argmax(...) 

281 assert isinstance(ret, dict) 

282 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 

283 

284 def astype(self, dtype: DTypeStr, *, copy: bool = False): 

285 """Return tensor cast to `dtype` 

286 

287 note: if dtype is already satisfied copy if `copy`""" 

288 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 

289 

290 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 

291 """Return a tensor whose values are limited to [min, max]. 

292 At least one of max or min must be given.""" 

293 return self.__class__.from_xarray(self._data.clip(min, max)) 

294 

295 def crop_to( 

296 self, 

297 sizes: PerAxis[int], 

298 crop_where: Union[ 

299 CropWhere, 

300 PerAxis[CropWhere], 

301 ] = "left_and_right", 

302 ) -> Self: 

303 """crop to match `sizes`""" 

304 if isinstance(crop_where, str): 

305 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 

306 else: 

307 crop_axis_where = crop_where 

308 

309 slices: Dict[AxisId, SliceInfo] = {} 

310 

311 for a, s_is in self.sizes.items(): 

312 if a not in sizes or sizes[a] == s_is: 

313 pass 

314 elif sizes[a] > s_is: 

315 logger.warning( 

316 "Cannot crop axis {} of size {} to larger size {}", 

317 a, 

318 s_is, 

319 sizes[a], 

320 ) 

321 elif a not in crop_axis_where: 

322 raise ValueError( 

323 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 

324 ) 

325 else: 

326 crop_this_axis_where = crop_axis_where[a] 

327 if crop_this_axis_where == "left": 

328 slices[a] = SliceInfo(s_is - sizes[a], s_is) 

329 elif crop_this_axis_where == "right": 

330 slices[a] = SliceInfo(0, sizes[a]) 

331 elif crop_this_axis_where == "left_and_right": 

332 slices[a] = SliceInfo( 

333 start := (s_is - sizes[a]) // 2, sizes[a] + start 

334 ) 

335 else: 

336 assert_never(crop_this_axis_where) 

337 

338 return self[slices] 

339 

340 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 

341 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 

342 

343 def item( 

344 self, 

345 key: Union[ 

346 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 

347 ] = None, 

348 ): 

349 """Copy a tensor element to a standard Python scalar and return it.""" 

350 if key is None: 

351 ret = self._data.item() 

352 else: 

353 ret = self[key]._data.item() 

354 

355 assert isinstance(ret, (bool, float, int)) 

356 return ret 

357 

358 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

359 return self.__class__.from_xarray(self._data.mean(dim=dim)) 

360 

361 def pad( 

362 self, 

363 pad_width: PerAxis[PadWidthLike], 

364 mode: PadMode = "symmetric", 

365 ) -> Self: 

366 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 

367 mode_name, constant_value = _resolve_pad_mode(mode) 

368 return self.__class__.from_xarray( 

369 self._data.pad( 

370 pad_width=pad_width, mode=mode_name, constant_values=constant_value 

371 ) 

372 ) 

373 

374 def pad_to( 

375 self, 

376 sizes: PerAxis[int], 

377 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 

378 mode: PadMode = "symmetric", 

379 ) -> Self: 

380 """pad `tensor` to match `sizes`""" 

381 if isinstance(pad_where, str): 

382 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 

383 else: 

384 pad_axis_where = pad_where 

385 

386 pad_width: Dict[AxisId, PadWidth] = {} 

387 for a, s_is in self.sizes.items(): 

388 if a not in sizes or sizes[a] == s_is: 

389 pad_width[a] = PadWidth(0, 0) 

390 elif s_is > sizes[a]: 

391 pad_width[a] = PadWidth(0, 0) 

392 logger.warning( 

393 "Cannot pad axis {} of size {} to smaller size {}", 

394 a, 

395 s_is, 

396 sizes[a], 

397 ) 

398 elif a not in pad_axis_where: 

399 raise ValueError( 

400 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 

401 ) 

402 else: 

403 pad_this_axis_where = pad_axis_where[a] 

404 d = sizes[a] - s_is 

405 if pad_this_axis_where == "left": 

406 pad_width[a] = PadWidth(d, 0) 

407 elif pad_this_axis_where == "right": 

408 pad_width[a] = PadWidth(0, d) 

409 elif pad_this_axis_where == "left_and_right": 

410 pad_width[a] = PadWidth(left := d // 2, d - left) 

411 else: 

412 assert_never(pad_this_axis_where) 

413 

414 return self.pad(pad_width, mode) 

415 

416 def quantile( 

417 self, 

418 q: Union[float, Sequence[float]], 

419 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 

420 method: QuantileMethod = "linear", 

421 ) -> Self: 

422 assert ( 

423 isinstance(q, (float, int)) 

424 and q >= 0.0 

425 or not isinstance(q, (float, int)) 

426 and all(qq >= 0.0 for qq in q) 

427 ) 

428 assert ( 

429 isinstance(q, (float, int)) 

430 and q <= 1.0 

431 or not isinstance(q, (float, int)) 

432 and all(qq <= 1.0 for qq in q) 

433 ) 

434 assert dim is None or ( 

435 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 

436 ) 

437 return self.__class__.from_xarray( 

438 self._data.quantile(q, dim=dim, method=method) 

439 ) 

440 

441 def resize_to( 

442 self, 

443 sizes: PerAxis[int], 

444 *, 

445 pad_where: Union[ 

446 PadWhere, 

447 PerAxis[PadWhere], 

448 ] = "left_and_right", 

449 crop_where: Union[ 

450 CropWhere, 

451 PerAxis[CropWhere], 

452 ] = "left_and_right", 

453 pad_mode: PadMode = "symmetric", 

454 ): 

455 """return cropped/padded tensor with `sizes`""" 

456 crop_to_sizes: Dict[AxisId, int] = {} 

457 pad_to_sizes: Dict[AxisId, int] = {} 

458 new_axes = dict(sizes) 

459 for a, s_is in self.sizes.items(): 

460 a = AxisId(str(a)) 

461 _ = new_axes.pop(a, None) 

462 if a not in sizes or sizes[a] == s_is: 

463 pass 

464 elif s_is > sizes[a]: 

465 crop_to_sizes[a] = sizes[a] 

466 else: 

467 pad_to_sizes[a] = sizes[a] 

468 

469 tensor = self 

470 if crop_to_sizes: 

471 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 

472 

473 if pad_to_sizes: 

474 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 

475 

476 if new_axes: 

477 tensor = tensor.expand_dims(new_axes) 

478 

479 return tensor 

480 

481 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

482 return self.__class__.from_xarray(self._data.std(dim=dim)) 

483 

484 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

485 """Reduce this Tensor's data by applying sum along some dimension(s).""" 

486 return self.__class__.from_xarray(self._data.sum(dim=dim)) 

487 

488 def transpose( 

489 self, 

490 axes: Sequence[AxisId], 

491 ) -> Self: 

492 """return a transposed tensor 

493 

494 Args: 

495 axes: the desired tensor axes 

496 """ 

497 # expand missing tensor axes 

498 missing_axes = tuple(a for a in axes if a not in self.dims) 

499 array = self._data 

500 if missing_axes: 

501 array = array.expand_dims(missing_axes) 

502 

503 # transpose to the correct axis order 

504 return self.__class__.from_xarray(array.transpose(*axes)) 

505 

506 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

507 return self.__class__.from_xarray(self._data.var(dim=dim)) 

508 

509 @classmethod 

510 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 

511 ndim = array.ndim 

512 if ndim == 2: 

513 current_axes = ( 

514 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 

515 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 

516 ) 

517 elif ndim == 3 and any(s <= 3 for s in array.shape): 

518 current_axes = ( 

519 v0_5.ChannelAxis( 

520 channel_names=[ 

521 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

522 ] 

523 ), 

524 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

525 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

526 ) 

527 elif ndim == 3: 

528 current_axes = ( 

529 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 

530 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

531 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

532 ) 

533 elif ndim == 4: 

534 current_axes = ( 

535 v0_5.ChannelAxis( 

536 channel_names=[ 

537 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

538 ] 

539 ), 

540 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 

541 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 

542 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 

543 ) 

544 elif ndim == 5: 

545 current_axes = ( 

546 v0_5.BatchAxis(), 

547 v0_5.ChannelAxis( 

548 channel_names=[ 

549 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 

550 ] 

551 ), 

552 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 

553 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 

554 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 

555 ) 

556 else: 

557 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 

558 

559 return cls(array, dims=tuple(a.id for a in current_axes)) 

560 

561 

562def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 

563 if len(arr.shape) > len(axis_infos): 

564 # remove singletons 

565 for i, s in enumerate(arr.shape): 

566 if s == 1: 

567 arr = np.take(arr, 0, axis=i) 

568 if len(arr.shape) == len(axis_infos): 

569 break 

570 

571 # add singletons if nececsary 

572 for i, a in enumerate(axis_infos): 

573 if len(arr.shape) >= len(axis_infos): 

574 break 

575 

576 if a.size.min == 1: 

577 arr = np.expand_dims(arr, i) 

578 

579 return arr 

580 

581 

582def _get_array_view( 

583 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 

584) -> Optional[NDArray[Any]]: 

585 perms = list(permutations(range(len(original_array.shape)))) 

586 

587 for perm in perms: 

588 view = original_array.transpose(perm) 

589 view = _add_singletons(view, axis_infos) 

590 if len(view.shape) != len(axis_infos): 

591 return None 

592 

593 for s, a in zip(view.shape, axis_infos): 

594 if ( 

595 s < a.size.min 

596 or (a.size.max is not None and s > a.size.max) 

597 or (a.size.step is not None and (s - a.size.min) % a.size.step != 0) 

598 ): 

599 break 

600 else: 

601 return view 

602 

603 return None