Coverage for src/bioimageio/core/tensor.py: 82%

242 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-22 09:21 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from itertools import permutations 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Callable, 

9 Dict, 

10 Iterator, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Union, 

16 cast, 

17 get_args, 

18) 

19 

20import numpy as np 

21import xarray as xr 

22from loguru import logger 

23from numpy.typing import DTypeLike, NDArray 

24from typing_extensions import Self, assert_never 

25 

26from bioimageio.spec.model import v0_5 

27 

28from ._magic_tensor_ops import MagicTensorOpsMixin 

29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 

30from .common import ( 

31 CropWhere, 

32 DTypeStr, 

33 PadMode, 

34 PadWhere, 

35 PadWidth, 

36 PadWidthLike, 

37 SliceInfo, 

38) 

39 

40if TYPE_CHECKING: 

41 from numpy.typing import ArrayLike, NDArray 

42 

43 

44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 

45 

46 

47# TODO: complete docstrings 

48# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray 

49class Tensor(MagicTensorOpsMixin): 

50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 

51 and improved type annotations.""" 

52 

53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 

54 

55 def __init__( 

56 self, 

57 array: NDArray[Any], 

58 dims: Sequence[Union[AxisId, AxisLike]], 

59 ) -> None: 

60 super().__init__() 

61 axes = tuple( 

62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 

63 ) 

64 self._data = xr.DataArray(array, dims=axes) 

65 

66 def __array__(self, dtype: DTypeLike = None): 

67 return np.asarray(self._data, dtype=dtype) 

68 

69 def __getitem__( 

70 self, 

71 key: Union[ 

72 SliceInfo, 

73 slice, 

74 int, 

75 PerAxis[Union[SliceInfo, slice, int]], 

76 Tensor, 

77 xr.DataArray, 

78 ], 

79 ) -> Self: 

80 if isinstance(key, SliceInfo): 

81 key = slice(*key) 

82 elif isinstance(key, collections.abc.Mapping): 

83 key = { 

84 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 

85 for a, s in key.items() 

86 } 

87 elif isinstance(key, Tensor): 

88 key = key._data 

89 

90 return self.__class__.from_xarray(self._data[key]) 

91 

92 def __setitem__( 

93 self, 

94 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 

95 value: Union[Tensor, xr.DataArray, float, int], 

96 ) -> None: 

97 if isinstance(key, Tensor): 

98 key = key._data 

99 elif isinstance(key, xr.DataArray): 

100 pass 

101 else: 

102 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 

103 

104 if isinstance(value, Tensor): 

105 value = value._data 

106 

107 self._data[key] = value 

108 

109 def __len__(self) -> int: 

110 return len(self.data) 

111 

112 def _iter(self: Any) -> Iterator[Any]: 

113 for n in range(len(self)): 

114 yield self[n] 

115 

116 def __iter__(self: Any) -> Iterator[Any]: 

117 if self.ndim == 0: 

118 raise TypeError("iteration over a 0-d array") 

119 return self._iter() 

120 

121 def _binary_op( 

122 self, 

123 other: _Compatible, 

124 f: Callable[[Any, Any], Any], 

125 reflexive: bool = False, 

126 ) -> Self: 

127 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 

128 (other._data if isinstance(other, Tensor) else other), 

129 f, 

130 reflexive, 

131 ) 

132 return self.__class__.from_xarray(data) 

133 

134 def _inplace_binary_op( 

135 self, 

136 other: _Compatible, 

137 f: Callable[[Any, Any], Any], 

138 ) -> Self: 

139 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 

140 ( 

141 other_d 

142 if (other_d := getattr(other, "data")) is not None 

143 and isinstance( 

144 other_d, 

145 xr.DataArray, 

146 ) 

147 else other 

148 ), 

149 f, 

150 ) 

151 return self 

152 

153 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 

154 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 

155 f, *args, **kwargs 

156 ) 

157 return self.__class__.from_xarray(data) 

158 

159 @classmethod 

160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 

161 """create a `Tensor` from an xarray data array 

162 

