Coverage for src/bioimageio/core/tensor.py: 82%

243 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-14 08:35 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from itertools import permutations 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Callable, 

9 Dict, 

10 Iterator, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Union, 

16 cast, 

17 get_args, 

18) 

19 

20import numpy as np 

21import xarray as xr 

22from bioimageio.spec.model import v0_5 

23from loguru import logger 

24from numpy.typing import DTypeLike, NDArray 

25from typing_extensions import Self, assert_never 

26 

27from ._magic_tensor_ops import MagicTensorOpsMixin 

28from .axis import AxisId, AxisInfo, AxisLike, PerAxis 

29from .common import ( 

30 CropWhere, 

31 DTypeStr, 

32 PadMode, 

33 PadWhere, 

34 PadWidth, 

35 PadWidthLike, 

36 SliceInfo, 

37) 

38 

39if TYPE_CHECKING: 

40 from numpy.typing import ArrayLike, NDArray 

41 

42 

43_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 

44 

45 

46# TODO: complete docstrings 

47# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray 

48class Tensor(MagicTensorOpsMixin): 

49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 

50 and improved type annotations.""" 

51 

52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 

53 

54 def __init__( 

55 self, 

56 array: NDArray[Any], 

57 dims: Sequence[Union[AxisId, AxisLike]], 

58 ) -> None: 

59 super().__init__() 

60 axes = tuple( 

61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 

62 ) 

63 self._data = xr.DataArray(array, dims=axes) 

64 

65 def __array__(self, dtype: DTypeLike = None): 

66 return np.asarray(self._data, dtype=dtype) 

67 

68 def __getitem__( 

69 self, 

70 key: Union[ 

71 SliceInfo, 

72 slice, 

73 int, 

74 PerAxis[Union[SliceInfo, slice, int]], 

75 Tensor, 

76 xr.DataArray, 

77 ], 

78 ) -> Self: 

79 if isinstance(key, SliceInfo): 

80 key = slice(*key) 

81 elif isinstance(key, collections.abc.Mapping): 

82 key = { 

83 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 

84 for a, s in key.items() 

85 } 

86 elif isinstance(key, Tensor): 

87 key = key._data 

88 

89 return self.__class__.from_xarray(self._data[key]) 

90 

91 def __setitem__( 

92 self, 

93 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 

94 value: Union[Tensor, xr.DataArray, float, int], 

95 ) -> None: 

96 if isinstance(key, Tensor): 

97 key = key._data 

98 elif isinstance(key, xr.DataArray): 

99 pass 

100 else: 

101 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 

102 

103 if isinstance(value, Tensor): 

104 value = value._data 

105 

106 self._data[key] = value 

107 

108 def __len__(self) -> int: 

109 return len(self.data) 

110 

111 def _iter(self: Any) -> Iterator[Any]: 

112 for n in range(len(self)): 

113 yield self[n] 

114 

115 def __iter__(self: Any) -> Iterator[Any]: 

116 if self.ndim == 0: 

117 raise TypeError("iteration over a 0-d array") 

118 return self._iter() 

119 

120 def _binary_op( 

121 self, 

122 other: _Compatible, 

123 f: Callable[[Any, Any], Any], 

124 reflexive: bool = False, 

125 ) -> Self: 

126 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 

127 (other._data if isinstance(other, Tensor) else other), 

128 f, 

129 reflexive, 

130 ) 

131 return self.__class__.from_xarray(data) 

132 

133 def _inplace_binary_op( 

134 self, 

135 other: _Compatible, 

136 f: Callable[[Any, Any], Any], 

137 ) -> Self: 

138 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 

139 ( 

140 other_d 

141 if (other_d := getattr(other, "data")) is not None 

142 and isinstance( 

143 other_d, 

144 xr.DataArray, 

145 ) 

146 else other 

147 ), 

148 f, 

149 ) 

150 return self 

151 

152 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 

153 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 

154 f, *args, **kwargs 

155 ) 

156 return self.__class__.from_xarray(data) 

157 

158 @classmethod 

159 def from_xarray(cls, data_array: xr.DataArray) -> Self: 

160 """create a `Tensor` from an xarray data array 

