Coverage for src / bioimageio / core / tensor.py: 83%

247 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 18:38 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from itertools import permutations 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Callable, 

9 Dict, 

10 Iterator, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Union, 

16 cast, 

17 get_args, 

18) 

19 

20import numpy as np 

21import xarray as xr 

22from loguru import logger 

23from numpy.typing import DTypeLike, NDArray 

24from typing_extensions import Self, assert_never 

25 

26from bioimageio.spec.model import v0_5 

27 

28from ._magic_tensor_ops import MagicTensorOpsMixin 

29from .axis import AxisId, AxisInfo, AxisLike, PerAxis 

30from .common import ( 

31 CropWhere, 

32 DTypeStr, 

33 PadMode, 

34 PadWhere, 

35 PadWidth, 

36 PadWidthLike, 

37 QuantileMethod, 

38 SliceInfo, 

39) 

40 

41if TYPE_CHECKING: 

42 from numpy.typing import ArrayLike, NDArray 

43 

44 

45_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 

46 

47 

48# TODO: complete docstrings 

49# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray 

50class Tensor(MagicTensorOpsMixin): 

51 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 

52 and improved type annotations.""" 

53 

54 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 

55 

56 def __init__( 

57 self, 

58 array: NDArray[Any], 

59 dims: Sequence[Union[AxisId, AxisLike]], 

60 ) -> None: 

61 super().__init__() 

62 axes = tuple( 

63 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 

64 ) 

65 self._data = xr.DataArray(array, dims=axes) 

66 

67 def __repr__(self) -> str: 

68 return f"<Tensor {repr(self._data)}>" 

69 

70 def __array__(self, dtype: DTypeLike = None): 

71 return np.asarray(self._data, dtype=dtype) 

72 

73 def __getitem__( 

74 self, 

75 key: Union[ 

76 SliceInfo, 

77 slice, 

78 int, 

79 PerAxis[Union[SliceInfo, slice, int]], 

80 Tensor, 

81 xr.DataArray, 

82 ], 

83 ) -> Self: 

84 if isinstance(key, SliceInfo): 

85 key = slice(*key) 

86 elif isinstance(key, collections.abc.Mapping): 

87 key = { 

88 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 

89 for a, s in key.items() 

90 } 

91 elif isinstance(key, Tensor): 

92 key = key._data 

93 

94 return self.__class__.from_xarray(self._data[key]) 

95 

96 def __setitem__( 

97 self, 

98 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 

99 value: Union[Tensor, xr.DataArray, float, int], 

100 ) -> None: 

101 if isinstance(key, Tensor): 

102 key = key._data 

103 elif isinstance(key, xr.DataArray): 

104 pass 

105 else: 

106 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 

107 

108 if isinstance(value, Tensor): 

109 value = value._data 

110 

111 self._data[key] = value 

112 

113 def __len__(self) -> int: 

114 return len(self.data) 

115 

116 def _iter(self: Any) -> Iterator[Any]: 

117 for n in range(len(self)): 

118 yield self[n] 

119 

120 def __iter__(self: Any) -> Iterator[Any]: 

121 if self.ndim == 0: 

122 raise TypeError("iteration over a 0-d array") 

123 return self._iter() 

124 

125 def _binary_op( 

126 self, 

127 other: _Compatible, 

128 f: Callable[[Any, Any], Any], 

129 reflexive: bool = False, 

130 ) -> Self: 

131 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 

132 (other._data if isinstance(other, Tensor) else other), 

133 f, 

134 reflexive, 

135 ) 

136 return self.__class__.from_xarray(data) 

137 

138 def _inplace_binary_op( 

139 self, 

140 other: _Compatible, 

141 f: Callable[[Any, Any], Any], 

142 ) -> Self: 

143 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 

144 ( 

145 other_d 

146 if (other_d := getattr(other, "data")) is not None 

147 and isinstance( 

148 other_d, 

149 xr.DataArray, 

150 ) 

151 else other 

152 ), 

153 f, 

154 ) 

155 return self 

156 

157 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 

158 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 

159 f, *args, **kwargs 

160 ) 

161 return self.__class__.from_xarray(data) 

162 

163 @classmethod 

164 def from_xarray(cls, data_array: xr.DataArray) -> Self: 

165 """create a `Tensor` from an xarray data array 