163 note for internal use: this factory method is round-trip save 

164 for any `Tensor`'s `data` property (an xarray.DataArray). 

165 """ 

166 return cls( 

167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 

168 ) 

169 

170 @classmethod 

171 def from_numpy( 

172 cls, 

173 array: NDArray[Any], 

174 *, 

175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 

176 ) -> Tensor: 

177 """create a `Tensor` from a numpy array 

178 

179 Args: 

180 array: the nd numpy array 

181 axes: A description of the array's axes, 

182 if None axes are guessed (which might fail and raise a ValueError.) 

183 

184 Raises: 

185 ValueError: if `axes` is None and axes guessing fails. 

186 """ 

187 

188 if dims is None: 

189 return cls._interprete_array_wo_known_axes(array) 

190 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 

191 dims = [dims] 

192 

193 axis_infos = [AxisInfo.create(a) for a in dims] 

194 original_shape = tuple(array.shape) 

195 

196 successful_view = _get_array_view(array, axis_infos) 

197 if successful_view is None: 

198 raise ValueError( 

199 f"Array shape {original_shape} does not map to axes {dims}" 

200 ) 

201 

202 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 

203 

204 @property 

205 def data(self): 

206 return self._data 

207 

208 @property 

209 def dims(self): # TODO: rename to `axes`? 

210 """Tuple of dimension names associated with this tensor.""" 

211 return cast(Tuple[AxisId, ...], self._data.dims) 

212 

213 @property 

214 def dtype(self) -> DTypeStr: 

215 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 

216 assert dt in get_args(DTypeStr) 

217 return dt # pyright: ignore[reportReturnType] 

218 

219 @property 

220 def ndim(self): 

221 """Number of tensor dimensions.""" 

222 return self._data.ndim 

223 

224 @property 

225 def shape(self): 

226 """Tuple of tensor axes lengths""" 

227 return self._data.shape 

228 

229 @property 

230 def shape_tuple(self): 

231 """Tuple of tensor axes lengths""" 

232 return self._data.shape 

233 

234 @property 

235 def size(self): 

236 """Number of elements in the tensor. 

237 

238 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 

239 """ 

240 return self._data.size 

241 

242 @property 

243 def sizes(self): 

244 """Ordered, immutable mapping from axis ids to axis lengths.""" 

245 return cast(Mapping[AxisId, int], self.data.sizes) 

246 

247 @property 

248 def tagged_shape(self): 

249 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 

250 return self.sizes 

251 

252 def argmax(self) -> Mapping[AxisId, int]: 

253 ret = self._data.argmax(...) 

254 assert isinstance(ret, dict) 

255 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 

256 

257 def astype(self, dtype: DTypeStr, *, copy: bool = False): 

258 """Return tensor cast to `dtype` 

259 

260 note: if dtype is already satisfied copy if `copy`""" 

261 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 

262 

263 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 

264 """Return a tensor whose values are limited to [min, max]. 