161 

162 note for internal use: this factory method is round-trip save 

163 for any `Tensor`'s `data` property (an xarray.DataArray). 

164 """ 

165 return cls( 

166 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 

167 ) 

168 

169 @classmethod 

170 def from_numpy( 

171 cls, 

172 array: NDArray[Any], 

173 *, 

174 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 

175 ) -> Tensor: 

176 """create a `Tensor` from a numpy array 

177 

178 Args: 

179 array: the nd numpy array 

180 axes: A description of the array's axes, 

181 if None axes are guessed (which might fail and raise a ValueError.) 

182 

183 Raises: 

184 ValueError: if `axes` is None and axes guessing fails. 

185 """ 

186 

187 if dims is None: 

188 return cls._interprete_array_wo_known_axes(array) 

189 elif isinstance(dims, collections.abc.Sequence): 

190 dim_seq = list(dims) 

191 else: 

192 dim_seq = [dims] 

193 

194 axis_infos = [AxisInfo.create(a) for a in dim_seq] 

195 original_shape = tuple(array.shape) 

196 

197 successful_view = _get_array_view(array, axis_infos) 

198 if successful_view is None: 

199 raise ValueError( 

200 f"Array shape {original_shape} does not map to axes {dims}" 

201 ) 

202 

203 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 

204 

205 @property 

206 def data(self): 

207 return self._data 

208 

209 @property 

210 def dims(self): # TODO: rename to `axes`? 

211 """Tuple of dimension names associated with this tensor.""" 

212 return cast(Tuple[AxisId, ...], self._data.dims) 

213 

214 @property 

215 def dtype(self) -> DTypeStr: 

216 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 

217 assert dt in get_args(DTypeStr) 

218 return dt # pyright: ignore[reportReturnType] 

219 

220 @property 

221 def ndim(self): 

222 """Number of tensor dimensions.""" 

223 return self._data.ndim 

224 

225 @property 

226 def shape(self): 

227 """Tuple of tensor axes lengths""" 

228 return self._data.shape 

229 

230 @property 

231 def shape_tuple(self): 

232 """Tuple of tensor axes lengths""" 

233 return self._data.shape 

234 

235 @property 

236 def size(self): 

237 """Number of elements in the tensor. 

238 

239 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 

240 """ 

241 return self._data.size 

242 

243 @property 

244 def sizes(self): 

245 """Ordered, immutable mapping from axis ids to axis lengths.""" 

246 return cast(Mapping[AxisId, int], self.data.sizes) 

247 

248 @property 

249 def tagged_shape(self): 

250 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 

251 return self.sizes 

252 

253 def argmax(self) -> Mapping[AxisId, int]: 

254 ret = self._data.argmax(...) 

255 assert isinstance(ret, dict) 

256 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 

257 

258 def astype(self, dtype: DTypeStr, *, copy: bool = False): 

259 """Return tensor cast to `dtype` 

260 

261 note: if dtype is already satisfied copy if `copy`""" 

262 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 

263 

264 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 

265 """Return a tensor whose values are limited to [min, max]. 