166 

167 note for internal use: this factory method is round-trip save 

168 for any `Tensor`'s `data` property (an xarray.DataArray). 

169 """ 

170 return cls( 

171 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 

172 ) 

173 

174 @classmethod 

175 def from_numpy( 

176 cls, 

177 array: NDArray[Any], 

178 *, 

179 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 

180 ) -> Tensor: 

181 """create a `Tensor` from a numpy array 

182 

183 Args: 

184 array: the nd numpy array 

185 dims: A description of the array's axes, 

186 if None axes are guessed (which might fail and raise a ValueError.) 

187 

188 Raises: 

189 ValueError: if `dims` is None and dims guessing fails. 

190 """ 

191 

192 if dims is None: 

193 return cls._interprete_array_wo_known_axes(array) 

194 elif isinstance(dims, collections.abc.Sequence): 

195 dim_seq = list(dims) 

196 else: 

197 dim_seq = [dims] 

198 

199 axis_infos = [AxisInfo.create(a) for a in dim_seq] 

200 original_shape = tuple(array.shape) 

201 

202 successful_view = _get_array_view(array, axis_infos) 

203 if successful_view is None: 

204 raise ValueError( 

205 f"Array shape {original_shape} does not map to axes {dims}" 

206 ) 

207 

208 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 

209 

210 @property 

211 def data(self): 

212 return self._data 

213 

214 @property 

215 def dims(self): # TODO: rename to `axes`? 

216 """Tuple of dimension names associated with this tensor.""" 

217 return cast(Tuple[AxisId, ...], self._data.dims) 

218 

219 @property 

220 def dtype(self) -> DTypeStr: 

221 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 

222 assert dt in get_args(DTypeStr) 

223 return dt # pyright: ignore[reportReturnType] 

224 

225 @property 

226 def ndim(self): 

227 """Number of tensor dimensions.""" 

228 return self._data.ndim 

229 

230 @property 

231 def shape(self): 

232 """Tuple of tensor axes lengths""" 

233 return self._data.shape 

234 

235 @property 

236 def shape_tuple(self): 

237 """Tuple of tensor axes lengths""" 

238 return self._data.shape 

239 

240 @property 

241 def size(self): 

242 """Number of elements in the tensor. 

243 

244 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 

245 """ 

246 return self._data.size 

247 

248 @property 

249 def sizes(self): 

250 """Ordered, immutable mapping from axis ids to axis lengths.""" 

251 return cast(Mapping[AxisId, int], self.data.sizes) 

252 

253 @property 

254 def tagged_shape(self): 

255 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 

256 return self.sizes 

257 

258 def to_numpy(self) -> NDArray[Any]: 

259 """Return the data of this tensor as a numpy array.""" 

260 return self.data.to_numpy() # pyright: ignore[reportUnknownVariableType] 

261 

262 def argmax(self) -> Mapping[AxisId, int]: 

263 ret = self._data.argmax(...) 

264 assert isinstance(ret, dict) 

265 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 

266 

267 def astype(self, dtype: DTypeStr, *, copy: bool = False): 

268 """Return tensor cast to `dtype` 

269 

270 note: if dtype is already satisfied copy if `copy`""" 

271 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 

272 

273 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 

274 """Return a tensor whose values are limited to [min, max]. 

