Coverage for src / bioimageio / core / tensor.py: 83%

245 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from itertools import permutations 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Callable, 

9 Dict, 

10 Iterator, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Union, 

16 cast, 

17 get_args, 

18) 

19 

20import numpy as np 

21import xarray as xr 

22from loguru import logger 

23from numpy.typing import DTypeLike, NDArray 

24from typing_extensions import Self, assert_never 

25 

26from bioimageio.spec.model import v0_5 

27 

28from ._magic_tensor_ops import MagicTensorOpsMixin 

29from .axis import AxisId, AxisInfo, AxisLike, PerAxis 

30from .common import ( 

31 CropWhere, 

32 DTypeStr, 

33 PadMode, 

34 PadWhere, 

35 PadWidth, 

36 PadWidthLike, 

37 QuantileMethod, 

38 SliceInfo, 

39) 

40 

41if TYPE_CHECKING: 

42 from numpy.typing import ArrayLike, NDArray 

43 

44 

45_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 

46 

47 

48# TODO: complete docstrings 

49# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray 

50class Tensor(MagicTensorOpsMixin): 

51 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 

52 and improved type annotations.""" 

53 

54 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 

55 

56 def __init__( 

57 self, 

58 array: NDArray[Any], 

59 dims: Sequence[Union[AxisId, AxisLike]], 

60 ) -> None: 

61 super().__init__() 

62 axes = tuple( 

63 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 

64 ) 

65 self._data = xr.DataArray(array, dims=axes) 

66 

67 def __array__(self, dtype: DTypeLike = None): 

68 return np.asarray(self._data, dtype=dtype) 

69 

70 def __getitem__( 

71 self, 

72 key: Union[ 

73 SliceInfo, 

74 slice, 

75 int, 

76 PerAxis[Union[SliceInfo, slice, int]], 

77 Tensor, 

78 xr.DataArray, 

79 ], 

80 ) -> Self: 

81 if isinstance(key, SliceInfo): 

82 key = slice(*key) 

83 elif isinstance(key, collections.abc.Mapping): 

84 key = { 

85 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 

86 for a, s in key.items() 

87 } 

88 elif isinstance(key, Tensor): 

89 key = key._data 

90 

91 return self.__class__.from_xarray(self._data[key]) 

92 

93 def __setitem__( 

94 self, 

95 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 

96 value: Union[Tensor, xr.DataArray, float, int], 

97 ) -> None: 

98 if isinstance(key, Tensor): 

99 key = key._data 

100 elif isinstance(key, xr.DataArray): 

101 pass 

102 else: 

103 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 

104 

105 if isinstance(value, Tensor): 

106 value = value._data 

107 

108 self._data[key] = value 

109 

110 def __len__(self) -> int: 

111 return len(self.data) 

112 

113 def _iter(self: Any) -> Iterator[Any]: 

114 for n in range(len(self)): 

115 yield self[n] 

116 

117 def __iter__(self: Any) -> Iterator[Any]: 

118 if self.ndim == 0: 

119 raise TypeError("iteration over a 0-d array") 

120 return self._iter() 

121 

122 def _binary_op( 

123 self, 

124 other: _Compatible, 

125 f: Callable[[Any, Any], Any], 

126 reflexive: bool = False, 

127 ) -> Self: 

128 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 

129 (other._data if isinstance(other, Tensor) else other), 

130 f, 

131 reflexive, 

132 ) 

133 return self.__class__.from_xarray(data) 

134 

135 def _inplace_binary_op( 

136 self, 

137 other: _Compatible, 

138 f: Callable[[Any, Any], Any], 

139 ) -> Self: 

140 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 

141 ( 

142 other_d 

143 if (other_d := getattr(other, "data")) is not None 

144 and isinstance( 

145 other_d, 

146 xr.DataArray, 

147 ) 

148 else other 

149 ), 

150 f, 

151 ) 

152 return self 

153 

154 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 

155 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 

156 f, *args, **kwargs 

157 ) 

158 return self.__class__.from_xarray(data) 

159 

160 @classmethod 

161 def from_xarray(cls, data_array: xr.DataArray) -> Self: 

162 """create a `Tensor` from an xarray data array 

163 

