Coverage for bioimageio/core/tensor.py: 85%

234 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from itertools import permutations 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Callable, 

9 Dict, 

10 Iterator, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Union, 

16 cast, 

17 get_args, 

18) 

19 

20import numpy as np 

21import xarray as xr 

22from loguru import logger 

23from numpy.typing import DTypeLike, NDArray 

24from typing_extensions import Self, assert_never 

25 

26from bioimageio.spec.model import v0_5 

27 

28from ._magic_tensor_ops import MagicTensorOpsMixin 

29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 

30from .common import ( 

31 CropWhere, 

32 DTypeStr, 

33 PadMode, 

34 PadWhere, 

35 PadWidth, 

36 PadWidthLike, 

37 SliceInfo, 

38) 

39 

40if TYPE_CHECKING: 

41 from numpy.typing import ArrayLike, NDArray 

42 

43 

44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 

45 

46 

47# TODO: complete docstrings 

48class Tensor(MagicTensorOpsMixin): 

49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 

50 and improved type annotations.""" 

51 

52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 

53 

54 def __init__( 

55 self, 

56 array: NDArray[Any], 

57 dims: Sequence[Union[AxisId, AxisLike]], 

58 ) -> None: 

59 super().__init__() 

60 axes = tuple( 

61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 

62 ) 

63 self._data = xr.DataArray(array, dims=axes) 

64 

65 def __array__(self, dtype: DTypeLike = None): 

66 return np.asarray(self._data, dtype=dtype) 

67 

68 def __getitem__( 

69 self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]] 

70 ) -> Self: 

71 if isinstance(key, SliceInfo): 

72 key = slice(*key) 

73 elif isinstance(key, collections.abc.Mapping): 

74 key = { 

75 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 

76 for a, s in key.items() 

77 } 

78 return self.__class__.from_xarray(self._data[key]) 

79 

80 def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: 

81 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 

82 self._data[key] = value._data 

83 

84 def __len__(self) -> int: 

85 return len(self.data) 

86 

87 def _iter(self: Any) -> Iterator[Any]: 

88 for n in range(len(self)): 

89 yield self[n] 

90 

91 def __iter__(self: Any) -> Iterator[Any]: 

92 if self.ndim == 0: 

93 raise TypeError("iteration over a 0-d array") 

94 return self._iter() 

95 

96 def _binary_op( 

97 self, 

98 other: _Compatible, 

99 f: Callable[[Any, Any], Any], 

100 reflexive: bool = False, 

101 ) -> Self: 

102 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 

103 (other._data if isinstance(other, Tensor) else other), 

104 f, 

105 reflexive, 

106 ) 

107 return self.__class__.from_xarray(data) 

108 

109 def _inplace_binary_op( 

110 self, 

111 other: _Compatible, 

112 f: Callable[[Any, Any], Any], 

113 ) -> Self: 

114 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 

115 ( 

116 other_d 

117 if (other_d := getattr(other, "data")) is not None 

118 and isinstance( 

119 other_d, 

120 xr.DataArray, 

121 ) 

122 else other 

123 ), 

124 f, 

125 ) 

126 return self 

127 

128 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 

129 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 

130 f, *args, **kwargs 

131 ) 

132 return self.__class__.from_xarray(data) 

133 

134 @classmethod 

135 def from_xarray(cls, data_array: xr.DataArray) -> Self: 

136 """create a `Tensor` from an xarray data array 

137 

138 note for internal use: this factory method is round-trip save 

139 for any `Tensor`'s `data` property (an xarray.DataArray). 

140 """ 

141 return cls( 

142 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 

143 ) 

144 

145 @classmethod 

146 def from_numpy( 

147 cls, 

148 array: NDArray[Any], 

149 *, 

150 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 

151 ) -> Tensor: 

152 """create a `Tensor` from a numpy array 

153 

154 Args: 

155 array: the nd numpy array 