266 At least one of max or min must be given.""" 

267 return self.__class__.from_xarray(self._data.clip(min, max)) 

268 

269 def crop_to( 

270 self, 

271 sizes: PerAxis[int], 

272 crop_where: Union[ 

273 CropWhere, 

274 PerAxis[CropWhere], 

275 ] = "left_and_right", 

276 ) -> Self: 

277 """crop to match `sizes`""" 

278 if isinstance(crop_where, str): 

279 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 

280 else: 

281 crop_axis_where = crop_where 

282 

283 slices: Dict[AxisId, SliceInfo] = {} 

284 

285 for a, s_is in self.sizes.items(): 

286 if a not in sizes or sizes[a] == s_is: 

287 pass 

288 elif sizes[a] > s_is: 

289 logger.warning( 

290 "Cannot crop axis {} of size {} to larger size {}", 

291 a, 

292 s_is, 

293 sizes[a], 

294 ) 

295 elif a not in crop_axis_where: 

296 raise ValueError( 

297 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 

298 ) 

299 else: 

300 crop_this_axis_where = crop_axis_where[a] 

301 if crop_this_axis_where == "left": 

302 slices[a] = SliceInfo(s_is - sizes[a], s_is) 

303 elif crop_this_axis_where == "right": 

304 slices[a] = SliceInfo(0, sizes[a]) 

305 elif crop_this_axis_where == "left_and_right": 

306 slices[a] = SliceInfo( 

307 start := (s_is - sizes[a]) // 2, sizes[a] + start 

308 ) 

309 else: 

310 assert_never(crop_this_axis_where) 

311 

312 return self[slices] 

313 

314 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 

315 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 

316 

317 def item( 

318 self, 

319 key: Union[ 

320 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 

321 ] = None, 

322 ): 

323 """Copy a tensor element to a standard Python scalar and return it.""" 

324 if key is None: 

325 ret = self._data.item() 

326 else: 

327 ret = self[key]._data.item() 

328 

329 assert isinstance(ret, (bool, float, int)) 

330 return ret 

331 

332 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

333 return self.__class__.from_xarray(self._data.mean(dim=dim)) 

334 

335 def pad( 

336 self, 

337 pad_width: PerAxis[PadWidthLike], 

338 mode: PadMode = "symmetric", 

339 ) -> Self: 

340 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 

341 return self.__class__.from_xarray( 

342 self._data.pad(pad_width=pad_width, mode=mode) 

343 ) 

344 

345 def pad_to( 

346 self, 

347 sizes: PerAxis[int], 

348 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 

349 mode: PadMode = "symmetric", 

350 ) -> Self: 

351 """pad `tensor` to match `sizes`""" 

352 if isinstance(pad_where, str): 

353 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 

354 else: 

355 pad_axis_where = pad_where 

356 

357 pad_width: Dict[AxisId, PadWidth] = {} 

358 for a, s_is in self.sizes.items(): 

359 if a not in sizes or sizes[a] == s_is: 

360 pad_width[a] = PadWidth(0, 0) 

361 elif s_is > sizes[a]: 

362 pad_width[a] = PadWidth(0, 0) 

363 logger.warning( 

364 "Cannot pad axis {} of size {} to smaller size {}", 

365 a, 

366 s_is, 

367 sizes[a], 

368 ) 

369 elif a not in pad_axis_where: 

370 raise ValueError( 

371 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 

372 ) 

373 else: 

374 pad_this_axis_where = pad_axis_where[a] 

375 d = sizes[a] - s_is 

376 if pad_this_axis_where == "left": 

377 pad_width[a] = PadWidth(d, 0) 

378 elif pad_this_axis_where == "right": 

379 pad_width[a] = PadWidth(0, d) 

380 elif pad_this_axis_where == "left_and_right": 

381 pad_width[a] = PadWidth(left := d // 2, d - left) 

382 else: 

383 assert_never(pad_this_axis_where) 

384 

385 return self.pad(pad_width, mode) 

386 

387 def quantile( 

388 self, 

389 q: Union[float, Sequence[float]], 

390 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 

391 ) -> Self: 

392 assert ( 

393 isinstance(q, (float, int)) 

394 and q >= 0.0 

395 or not isinstance(q, (float, int)) 

396 and all(qq >= 0.0 for qq in q) 

397 ) 

398 assert ( 

399 isinstance(q, (float, int)) 

400 and q <= 1.0 

401 or not isinstance(q, (float, int)) 

402 and all(qq <= 1.0 for qq in q) 

403 ) 

404 assert dim is None or ( 

405 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 

406 ) 

407 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 

408 

409 def resize_to( 

410 self, 

411 sizes: PerAxis[int], 

412 *, 

413 pad_where: Union[ 

414 PadWhere, 

415 PerAxis[PadWhere], 

416 ] = "left_and_right", 

417 crop_where: Union[ 

418 CropWhere, 

419 PerAxis[CropWhere], 

420 ] = "left_and_right", 

421 pad_mode: PadMode = "symmetric", 

422 ): 

423 """return cropped/padded tensor with `sizes`""" 

424 crop_to_sizes: Dict[AxisId, int] = {} 

425 pad_to_sizes: Dict[AxisId, int] = {} 

426 new_axes = dict(sizes) 

427 for a, s_is in self.sizes.items(): 

428 a = AxisId(str(a)) 

429 _ = new_axes.pop(a, None) 

430 if a not in sizes or sizes[a] == s_is: 

431 pass 

432 elif s_is > sizes[a]: 

433 crop_to_sizes[a] = sizes[a] 

434 else: 

435 pad_to_sizes[a] = sizes[a] 

436 

437 tensor = self 

438 if crop_to_sizes: 

439 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 

440 

441 if pad_to_sizes: 

442 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 

443 

444 if new_axes: 

445 tensor = tensor.expand_dims(new_axes) 

446 

447 return tensor 

448 

449 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

450 return self.__class__.from_xarray(self._data.std(dim=dim)) 

451 

452 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

453 """Reduce this Tensor's data by applying sum along some dimension(s).""" 