164 note for internal use: this factory method is round-trip save 

165 for any `Tensor`'s `data` property (an xarray.DataArray). 

166 """ 

167 return cls( 

168 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 

169 ) 

170 

171 @classmethod 

172 def from_numpy( 

173 cls, 

174 array: NDArray[Any], 

175 *, 

176 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 

177 ) -> Tensor: 

178 """create a `Tensor` from a numpy array 

179 

180 Args: 

181 array: the nd numpy array 

182 dims: A description of the array's axes, 

183 if None axes are guessed (which might fail and raise a ValueError.) 

184 

185 Raises: 

186 ValueError: if `dims` is None and dims guessing fails. 

187 """ 

188 

189 if dims is None: 

190 return cls._interprete_array_wo_known_axes(array) 

191 elif isinstance(dims, collections.abc.Sequence): 

192 dim_seq = list(dims) 

193 else: 

194 dim_seq = [dims] 

195 

196 axis_infos = [AxisInfo.create(a) for a in dim_seq] 

197 original_shape = tuple(array.shape) 

198 

199 successful_view = _get_array_view(array, axis_infos) 

200 if successful_view is None: 

201 raise ValueError( 

202 f"Array shape {original_shape} does not map to axes {dims}" 

203 ) 

204 

205 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 

206 

207 @property 

208 def data(self): 

209 return self._data 

210 

211 @property 

212 def dims(self): # TODO: rename to `axes`? 

213 """Tuple of dimension names associated with this tensor.""" 

214 return cast(Tuple[AxisId, ...], self._data.dims) 

215 

216 @property 

217 def dtype(self) -> DTypeStr: 

218 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 

219 assert dt in get_args(DTypeStr) 

220 return dt # pyright: ignore[reportReturnType] 

221 

222 @property 

223 def ndim(self): 

224 """Number of tensor dimensions.""" 

225 return self._data.ndim 

226 

227 @property 

228 def shape(self): 

229 """Tuple of tensor axes lengths""" 

230 return self._data.shape 

231 

232 @property 

233 def shape_tuple(self): 

234 """Tuple of tensor axes lengths""" 

235 return self._data.shape 

236 

237 @property 

238 def size(self): 

239 """Number of elements in the tensor. 

240 

241 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 

242 """ 

243 return self._data.size 

244 

245 @property 

246 def sizes(self): 

247 """Ordered, immutable mapping from axis ids to axis lengths.""" 

248 return cast(Mapping[AxisId, int], self.data.sizes) 

249 

250 @property 

251 def tagged_shape(self): 

252 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 

253 return self.sizes 

254 

255 def to_numpy(self) -> NDArray[Any]: 

256 """Return the data of this tensor as a numpy array.""" 

257 return self.data.to_numpy() # pyright: ignore[reportUnknownVariableType] 

258 

259 def argmax(self) -> Mapping[AxisId, int]: 

260 ret = self._data.argmax(...) 

261 assert isinstance(ret, dict) 

262 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 

263 

264 def astype(self, dtype: DTypeStr, *, copy: bool = False): 

265 """Return tensor cast to `dtype` 

266 

267 note: if dtype is already satisfied copy if `copy`""" 

268 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 

269 

270 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 

271 """Return a tensor whose values are limited to [min, max]. 

