Tensor

Struct Tensor 

pub struct Tensor<B>
where B: Backend,
{ /* private fields */ }
Expand description

Re-export core types. An n-dimensional array of numbers on a specific backend.

Tensors are the fundamental data type in Shrew. All neural network operations accept and return tensors.

§Type Parameter

  • B: Backend — the compute backend (e.g., CpuBackend, CudaBackend)

§Example

use shrew_core::Tensor;
use shrew_cpu::CpuBackend;

let a = Tensor::<CpuBackend>::from_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2))?;
let b = Tensor::<CpuBackend>::ones((2, 2), DType::F32, &CpuDevice)?;
let c = a.add(&b)?;

Implementations§

§

impl<B> Tensor<B>
where B: Backend,

pub fn id(&self) -> TensorId

Unique tensor ID.

pub fn shape(&self) -> &Shape

The shape of this tensor.

pub fn dims(&self) -> &[usize]

The dimensions as a slice (shortcut for shape().dims()).

pub fn rank(&self) -> usize

Number of dimensions (rank).

pub fn elem_count(&self) -> usize

Total number of elements.

pub fn dtype(&self) -> DType

Data type of the elements.

pub fn device(&self) -> &<B as Backend>::Device

The device this tensor is on.

pub fn layout(&self) -> &Layout

The memory layout (shape + strides + offset).

pub fn is_contiguous(&self) -> bool

Whether this tensor is contiguous in memory.

pub fn is_variable(&self) -> bool

Whether this tensor tracks gradients.

pub fn storage(&self) -> RwLockReadGuard<'_, <B as Backend>::Storage>

Access the underlying storage (read lock).

pub fn op(&self) -> &Op<B>

The op that created this tensor.

pub fn update_data_inplace(&self, new_data: &[f64]) -> Result<(), Error>

Update the underlying storage data in place.

This writes new_data directly into the existing Arc<RwLock<Storage>>, so any other tensor sharing this storage (e.g., a clone held by a Module) will also see the updated values.

This is the mechanism that makes optimizer parameter updates visible to model layers without needing to re-assign parameters.

§Safety (logical)

The new data must have the same number of elements and dtype as the current storage. The shape is not changed.