275 At least one of max or min must be given.""" 

276 return self.__class__.from_xarray(self._data.clip(min, max)) 

277 

278 def crop_to( 

279 self, 

280 sizes: PerAxis[int], 

281 crop_where: Union[ 

282 CropWhere, 

283 PerAxis[CropWhere], 

284 ] = "left_and_right", 

285 ) -> Self: 

286 """crop to match `sizes`""" 

287 if isinstance(crop_where, str): 

288 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 

289 else: 

290 crop_axis_where = crop_where 

291 

292 slices: Dict[AxisId, SliceInfo] = {} 

293 

294 for a, s_is in self.sizes.items(): 

295 if a not in sizes or sizes[a] == s_is: 

296 pass 

297 elif sizes[a] > s_is: 

298 logger.warning( 

299 "Cannot crop axis {} of size {} to larger size {}", 

300 a, 

301 s_is, 

302 sizes[a], 

303 ) 

304 elif a not in crop_axis_where: 

305 raise ValueError( 

306 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 

307 ) 

308 else: 

309 crop_this_axis_where = crop_axis_where[a] 

310 if crop_this_axis_where == "left": 

311 slices[a] = SliceInfo(s_is - sizes[a], s_is) 

312 elif crop_this_axis_where == "right": 

313 slices[a] = SliceInfo(0, sizes[a]) 

314 elif crop_this_axis_where == "left_and_right": 

315 slices[a] = SliceInfo( 

316 start := (s_is - sizes[a]) // 2, sizes[a] + start 

317 ) 

318 else: 

319 assert_never(crop_this_axis_where) 

320 

321 return self[slices] 

322 

323 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 

324 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 

325 

326 def item( 

327 self, 

328 key: Union[ 

329 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 

330 ] = None, 

331 ): 

332 """Copy a tensor element to a standard Python scalar and return it.""" 

333 if key is None: 

334 ret = self._data.item() 

335 else: 

336 ret = self[key]._data.item() 

337 

338 assert isinstance(ret, (bool, float, int)) 

339 return ret 

340 

341 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

342 return self.__class__.from_xarray(self._data.mean(dim=dim)) 

343 

344 def pad( 

345 self, 

346 pad_width: PerAxis[PadWidthLike], 

347 mode: PadMode = "symmetric", 

348 ) -> Self: 

349 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 

350 return self.__class__.from_xarray( 

351 self._data.pad(pad_width=pad_width, mode=mode) 

352 ) 

353 

354 def pad_to( 

355 self, 

356 sizes: PerAxis[int], 

357 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 

358 mode: PadMode = "symmetric", 

359 ) -> Self: 

360 """pad `tensor` to match `sizes`""" 

361 if isinstance(pad_where, str): 

362 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 

363 else: 

364 pad_axis_where = pad_where 

365 

366 pad_width: Dict[AxisId, PadWidth] = {} 

367 for a, s_is in self.sizes.items(): 

368 if a not in sizes or sizes[a] == s_is: 

369 pad_width[a] = PadWidth(0, 0) 

370 elif s_is > sizes[a]: 

371 pad_width[a] = PadWidth(0, 0) 

372 logger.warning( 

373 "Cannot pad axis {} of size {} to smaller size {}", 

374 a, 

375 s_is, 

376 sizes[a], 

377 ) 

378 elif a not in pad_axis_where: 

379 raise ValueError( 

380 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 

381 ) 

382 else: 

383 pad_this_axis_where = pad_axis_where[a] 

384 d = sizes[a] - s_is 

385 if pad_this_axis_where == "left": 

386 pad_width[a] = PadWidth(d, 0) 

387 elif pad_this_axis_where == "right": 

388 pad_width[a] = PadWidth(0, d) 

389 elif pad_this_axis_where == "left_and_right": 

390 pad_width[a] = PadWidth(left := d // 2, d - left) 

391 else: 

392 assert_never(pad_this_axis_where) 

393 

394 return self.pad(pad_width, mode) 

395 

396 def quantile( 

397 self, 

398 q: Union[float, Sequence[float]], 

399 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 

400 method: QuantileMethod = "linear", 

401 ) -> Self: 

402 assert ( 

403 isinstance(q, (float, int)) 

404 and q >= 0.0 

405 or not isinstance(q, (float, int)) 

406 and all(qq >= 0.0 for qq in q) 

407 ) 

408 assert ( 

409 isinstance(q, (float, int)) 

410 and q <= 1.0 

411 or not isinstance(q, (float, int)) 

412 and all(qq <= 1.0 for qq in q) 

413 ) 

414 assert dim is None or ( 

415 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 

416 ) 

417 return self.__class__.from_xarray( 

418 self._data.quantile(q, dim=dim, method=method) 

419 ) 

420 

421 def resize_to( 

422 self, 

423 sizes: PerAxis[int], 

424 *, 

425 pad_where: Union[ 

426 PadWhere, 

427 PerAxis[PadWhere], 

428 ] = "left_and_right", 

429 crop_where: Union[ 

430 CropWhere, 

431 PerAxis[CropWhere], 

432 ] = "left_and_right", 

433 pad_mode: PadMode = "symmetric", 

434 ): 

435 """return cropped/padded tensor with `sizes`""" 

436 crop_to_sizes: Dict[AxisId, int] = {} 

437 pad_to_sizes: Dict[AxisId, int] = {} 

438 new_axes = dict(sizes) 

439 for a, s_is in self.sizes.items(): 

440 a = AxisId(str(a)) 

441 _ = new_axes.pop(a, None) 

442 if a not in sizes or sizes[a] == s_is: 

443 pass 

444 elif s_is > sizes[a]: 

445 crop_to_sizes[a] = sizes[a] 

446 else: 

447 pad_to_sizes[a] = sizes[a] 

448 

449 tensor = self 

450 if crop_to_sizes: 

451 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 

452 

453 if pad_to_sizes: 

454 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 

455 

456 if new_axes: 

457 tensor = tensor.expand_dims(new_axes) 

458 

459 return tensor 

460 

461 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

462 return self.__class__.from_xarray(self._data.std(dim=dim)) 

463 

464 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

465 """Reduce this Tensor's data by applying sum along some dimension(s).""" 