272 At least one of max or min must be given.""" 

273 return self.__class__.from_xarray(self._data.clip(min, max)) 

274 

275 def crop_to( 

276 self, 

277 sizes: PerAxis[int], 

278 crop_where: Union[ 

279 CropWhere, 

280 PerAxis[CropWhere], 

281 ] = "left_and_right", 

282 ) -> Self: 

283 """crop to match `sizes`""" 

284 if isinstance(crop_where, str): 

285 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 

286 else: 

287 crop_axis_where = crop_where 

288 

289 slices: Dict[AxisId, SliceInfo] = {} 

290 

291 for a, s_is in self.sizes.items(): 

292 if a not in sizes or sizes[a] == s_is: 

293 pass 

294 elif sizes[a] > s_is: 

295 logger.warning( 

296 "Cannot crop axis {} of size {} to larger size {}", 

297 a, 

298 s_is, 

299 sizes[a], 

300 ) 

301 elif a not in crop_axis_where: 

302 raise ValueError( 

303 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 

304 ) 

305 else: 

306 crop_this_axis_where = crop_axis_where[a] 

307 if crop_this_axis_where == "left": 

308 slices[a] = SliceInfo(s_is - sizes[a], s_is) 

309 elif crop_this_axis_where == "right": 

310 slices[a] = SliceInfo(0, sizes[a]) 

311 elif crop_this_axis_where == "left_and_right": 

312 slices[a] = SliceInfo( 

313 start := (s_is - sizes[a]) // 2, sizes[a] + start 

314 ) 

315 else: 

316 assert_never(crop_this_axis_where) 

317 

318 return self[slices] 

319 

320 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 

321 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 

322 

323 def item( 

324 self, 

325 key: Union[ 

326 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 

327 ] = None, 

328 ): 

329 """Copy a tensor element to a standard Python scalar and return it.""" 

330 if key is None: 

331 ret = self._data.item() 

332 else: 

333 ret = self[key]._data.item() 

334 

335 assert isinstance(ret, (bool, float, int)) 

336 return ret 

337 

338 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

339 return self.__class__.from_xarray(self._data.mean(dim=dim)) 

340 

341 def pad( 

342 self, 

343 pad_width: PerAxis[PadWidthLike], 

344 mode: PadMode = "symmetric", 

345 ) -> Self: 

346 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 

347 return self.__class__.from_xarray( 

348 self._data.pad(pad_width=pad_width, mode=mode) 

349 ) 

350 

351 def pad_to( 

352 self, 

353 sizes: PerAxis[int], 

354 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 

355 mode: PadMode = "symmetric", 

356 ) -> Self: 

357 """pad `tensor` to match `sizes`""" 

358 if isinstance(pad_where, str): 

359 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 

360 else: 

361 pad_axis_where = pad_where 

362 

363 pad_width: Dict[AxisId, PadWidth] = {} 

364 for a, s_is in self.sizes.items(): 

365 if a not in sizes or sizes[a] == s_is: 

366 pad_width[a] = PadWidth(0, 0) 

367 elif s_is > sizes[a]: 

368 pad_width[a] = PadWidth(0, 0) 

369 logger.warning( 

370 "Cannot pad axis {} of size {} to smaller size {}", 

371 a, 

372 s_is, 

373 sizes[a], 

374 ) 

375 elif a not in pad_axis_where: 

376 raise ValueError( 

377 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 

378 ) 

379 else: 

380 pad_this_axis_where = pad_axis_where[a] 

381 d = sizes[a] - s_is 

382 if pad_this_axis_where == "left": 

383 pad_width[a] = PadWidth(d, 0) 

384 elif pad_this_axis_where == "right": 

385 pad_width[a] = PadWidth(0, d) 

386 elif pad_this_axis_where == "left_and_right": 

387 pad_width[a] = PadWidth(left := d // 2, d - left) 

388 else: 

389 assert_never(pad_this_axis_where) 

390 

391 return self.pad(pad_width, mode) 

392 

393 def quantile( 

394 self, 

395 q: Union[float, Sequence[float]], 

396 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 

397 method: QuantileMethod = "linear", 

398 ) -> Self: 

399 assert ( 

400 isinstance(q, (float, int)) 

401 and q >= 0.0 

402 or not isinstance(q, (float, int)) 

403 and all(qq >= 0.0 for qq in q) 

404 ) 

405 assert ( 

406 isinstance(q, (float, int)) 

407 and q <= 1.0 

408 or not isinstance(q, (float, int)) 

409 and all(qq <= 1.0 for qq in q) 

410 ) 

411 assert dim is None or ( 

412 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 

413 ) 

414 return self.__class__.from_xarray( 

415 self._data.quantile(q, dim=dim, method=method) 

416 ) 

417 

418 def resize_to( 

419 self, 

420 sizes: PerAxis[int], 

421 *, 

422 pad_where: Union[ 

423 PadWhere, 

424 PerAxis[PadWhere], 

425 ] = "left_and_right", 

426 crop_where: Union[ 

427 CropWhere, 

428 PerAxis[CropWhere], 

429 ] = "left_and_right", 

430 pad_mode: PadMode = "symmetric", 

431 ): 

432 """return cropped/padded tensor with `sizes`""" 

433 crop_to_sizes: Dict[AxisId, int] = {} 

434 pad_to_sizes: Dict[AxisId, int] = {} 

435 new_axes = dict(sizes) 

436 for a, s_is in self.sizes.items(): 

437 a = AxisId(str(a)) 

438 _ = new_axes.pop(a, None) 

439 if a not in sizes or sizes[a] == s_is: 

440 pass 

441 elif s_is > sizes[a]: 

442 crop_to_sizes[a] = sizes[a] 

443 else: 

444 pad_to_sizes[a] = sizes[a] 

445 

446 tensor = self 

447 if crop_to_sizes: 

448 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 

449 

450 if pad_to_sizes: 

451 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 

452 

453 if new_axes: 

454 tensor = tensor.expand_dims(new_axes) 

455 

456 return tensor 

457 

458 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

459 return self.__class__.from_xarray(self._data.std(dim=dim)) 

460 

461 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

462 """Reduce this Tensor's data by applying sum along some dimension(s).""" 