454 return self.__class__.from_xarray(self._data.sum(dim=dim)) 

455 

456 def transpose( 

457 self, 

458 axes: Sequence[AxisId], 

459 ) -> Self: 

460 """return a transposed tensor 

461 

462 Args: 

463 axes: the desired tensor axes 

464 """ 

465 # expand missing tensor axes 

466 missing_axes = tuple(a for a in axes if a not in self.dims) 

467 array = self._data 

468 if missing_axes: 

469 array = array.expand_dims(missing_axes) 

470 

471 # transpose to the correct axis order 

472 return self.__class__.from_xarray(array.transpose(*axes)) 

473 

474 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

475 return self.__class__.from_xarray(self._data.var(dim=dim)) 

476 

477 @classmethod 

478 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 

479 ndim = array.ndim 

480 if ndim == 2: 

481 current_axes = ( 

482 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 

483 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 

484 ) 

485 elif ndim == 3 and any(s <= 3 for s in array.shape): 

486 current_axes = ( 

487 v0_5.ChannelAxis( 

488 channel_names=[ 

489 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

490 ] 

491 ), 

492 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

493 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

494 ) 

495 elif ndim == 3: 

496 current_axes = ( 

497 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 

498 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

499 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

500 ) 

501 elif ndim == 4: 

502 current_axes = ( 

503 v0_5.ChannelAxis( 

504 channel_names=[ 

505 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

506 ] 

507 ), 

508 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 

509 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 

510 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 

511 ) 

512 elif ndim == 5: 

513 current_axes = ( 

514 v0_5.BatchAxis(), 

515 v0_5.ChannelAxis( 

516 channel_names=[ 

517 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 

518 ] 

519 ), 

520 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 

521 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 

522 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 

523 ) 

524 else: 

525 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 

526 

527 return cls(array, dims=tuple(a.id for a in current_axes)) 

528 

529 

530def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 

531 if len(arr.shape) > len(axis_infos): 

532 # remove singletons 

533 for i, s in enumerate(arr.shape): 

534 if s == 1: 

535 arr = np.take(arr, 0, axis=i) 

536 if len(arr.shape) == len(axis_infos): 

537 break 

538 

539 # add singletons if nececsary 

540 for i, a in enumerate(axis_infos): 

541 if len(arr.shape) >= len(axis_infos): 

542 break 

543 

544 if a.maybe_singleton: 

545 arr = np.expand_dims(arr, i) 

546 

547 return arr 

548 

549 

550def _get_array_view( 

551 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 

552) -> Optional[NDArray[Any]]: 

553 perms = list(permutations(range(len(original_array.shape)))) 

554 perms.insert(1, perms.pop()) # try A and A.T first 

555 

556 for perm in perms: 

557 view = original_array.transpose(perm) 

558 view = _add_singletons(view, axis_infos) 

559 if len(view.shape) != len(axis_infos): 

560 return None 

561 

562 for s, a in zip(view.shape, axis_infos): 

563 if s == 1 and not a.maybe_singleton: 

564 break 

565 else: 

566 return view 

567 

568 return None