265 At least one of max or min must be given.""" 

266 return self.__class__.from_xarray(self._data.clip(min, max)) 

267 

268 def crop_to( 

269 self, 

270 sizes: PerAxis[int], 

271 crop_where: Union[ 

272 CropWhere, 

273 PerAxis[CropWhere], 

274 ] = "left_and_right", 

275 ) -> Self: 

276 """crop to match `sizes`""" 

277 if isinstance(crop_where, str): 

278 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 

279 else: 

280 crop_axis_where = crop_where 

281 

282 slices: Dict[AxisId, SliceInfo] = {} 

283 

284 for a, s_is in self.sizes.items(): 

285 if a not in sizes or sizes[a] == s_is: 

286 pass 

287 elif sizes[a] > s_is: 

288 logger.warning( 

289 "Cannot crop axis {} of size {} to larger size {}", 

290 a, 

291 s_is, 

292 sizes[a], 

293 ) 

294 elif a not in crop_axis_where: 

295 raise ValueError( 

296 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 

297 ) 

298 else: 

299 crop_this_axis_where = crop_axis_where[a] 

300 if crop_this_axis_where == "left": 

301 slices[a] = SliceInfo(s_is - sizes[a], s_is) 

302 elif crop_this_axis_where == "right": 

303 slices[a] = SliceInfo(0, sizes[a]) 

304 elif crop_this_axis_where == "left_and_right": 

305 slices[a] = SliceInfo( 

306 start := (s_is - sizes[a]) // 2, sizes[a] + start 

307 ) 

308 else: 

309 assert_never(crop_this_axis_where) 

310 

311 return self[slices] 

312 

313 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 

314 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 

315 

316 def item( 

317 self, 

318 key: Union[ 

319 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 

320 ] = None, 

321 ): 

322 """Copy a tensor element to a standard Python scalar and return it.""" 

323 if key is None: 

324 ret = self._data.item() 

325 else: 

326 ret = self[key]._data.item() 

327 

328 assert isinstance(ret, (bool, float, int)) 

329 return ret 

330 

331 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

332 return self.__class__.from_xarray(self._data.mean(dim=dim)) 

333 

334 def pad( 

335 self, 

336 pad_width: PerAxis[PadWidthLike], 

337 mode: PadMode = "symmetric", 

338 ) -> Self: 

339 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 

340 return self.__class__.from_xarray( 

341 self._data.pad(pad_width=pad_width, mode=mode) 

342 ) 

343 

344 def pad_to( 

345 self, 

346 sizes: PerAxis[int], 

347 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 

348 mode: PadMode = "symmetric", 

349 ) -> Self: 

350 """pad `tensor` to match `sizes`""" 

351 if isinstance(pad_where, str): 

352 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 

353 else: 

354 pad_axis_where = pad_where 

355 

356 pad_width: Dict[AxisId, PadWidth] = {} 

357 for a, s_is in self.sizes.items(): 

358 if a not in sizes or sizes[a] == s_is: 

359 pad_width[a] = PadWidth(0, 0) 

360 elif s_is > sizes[a]: 

361 pad_width[a] = PadWidth(0, 0) 

362 logger.warning( 

363 "Cannot pad axis {} of size {} to smaller size {}", 

364 a, 

365 s_is, 

366 sizes[a], 

367 ) 

368 elif a not in pad_axis_where: 

369 raise ValueError( 

370 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 

371 ) 

372 else: 

373 pad_this_axis_where = pad_axis_where[a] 

374 d = sizes[a] - s_is 

375 if pad_this_axis_where == "left": 

376 pad_width[a] = PadWidth(d, 0) 

377 elif pad_this_axis_where == "right": 

378 pad_width[a] = PadWidth(0, d) 

379 elif pad_this_axis_where == "left_and_right": 

380 pad_width[a] = PadWidth(left := d // 2, d - left) 

381 else: 

382 assert_never(pad_this_axis_where) 

383 

384 return self.pad(pad_width, mode) 

385 

386 def quantile( 

387 self, 

388 q: Union[float, Sequence[float]], 

389 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 

390 ) -> Self: 

391 assert ( 

392 isinstance(q, (float, int)) 

393 and q >= 0.0 

394 or not isinstance(q, (float, int)) 

395 and all(qq >= 0.0 for qq in q) 

396 ) 

397 assert ( 

398 isinstance(q, (float, int)) 

399 and q <= 1.0 

400 or not isinstance(q, (float, int)) 

401 and all(qq <= 1.0 for qq in q) 

402 ) 

403 assert dim is None or ( 

404 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 

405 ) 

406 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 

407 

408 def resize_to( 

409 self, 

410 sizes: PerAxis[int], 

411 *, 

412 pad_where: Union[ 

413 PadWhere, 

414 PerAxis[PadWhere], 

415 ] = "left_and_right", 

416 crop_where: Union[ 

417 CropWhere, 

418 PerAxis[CropWhere], 

419 ] = "left_and_right", 

420 pad_mode: PadMode = "symmetric", 

421 ): 

422 """return cropped/padded tensor with `sizes`""" 

423 crop_to_sizes: Dict[AxisId, int] = {} 

424 pad_to_sizes: Dict[AxisId, int] = {} 

425 new_axes = dict(sizes) 

426 for a, s_is in self.sizes.items(): 

427 a = AxisId(str(a)) 

428 _ = new_axes.pop(a, None) 

429 if a not in sizes or sizes[a] == s_is: 

430 pass 

431 elif s_is > sizes[a]: 

432 crop_to_sizes[a] = sizes[a] 

433 else: 

434 pad_to_sizes[a] = sizes[a] 

435 

436 tensor = self 

437 if crop_to_sizes: 

438 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 

439 

440 if pad_to_sizes: 

441 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 

442 

443 if new_axes: 

444 tensor = tensor.expand_dims(new_axes) 

445 

446 return tensor 

447 

448 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

449 return self.__class__.from_xarray(self._data.std(dim=dim)) 

450 

451 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

452 """Reduce this Tensor's data by applying sum along some dimension(s).""" 