466 return self.__class__.from_xarray(self._data.sum(dim=dim)) 

467 

468 def transpose( 

469 self, 

470 axes: Sequence[AxisId], 

471 ) -> Self: 

472 """return a transposed tensor 

473 

474 Args: 

475 axes: the desired tensor axes 

476 """ 

477 # expand missing tensor axes 

478 missing_axes = tuple(a for a in axes if a not in self.dims) 

479 array = self._data 

480 if missing_axes: 

481 array = array.expand_dims(missing_axes) 

482 

483 # transpose to the correct axis order 

484 return self.__class__.from_xarray(array.transpose(*axes)) 

485 

486 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

487 return self.__class__.from_xarray(self._data.var(dim=dim)) 

488 

489 @classmethod 

490 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 

491 ndim = array.ndim 

492 if ndim == 2: 

493 current_axes = ( 

494 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 

495 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 

496 ) 

497 elif ndim == 3 and any(s <= 3 for s in array.shape): 

498 current_axes = ( 

499 v0_5.ChannelAxis( 

500 channel_names=[ 

501 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

502 ] 

503 ), 

504 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

505 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

506 ) 

507 elif ndim == 3: 

508 current_axes = ( 

509 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 

510 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

511 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

512 ) 

513 elif ndim == 4: 

514 current_axes = ( 

515 v0_5.ChannelAxis( 

516 channel_names=[ 

517 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

518 ] 

519 ), 

520 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 

521 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 

522 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 

523 ) 

524 elif ndim == 5: 

525 current_axes = ( 

526 v0_5.BatchAxis(), 

527 v0_5.ChannelAxis( 

528 channel_names=[ 

529 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 

530 ] 

531 ), 

532 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 

533 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 

534 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 

535 ) 

536 else: 

537 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 

538 

539 return cls(array, dims=tuple(a.id for a in current_axes)) 

540 

541 

542def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 

543 if len(arr.shape) > len(axis_infos): 

544 # remove singletons 

545 for i, s in enumerate(arr.shape): 

546 if s == 1: 

547 arr = np.take(arr, 0, axis=i) 

548 if len(arr.shape) == len(axis_infos): 

549 break 

550 

551 # add singletons if nececsary 

552 for i, a in enumerate(axis_infos): 

553 if len(arr.shape) >= len(axis_infos): 

554 break 

555 

556 if a.maybe_singleton: 

557 arr = np.expand_dims(arr, i) 

558 

559 return arr 

560 

561 

562def _get_array_view( 

563 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 

564) -> Optional[NDArray[Any]]: 

565 perms = list(permutations(range(len(original_array.shape)))) 

566 perms.insert(1, perms.pop()) # try A and A.T first 

567 

568 for perm in perms: 

569 view = original_array.transpose(perm) 

570 view = _add_singletons(view, axis_infos) 

571 if len(view.shape) != len(axis_infos): 

572 return None 

573 

574 for s, a in zip(view.shape, axis_infos): 

575 if s == 1 and not a.maybe_singleton: 

576 break 

577 else: 

578 return view 

579 

580 return None