156 axes: A description of the array's axes, 

157 if None axes are guessed (which might fail and raise a ValueError.) 

158 

159 Raises: 

160 ValueError: if `axes` is None and axes guessing fails. 

161 """ 

162 

163 if dims is None: 

164 return cls._interprete_array_wo_known_axes(array) 

165 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 

166 dims = [dims] 

167 

168 axis_infos = [AxisInfo.create(a) for a in dims] 

169 original_shape = tuple(array.shape) 

170 

171 successful_view = _get_array_view(array, axis_infos) 

172 if successful_view is None: 

173 raise ValueError( 

174 f"Array shape {original_shape} does not map to axes {dims}" 

175 ) 

176 

177 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 

178 

179 @property 

180 def data(self): 

181 return self._data 

182 

183 @property 

184 def dims(self): # TODO: rename to `axes`? 

185 """Tuple of dimension names associated with this tensor.""" 

186 return cast(Tuple[AxisId, ...], self._data.dims) 

187 

188 @property 

189 def dtype(self) -> DTypeStr: 

190 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 

191 assert dt in get_args(DTypeStr) 

192 return dt # pyright: ignore[reportReturnType] 

193 

194 @property 

195 def ndim(self): 

196 """Number of tensor dimensions.""" 

197 return self._data.ndim 

198 

199 @property 

200 def shape(self): 

201 """Tuple of tensor axes lengths""" 

202 return self._data.shape 

203 

204 @property 

205 def shape_tuple(self): 

206 """Tuple of tensor axes lengths""" 

207 return self._data.shape 

208 

209 @property 

210 def size(self): 

211 """Number of elements in the tensor. 

212 

213 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 

214 """ 

215 return self._data.size 

216 

217 @property 

218 def sizes(self): 

219 """Ordered, immutable mapping from axis ids to axis lengths.""" 

220 return cast(Mapping[AxisId, int], self.data.sizes) 

221 

222 @property 

223 def tagged_shape(self): 

224 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 

225 return self.sizes 

226 

227 def argmax(self) -> Mapping[AxisId, int]: 

228 ret = self._data.argmax(...) 

229 assert isinstance(ret, dict) 

230 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 

231 

232 def astype(self, dtype: DTypeStr, *, copy: bool = False): 

233 """Return tensor cast to `dtype` 

234 

235 note: if dtype is already satisfied copy if `copy`""" 

236 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 

237 

238 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 

239 """Return a tensor whose values are limited to [min, max]. 