453 return self.__class__.from_xarray(self._data.sum(dim=dim)) 

454 

455 def transpose( 

456 self, 

457 axes: Sequence[AxisId], 

458 ) -> Self: 

459 """return a transposed tensor 

460 

461 Args: 

462 axes: the desired tensor axes 

463 """ 

464 # expand missing tensor axes 

465 missing_axes = tuple(a for a in axes if a not in self.dims) 

466 array = self._data 

467 if missing_axes: 

468 array = array.expand_dims(missing_axes) 

469 

470 # transpose to the correct axis order 

471 return self.__class__.from_xarray(array.transpose(*axes)) 

472 

473 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

474 return self.__class__.from_xarray(self._data.var(dim=dim)) 

475 

476 @classmethod 

477 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 

478 ndim = array.ndim 

479 if ndim == 2: 

480 current_axes = ( 

481 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 

482 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 

483 ) 

484 elif ndim == 3 and any(s <= 3 for s in array.shape): 

485 current_axes = ( 

486 v0_5.ChannelAxis( 

487 channel_names=[ 

488 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

489 ] 

490 ), 

491 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

492 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

493 ) 

494 elif ndim == 3: 

495 current_axes = ( 

496 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 

497 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

498 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

499 ) 

500 elif ndim == 4: 

501 current_axes = ( 

502 v0_5.ChannelAxis( 

503 channel_names=[ 

504 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

505 ] 

506 ), 

507 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 

508 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 

509 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 

510 ) 

511 elif ndim == 5: 

512 current_axes = ( 

513 v0_5.BatchAxis(), 

514 v0_5.ChannelAxis( 

515 channel_names=[ 

516 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 

517 ] 

518 ), 

519 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 

520 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 

521 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 

522 ) 

523 else: 

524 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 

525 

526 return cls(array, dims=tuple(a.id for a in current_axes)) 

527 

528 

529def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 

530 if len(arr.shape) > len(axis_infos): 

531 # remove singletons 

532 for i, s in enumerate(arr.shape): 

533 if s == 1: 

534 arr = np.take(arr, 0, axis=i) 

535 if len(arr.shape) == len(axis_infos): 

536 break 

537 

538 # add singletons if nececsary 

539 for i, a in enumerate(axis_infos): 

540 if len(arr.shape) >= len(axis_infos): 

541 break 

542 

543 if a.maybe_singleton: 

544 arr = np.expand_dims(arr, i) 

545 

546 return arr 

547 

548 

549def _get_array_view( 

550 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 

551) -> Optional[NDArray[Any]]: 

552 perms = list(permutations(range(len(original_array.shape)))) 

553 perms.insert(1, perms.pop()) # try A and A.T first 

554 

555 for perm in perms: 

556 view = original_array.transpose(perm) 

557 view = _add_singletons(view, axis_infos) 

558 if len(view.shape) != len(axis_infos): 

559 return None 

560 

561 for s, a in zip(view.shape, axis_infos): 

562 if s == 1 and not a.maybe_singleton: 

563 break 

564 else: 

565 return view 

566 

567 return None