463 return self.__class__.from_xarray(self._data.sum(dim=dim)) 

464 

465 def transpose( 

466 self, 

467 axes: Sequence[AxisId], 

468 ) -> Self: 

469 """return a transposed tensor 

470 

471 Args: 

472 axes: the desired tensor axes 

473 """ 

474 # expand missing tensor axes 

475 missing_axes = tuple(a for a in axes if a not in self.dims) 

476 array = self._data 

477 if missing_axes: 

478 array = array.expand_dims(missing_axes) 

479 

480 # transpose to the correct axis order 

481 return self.__class__.from_xarray(array.transpose(*axes)) 

482 

483 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

484 return self.__class__.from_xarray(self._data.var(dim=dim)) 

485 

486 @classmethod 

487 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 

488 ndim = array.ndim 

489 if ndim == 2: 

490 current_axes = ( 

491 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 

492 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 

493 ) 

494 elif ndim == 3 and any(s <= 3 for s in array.shape): 

495 current_axes = ( 

496 v0_5.ChannelAxis( 

497 channel_names=[ 

498 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

499 ] 

500 ), 

501 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

502 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

503 ) 

504 elif ndim == 3: 

505 current_axes = ( 

506 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 

507 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

508 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

509 ) 

510 elif ndim == 4: 

511 current_axes = ( 

512 v0_5.ChannelAxis( 

513 channel_names=[ 

514 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

515 ] 

516 ), 

517 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 

518 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 

519 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 

520 ) 

521 elif ndim == 5: 

522 current_axes = ( 

523 v0_5.BatchAxis(), 

524 v0_5.ChannelAxis( 

525 channel_names=[ 

526 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 

527 ] 

528 ), 

529 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 

530 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 

531 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 

532 ) 

533 else: 

534 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 

535 

536 return cls(array, dims=tuple(a.id for a in current_axes)) 

537 

538 

539def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 

540 if len(arr.shape) > len(axis_infos): 

541 # remove singletons 

542 for i, s in enumerate(arr.shape): 

543 if s == 1: 

544 arr = np.take(arr, 0, axis=i) 

545 if len(arr.shape) == len(axis_infos): 

546 break 

547 

548 # add singletons if nececsary 

549 for i, a in enumerate(axis_infos): 

550 if len(arr.shape) >= len(axis_infos): 

551 break 

552 

553 if a.maybe_singleton: 

554 arr = np.expand_dims(arr, i) 

555 

556 return arr 

557 

558 

559def _get_array_view( 

560 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 

561) -> Optional[NDArray[Any]]: 

562 perms = list(permutations(range(len(original_array.shape)))) 

563 perms.insert(1, perms.pop()) # try A and A.T first 

564 

565 for perm in perms: 

566 view = original_array.transpose(perm) 

567 view = _add_singletons(view, axis_infos) 

568 if len(view.shape) != len(axis_infos): 

569 return None 

570 

571 for s, a in zip(view.shape, axis_infos): 

572 if s == 1 and not a.maybe_singleton: 

573 break 

574 else: 

575 return view 

576 

577 return None