240 At least one of max or min must be given.""" 

241 return self.__class__.from_xarray(self._data.clip(min, max)) 

242 

243 def crop_to( 

244 self, 

245 sizes: PerAxis[int], 

246 crop_where: Union[ 

247 CropWhere, 

248 PerAxis[CropWhere], 

249 ] = "left_and_right", 

250 ) -> Self: 

251 """crop to match `sizes`""" 

252 if isinstance(crop_where, str): 

253 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 

254 else: 

255 crop_axis_where = crop_where 

256 

257 slices: Dict[AxisId, SliceInfo] = {} 

258 

259 for a, s_is in self.sizes.items(): 

260 if a not in sizes or sizes[a] == s_is: 

261 pass 

262 elif sizes[a] > s_is: 

263 logger.warning( 

264 "Cannot crop axis {} of size {} to larger size {}", 

265 a, 

266 s_is, 

267 sizes[a], 

268 ) 

269 elif a not in crop_axis_where: 

270 raise ValueError( 

271 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 

272 ) 

273 else: 

274 crop_this_axis_where = crop_axis_where[a] 

275 if crop_this_axis_where == "left": 

276 slices[a] = SliceInfo(s_is - sizes[a], s_is) 

277 elif crop_this_axis_where == "right": 

278 slices[a] = SliceInfo(0, sizes[a]) 

279 elif crop_this_axis_where == "left_and_right": 

280 slices[a] = SliceInfo( 

281 start := (s_is - sizes[a]) // 2, sizes[a] + start 

282 ) 

283 else: 

284 assert_never(crop_this_axis_where) 

285 

286 return self[slices] 

287 

288 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 

289 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 

290 

291 def item( 

292 self, 

293 key: Union[ 

294 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 

295 ] = None, 

296 ): 

297 """Copy a tensor element to a standard Python scalar and return it.""" 

298 if key is None: 

299 ret = self._data.item() 

300 else: 

301 ret = self[key]._data.item() 

302 

303 assert isinstance(ret, (bool, float, int)) 

304 return ret 

305 

306 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

307 return self.__class__.from_xarray(self._data.mean(dim=dim)) 

308 

309 def pad( 

310 self, 

311 pad_width: PerAxis[PadWidthLike], 

312 mode: PadMode = "symmetric", 

313 ) -> Self: 

314 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 

315 return self.__class__.from_xarray( 

316 self._data.pad(pad_width=pad_width, mode=mode) 

317 ) 

318 

319 def pad_to( 

320 self, 

321 sizes: PerAxis[int], 

322 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 

323 mode: PadMode = "symmetric", 

324 ) -> Self: 

325 """pad `tensor` to match `sizes`""" 

326 if isinstance(pad_where, str): 

327 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 

328 else: 

329 pad_axis_where = pad_where 

330 

331 pad_width: Dict[AxisId, PadWidth] = {} 

332 for a, s_is in self.sizes.items(): 

333 if a not in sizes or sizes[a] == s_is: 

334 pad_width[a] = PadWidth(0, 0) 

335 elif s_is > sizes[a]: 

336 pad_width[a] = PadWidth(0, 0) 

337 logger.warning( 

338 "Cannot pad axis {} of size {} to smaller size {}", 

339 a, 

340 s_is, 

341 sizes[a], 

342 ) 

343 elif a not in pad_axis_where: 

344 raise ValueError( 

345 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 

346 ) 

347 else: 

348 pad_this_axis_where = pad_axis_where[a] 

349 d = sizes[a] - s_is 

350 if pad_this_axis_where == "left": 

351 pad_width[a] = PadWidth(d, 0) 

352 elif pad_this_axis_where == "right": 

353 pad_width[a] = PadWidth(0, d) 

354 elif pad_this_axis_where == "left_and_right": 

355 pad_width[a] = PadWidth(left := d // 2, d - left) 

356 else: 

357 assert_never(pad_this_axis_where) 

358 

359 return self.pad(pad_width, mode) 

360 

361 def quantile( 

362 self, 

363 q: Union[float, Sequence[float]], 

364 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 

365 ) -> Self: 

366 assert ( 

367 isinstance(q, (float, int)) 

368 and q >= 0.0 

369 or not isinstance(q, (float, int)) 

370 and all(qq >= 0.0 for qq in q) 

371 ) 

372 assert ( 

373 isinstance(q, (float, int)) 

374 and q <= 1.0 

375 or not isinstance(q, (float, int)) 

376 and all(qq <= 1.0 for qq in q) 

377 ) 

378 assert dim is None or ( 

379 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 

380 ) 

381 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 

382 

383 def resize_to( 

384 self, 

385 sizes: PerAxis[int], 

386 *, 

387 pad_where: Union[ 

388 PadWhere, 

389 PerAxis[PadWhere], 

390 ] = "left_and_right", 

391 crop_where: Union[ 

392 CropWhere, 

393 PerAxis[CropWhere], 

394 ] = "left_and_right", 

395 pad_mode: PadMode = "symmetric", 

396 ): 

397 """return cropped/padded tensor with `sizes`""" 

398 crop_to_sizes: Dict[AxisId, int] = {} 

399 pad_to_sizes: Dict[AxisId, int] = {} 

400 new_axes = dict(sizes) 

401 for a, s_is in self.sizes.items(): 

402 a = AxisId(str(a)) 

403 _ = new_axes.pop(a, None) 

404 if a not in sizes or sizes[a] == s_is: 

405 pass 

406 elif s_is > sizes[a]: 

407 crop_to_sizes[a] = sizes[a] 

408 else: 

409 pad_to_sizes[a] = sizes[a] 

410 

411 tensor = self 

412 if crop_to_sizes: 

413 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 

414 

415 if pad_to_sizes: 

416 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 

417 

418 if new_axes: 

419 tensor = tensor.expand_dims(new_axes) 

420 

421 return tensor 

422 

423 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

424 return self.__class__.from_xarray(self._data.std(dim=dim)) 

425 

426 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

427 """Reduce this Tensor's data by applying sum along some dimension(s).""" 