pub fn zeros( shape: impl Into<Shape>, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a tensor filled with zeros.

pub fn ones( shape: impl Into<Shape>, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a tensor filled with ones.

pub fn full( shape: impl Into<Shape>, val: f64, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a tensor filled with a constant value.

pub fn from_f64_slice( data: &[f64], shape: impl Into<Shape>, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a tensor from a flat slice of f64 values. The data is converted to the specified dtype.

pub fn rand( shape: impl Into<Shape>, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a tensor with random uniform values in [0, 1).

pub fn randn( shape: impl Into<Shape>, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a tensor with random normal values (mean=0, std=1).

pub fn linspace( start: f64, end: f64, steps: usize, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a 1-D tensor with steps evenly spaced values from start to end (inclusive).

let t = Tensor::linspace(0.0, 1.0, 5, DType::F64, &dev)?;
// => [0.0, 0.25, 0.5, 0.75, 1.0]

pub fn eye( n: usize, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create an identity matrix of size n × n.

let I = Tensor::eye(3, DType::F64, &dev)?;
// [[1, 0, 0],
//  [0, 1, 0],
//  [0, 0, 1]]

pub fn zeros_like(other: &Tensor<B>) -> Result<Tensor<B>, Error>

Create a tensor of zeros with the same shape, dtype, and device as other.

pub fn ones_like(other: &Tensor<B>) -> Result<Tensor<B>, Error>

Create a tensor of ones with the same shape, dtype, and device as other.

pub fn full_like(other: &Tensor<B>, val: f64) -> Result<Tensor<B>, Error>

Create a tensor filled with val, with the same shape, dtype, and device as other.

pub fn set_variable(self) -> Tensor<B>

Mark this tensor as a variable (trainable parameter). Variables accumulate gradients during backward().

pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Tensor<B>, Error>

Transpose two dimensions (no data copy).

pub fn t(&self) -> Result<Tensor<B>, Error>

Transpose a 2D matrix (shorthand for transpose(0, 1)).

pub fn narrow( &self, dim: usize, start: usize, len: usize, ) -> Result<Tensor<B>, Error>

Narrow (slice) along a dimension.

pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor<B>, Error>

Reshape to a new shape. The new shape must have the same total elements. If the tensor is not contiguous, it will be made contiguous first.

pub fn contiguous(&self) -> Result<Tensor<B>, Error>

Ensure the tensor is contiguous in memory. If already contiguous, returns a clone (cheap Arc copy). Otherwise, copies the data into a new contiguous storage.

pub fn unsqueeze(&self, dim: usize) -> Result<Tensor<B>, Error>

Add a dimension of size 1 at the given position. unsqueeze(0) on [3, 4] → [1, 3, 4] unsqueeze(2) on [3, 4] → [3, 4, 1]

pub fn squeeze_all(&self) -> Tensor<B>

Remove dimensions of size 1. squeeze on [1, 3, 1, 4] → [3, 4]

pub fn squeeze(&self, dim: usize) -> Result<Tensor<B>, Error>

Remove a specific dimension of size 1.

squeeze(1) on [3, 1, 4] → [3, 4]

Returns an error if the specified dimension is not size 1.

pub fn permute(&self, dims: &[usize]) -> Result<Tensor<B>, Error>

Permute the dimensions of this tensor.

permute(&[2, 0, 1]) on [A, B, C] → [C, A, B]

This is a generalization of transpose to arbitrary dimension orderings. No data copy — just changes strides.

pub fn cumsum(&self, dim: usize) -> Result<Tensor<B>, Error>

Cumulative sum along dimension dim.

// [1, 2, 3] → [1, 3, 6]
let y = x.cumsum(0)?;

pub fn sort( &self, dim: usize, descending: bool, ) -> Result<(Tensor<B>, Tensor<B>), Error>

Sort along a dimension. Returns (sorted_values, sorted_indices).

let (vals, idxs) = x.sort(0, false)?; // ascending along dim 0

pub fn argsort(&self, dim: usize, descending: bool) -> Result<Tensor<B>, Error>

Argsort: returns indices that would sort the tensor along dim.

let indices = x.argsort(0, false)?; // ascending

pub fn add(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise addition: self + rhs.

pub fn sub(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise subtraction: self - rhs.

pub fn mul(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise multiplication: self * rhs.

pub fn div(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise division: self / rhs.

pub fn eq(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise equal: self == rhs. Returns a U8 tensor (0 or 1).

pub fn ne(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise not-equal: self != rhs. Returns a U8 tensor (0 or 1).

pub fn gt(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise greater-than: self > rhs. Returns a U8 tensor (0 or 1).

pub fn ge(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise greater-or-equal: self >= rhs. Returns a U8 tensor (0 or 1).

pub fn lt(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise less-than: self < rhs. Returns a U8 tensor (0 or 1).

pub fn le(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Element-wise less-or-equal: self <= rhs. Returns a U8 tensor (0 or 1).

pub fn neg(&self) -> Result<Tensor<B>, Error>

Element-wise negation: -self.

pub fn abs(&self) -> Result<Tensor<B>, Error>

Element-wise absolute value.

pub fn exp(&self) -> Result<Tensor<B>, Error>

Element-wise exponential: e^x.

pub fn log(&self) -> Result<Tensor<B>, Error>

Element-wise natural logarithm.

pub fn sqrt(&self) -> Result<Tensor<B>, Error>

Element-wise square root.

pub fn square(&self) -> Result<Tensor<B>, Error>

Element-wise square: x².

pub fn relu(&self) -> Result<Tensor<B>, Error>

ReLU activation: max(0, x).

pub fn sigmoid(&self) -> Result<Tensor<B>, Error>

Sigmoid activation: 1 / (1 + e^(-x)).

pub fn tanh(&self) -> Result<Tensor<B>, Error>

Tanh activation.

pub fn gelu(&self) -> Result<Tensor<B>, Error>

GELU activation (Gaussian Error Linear Unit).

pub fn silu(&self) -> Result<Tensor<B>, Error>

SiLU / Swish activation: x * sigmoid(x).

pub fn sin(&self) -> Result<Tensor<B>, Error>

Element-wise sine.

pub fn cos(&self) -> Result<Tensor<B>, Error>

Element-wise cosine.

pub fn floor(&self) -> Result<Tensor<B>, Error>

Element-wise floor: largest integer ≤ x.

pub fn ceil(&self) -> Result<Tensor<B>, Error>

Element-wise ceiling: smallest integer ≥ x.

pub fn round(&self) -> Result<Tensor<B>, Error>

Element-wise round to nearest integer.

pub fn powf(&self, exponent: f64) -> Result<Tensor<B>, Error>

Element-wise power: self^exponent.

pub fn clamp(&self, min: f64, max: f64) -> Result<Tensor<B>, Error>

Element-wise clamp to [min, max].

pub fn where_cond( mask: &Tensor<B>, on_true: &Tensor<B>, on_false: &Tensor<B>, ) -> Result<Tensor<B>, Error>

Conditional select: result[i] = if mask[i] != 0 { on_true[i] } else { on_false[i] }.

mask is typically a U8 tensor from comparison ops. on_true and on_false must have the same shape and dtype.

pub fn gather(&self, dim: usize, index: &Tensor<B>) -> Result<Tensor<B>, Error>

Gather elements along dim using an index tensor.

output[i][j][k] = input[index[i][j][k]][j][k] (when dim=0)

The index tensor must have the same number of dimensions as self. The output has the same shape as the index tensor.

pub fn masked_fill( &self, mask: &Tensor<B>, value: f64, ) -> Result<Tensor<B>, Error>

Fill elements where mask != 0 with value, keeping other elements.

result[i] = if mask[i] != 0 { value } else { self[i] }

This is implemented via where_cond so autograd is automatic.

pub fn pad( &self, padding: &[[usize; 2]], value: f64, ) -> Result<Tensor<B>, Error>

Pad the last N dimensions with constant value.

padding is a list of [before, after] pairs, one per dimension, applied to the last dimensions of the tensor.

Example: pad(&[[1, 1], [2, 2]], 0.0) pads the last 2 dims.

pub fn topk( &self, k: usize, dim: usize, ) -> Result<(Tensor<B>, Vec<usize>), Error>

Return the k largest elements along dim.

Returns (values, indices) where both have the same shape as self except dimension dim has size k.

Non-differentiable (returns detached values).

pub fn sum_all(&self) -> Result<Tensor<B>, Error>

Sum all elements, returning a scalar tensor.

pub fn sum(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Sum along a specific dimension.

pub fn mean_all(&self) -> Result<Tensor<B>, Error>

Mean of all elements, returning a scalar tensor.

pub fn mean(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Mean along a specific dimension.

pub fn max(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Max along a specific dimension.

pub fn min(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Min along a specific dimension.

pub fn argmax(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

ArgMax along a specific dimension (returns i64 indices).

pub fn argmin(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

ArgMin along a specific dimension (returns i64 indices).

pub fn softmax(&self, dim: usize) -> Result<Tensor<B>, Error>

Softmax along a dimension: softmax(x)_i = exp(x_i) / sum(exp(x_j))

Uses the numerically stable trick: subtract max before exp. This is built from existing differentiable ops (exp, sum, div, sub) so gradients flow through automatically.

pub fn log_softmax(&self, dim: usize) -> Result<Tensor<B>, Error>

Log-softmax along a dimension: log(softmax(x)) but numerically stable.

log_softmax(x)_i = x_i - max(x) - log(sum(exp(x - max(x))))

pub fn var(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Variance along a dimension: var(x) = mean((x - mean(x))²)

pub fn cat(tensors: &[Tensor<B>], dim: usize) -> Result<Tensor<B>, Error>

Concatenate tensors along a dimension.

All tensors must have the same shape except in the concatenation dimension. This creates a new tensor by copying data from all inputs.

pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor<B>>, Error>

Split a tensor into n equal chunks along a dimension. If the dimension size is not evenly divisible, the last chunk is smaller.

pub fn expand(&self, target_shape: impl Into<Shape>) -> Result<Tensor<B>, Error>

Expand a tensor to a larger shape by repeating data along size-1 dims. Only dims that are currently size 1 can be expanded. A size of -1 (usize::MAX) means don’t change that dim.

pub fn stack(tensors: &[Tensor<B>], dim: usize) -> Result<Tensor<B>, Error>

Stack tensors along a new dimension.

All tensors must have the same shape. Inserts a new dimension at dim. stack([a, b], dim=0) where a,b are shape [2,3] → [2, 2, 3].

pub fn arange( n: usize, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a 1-D tensor with values [0, 1, …, n-1].

pub fn arange_step( start: f64, end: f64, step: f64, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Create a 1-D tensor with values [start, start+step, …, <end).

pub fn triu( n: usize, m: usize, diagonal: i64, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Upper triangular mask: returns a 2-D tensor of shape [n, m] where elements on and above the diagonal-th diagonal are 1.0, rest 0.0.

diagonal = 0 → main diagonal. diagonal > 0 → above. diagonal < 0 → below.

pub fn tril( n: usize, m: usize, diagonal: i64, dtype: DType, device: &<B as Backend>::Device, ) -> Result<Tensor<B>, Error>

Lower triangular mask: returns a 2-D tensor of shape [n, m] where elements on and below the diagonal-th diagonal are 1.0, rest 0.0.

pub fn matmul(&self, rhs: &Tensor<B>) -> Result<Tensor<B>, Error>

Matrix multiplication: self @ rhs.

  • [m, k] @ [k, n] → [m, n]
  • Batched: [b, m, k] @ [b, k, n] → [b, m, n]

pub fn conv2d( &self, weight: &Tensor<B>, bias: Option<&Tensor<B>>, stride: [usize; 2], padding: [usize; 2], ) -> Result<Tensor<B>, Error>

2D convolution: applies convolution filters to a 4D input tensor.

  • self (input): [N, C_in, H, W]
  • weight: [C_out, C_in, kH, kW]
  • bias: optional [C_out]
  • stride: [sH, sW]
  • padding: [pH, pW]

Returns tensor of shape [N, C_out, H_out, W_out] where H_out = (H + 2*pH - kH) / sH + 1.

pub fn max_pool2d( &self, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], ) -> Result<Tensor<B>, Error>

2D max pooling on a 4D input tensor [N, C, H, W].

Returns (output, indices) where indices stores argmax positions (flat indices into the input) for backward.

pub fn avg_pool2d( &self, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], ) -> Result<Tensor<B>, Error>

Apply 2D average pooling to a 4D tensor [N, C, H, W].

pub fn conv1d( &self, weight: &Tensor<B>, bias: Option<&Tensor<B>>, stride: usize, padding: usize, ) -> Result<Tensor<B>, Error>

Apply 1D convolution to a 3D tensor [N, C_in, L]. weight: [C_out, C_in, K]

pub fn affine(&self, mul: f64, add: f64) -> Result<Tensor<B>, Error>

Affine transform: result[i] = self[i] * mul + add. Useful for normalization and scaling.

pub fn to_f64_vec(&self) -> Result<Vec<f64>, Error>

Extract all elements as a flat Vec.

pub fn to_scalar_f64(&self) -> Result<f64, Error>

Extract a scalar value (tensor must have exactly 1 element).

pub fn to_dtype(&self, dtype: DType) -> Result<Tensor<B>, Error>

Convert this tensor to a different dtype.

Returns a new tensor with the same shape but different element type. Uses the backend’s on-device cast when available, avoiding host round-trips. Records Op::ToDtype so gradients flow back through dtype conversions.

pub fn to_string_with_data(&self) -> Result<String, Error>

Display the tensor contents in a human-readable format.

pub fn backward(&self) -> Result<GradStore<B>, Error>

Compute gradients via reverse-mode automatic differentiation.

This tensor must be a scalar (single element). Returns a GradStore containing gradients for all tensors in the computation graph.

§Example
let a = Tensor::from_f64_slice(&[2.0], 1, DType::F32, &dev)?.set_variable();
let b = Tensor::from_f64_slice(&[3.0], 1, DType::F32, &dev)?.set_variable();
let c = a.mul(&b)?;
let grads = c.backward()?;
// grad_a = b = 3.0, grad_b = a = 2.0

pub fn detach(&self) -> Tensor<B>

Create a detached copy: same data but no gradient tracking. The new tensor has Op::None and a fresh TensorId.

pub fn freeze(&self) -> Tensor<B>

Freeze this tensor: same data and id, but is_variable = false.

Frozen tensors do NOT accumulate gradients during backward(). This is the functional equivalent of PyTorch’s param.requires_grad_(False).

pub fn unfreeze(&self) -> Tensor<B>

Unfreeze this tensor: same data and id, but is_variable = true.

This is the opposite of freeze().

pub fn index_select( &self, dim: usize, indices: &Tensor<B>, ) -> Result<Tensor<B>, Error>

Select entries along dim using the given 1-D index tensor.

The output has the same rank, with dim resized to indices.len(). Wraps the Backend::index_select kernel.

pub fn split( &self, split_size: usize, dim: usize, ) -> Result<Vec<Tensor<B>>, Error>

Split a tensor into chunks of split_size along dim.

The last chunk may be smaller if the dimension is not evenly divisible.

pub fn flatten( &self, start_dim: usize, end_dim: usize, ) -> Result<Tensor<B>, Error>

Flatten dimensions start_dim..=end_dim into a single dimension.

Negative-style indexing is not supported; both bounds are inclusive and zero-based.

pub fn std(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Standard deviation along a dimension.

Computed as sqrt(var(x, dim)).

pub fn reciprocal(&self) -> Result<Tensor<B>, Error>

Element-wise reciprocal: 1 / x.

pub fn rsqrt(&self) -> Result<Tensor<B>, Error>

Element-wise reciprocal square-root: 1 / sqrt(x).

pub fn sign(&self) -> Result<Tensor<B>, Error>

Element-wise sign: returns -1, 0, or +1.

Implemented via x / (|x| + eps) clamped to [-1, 1], with exact 0 for inputs that are exactly zero.

pub fn logsumexp(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Log-sum-exp along a dimension (numerically stable).

logsumexp(x, d) = max(x,d) + log(sum(exp(x - max(x,d)), d))

pub fn prod(&self, dim: usize, keep_dim: bool) -> Result<Tensor<B>, Error>

Product of elements along a dimension.

Computed as exp(sum(log(|x|))) with sign correction. Warning: undefined for inputs containing zero.

Trait Implementations§

§

impl<B> Clone for Tensor<B>
where B: Backend,

§

fn clone(&self) -> Tensor<B>

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
§

impl<B> Debug for Tensor<B>
where B: Backend,

§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<B> Freeze for Tensor<B>

§

impl<B> RefUnwindSafe for Tensor<B>
where <B as Backend>::Device: RefUnwindSafe,

§

impl<B> Send for Tensor<B>

§

impl<B> Sync for Tensor<B>

§

impl<B> Unpin for Tensor<B>

§

impl<B> UnwindSafe for Tensor<B>
where <B as Backend>::Device: RefUnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
§

impl<T> Pointable for T

§

const ALIGN: usize

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

§

fn vzip(self) -> V