428 return self.__class__.from_xarray(self._data.sum(dim=dim)) 

429 

430 def transpose( 

431 self, 

432 axes: Sequence[AxisId], 

433 ) -> Self: 

434 """return a transposed tensor 

435 

436 Args: 

437 axes: the desired tensor axes 

438 """ 

439 # expand missing tensor axes 

440 missing_axes = tuple(a for a in axes if a not in self.dims) 

441 array = self._data 

442 if missing_axes: 

443 array = array.expand_dims(missing_axes) 

444 

445 # transpose to the correct axis order 

446 return self.__class__.from_xarray(array.transpose(*axes)) 

447 

448 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 

449 return self.__class__.from_xarray(self._data.var(dim=dim)) 

450 

451 @classmethod 

452 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 

453 ndim = array.ndim 

454 if ndim == 2: 

455 current_axes = ( 

456 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 

457 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 

458 ) 

459 elif ndim == 3 and any(s <= 3 for s in array.shape): 

460 current_axes = ( 

461 v0_5.ChannelAxis( 

462 channel_names=[ 

463 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

464 ] 

465 ), 

466 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

467 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

468 ) 

469 elif ndim == 3: 

470 current_axes = ( 

471 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 

472 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 

473 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 

474 ) 

475 elif ndim == 4: 

476 current_axes = ( 

477 v0_5.ChannelAxis( 

478 channel_names=[ 

479 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 

480 ] 

481 ), 

482 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 

483 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 

484 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 

485 ) 

486 elif ndim == 5: 

487 current_axes = ( 

488 v0_5.BatchAxis(), 

489 v0_5.ChannelAxis( 

490 channel_names=[ 

491 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 

492 ] 

493 ), 

494 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 

495 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 

496 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 

497 ) 

498 else: 

499 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 

500 

501 return cls(array, dims=tuple(a.id for a in current_axes)) 

502 

503 

504def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 

505 if len(arr.shape) > len(axis_infos): 

506 # remove singletons 

507 for i, s in enumerate(arr.shape): 

508 if s == 1: 

509 arr = np.take(arr, 0, axis=i) 

510 if len(arr.shape) == len(axis_infos): 

511 break 

512 

513 # add singletons if nececsary 

514 for i, a in enumerate(axis_infos): 

515 if len(arr.shape) >= len(axis_infos): 

516 break 

517 

518 if a.maybe_singleton: 

519 arr = np.expand_dims(arr, i) 

520 

521 return arr 

522 

523 

524def _get_array_view( 

525 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 

526) -> Optional[NDArray[Any]]: 

527 perms = list(permutations(range(len(original_array.shape)))) 

528 perms.insert(1, perms.pop()) # try A and A.T first 

529 

530 for perm in perms: 

531 view = original_array.transpose(perm) 

532 view = _add_singletons(view, axis_infos) 

533 if len(view.shape) != len(axis_infos): 

534 return None 

535 

536 for s, a in zip(view.shape, axis_infos): 

537 if s == 1 and not a.maybe_singleton: 

538 break 

539 else: 

540 return view 

541 

542 return None