Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/mir/eval/shim/simd.rs')
| -rw-r--r-- | crates/hir-ty/src/mir/eval/shim/simd.rs | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/crates/hir-ty/src/mir/eval/shim/simd.rs b/crates/hir-ty/src/mir/eval/shim/simd.rs new file mode 100644 index 0000000000..ec74631048 --- /dev/null +++ b/crates/hir-ty/src/mir/eval/shim/simd.rs @@ -0,0 +1,177 @@ +//! Shim implementation for simd intrinsics + +use std::cmp::Ordering; + +use crate::TyKind; + +use super::*; + +macro_rules! from_bytes { + ($ty:tt, $value:expr) => { + ($ty::from_le_bytes(match ($value).try_into() { + Ok(it) => it, + Err(_) => return Err(MirEvalError::TypeError("mismatched size")), + })) + }; +} + +macro_rules! not_supported { + ($it: expr) => { + return Err(MirEvalError::NotSupported(format!($it))) + }; +} + +impl Evaluator<'_> { + fn detect_simd_ty(&self, ty: &Ty) -> Result<(usize, Ty)> { + match ty.kind(Interner) { + TyKind::Adt(id, subst) => { + let len = match subst.as_slice(Interner).get(1).and_then(|it| it.constant(Interner)) + { + Some(len) => len, + _ => { + if let AdtId::StructId(id) = id.0 { + let struct_data = self.db.struct_data(id); + let fields = struct_data.variant_data.fields(); + let Some((first_field, _)) = fields.iter().next() else { + not_supported!("simd type with no field"); + }; + let field_ty = self.db.field_types(id.into())[first_field] + .clone() + .substitute(Interner, subst); + return Ok((fields.len(), field_ty)); + } + return Err(MirEvalError::TypeError("simd type with no len param")); + } + }; + match try_const_usize(self.db, len) { + Some(len) => { + let Some(ty) = subst.as_slice(Interner).get(0).and_then(|it| it.ty(Interner)) else { + return Err(MirEvalError::TypeError("simd type with no ty param")); + }; + Ok((len as usize, ty.clone())) + } + None => Err(MirEvalError::TypeError("simd type with unevaluatable len param")), + } + } + _ => Err(MirEvalError::TypeError("simd type which is not a struct")), + } + } + + pub(super) fn exec_simd_intrinsic( + &mut self, + name: &str, + args: &[IntervalAndTy], + _generic_args: &Substitution, + destination: Interval, + _locals: &Locals, + _span: MirSpan, + ) -> Result<()> { + match name { + "and" | "or" | "xor" => { + let [left, right] = args else { + return Err(MirEvalError::TypeError("simd bit op args are not provided")); + }; + let result = left + .get(self)? + .iter() + .zip(right.get(self)?) + .map(|(&it, &y)| match name { + "and" => it & y, + "or" => it | y, + "xor" => it ^ y, + _ => unreachable!(), + }) + .collect::<Vec<_>>(); + destination.write_from_bytes(self, &result) + } + "eq" | "ne" | "lt" | "le" | "gt" | "ge" => { + let [left, right] = args else { + return Err(MirEvalError::TypeError("simd args are not provided")); + }; + let (len, ty) = self.detect_simd_ty(&left.ty)?; + let is_signed = matches!(ty.as_builtin(), Some(BuiltinType::Int(_))); + let size = left.interval.size / len; + let dest_size = destination.size / len; + let mut destination_bytes = vec![]; + let vector = left.get(self)?.chunks(size).zip(right.get(self)?.chunks(size)); + for (l, r) in vector { + let mut result = Ordering::Equal; + for (l, r) in l.iter().zip(r).rev() { + let it = l.cmp(r); + if it != Ordering::Equal { + result = it; + break; + } + } + if is_signed { + if let Some((&l, &r)) = l.iter().zip(r).rev().next() { + if l != r { + result = (l as i8).cmp(&(r as i8)); + } + } + } + let result = match result { + Ordering::Less => ["lt", "le", "ne"].contains(&name), + Ordering::Equal => ["ge", "le", "eq"].contains(&name), + Ordering::Greater => ["ge", "gt", "ne"].contains(&name), + }; + let result = if result { 255 } else { 0 }; + destination_bytes.extend(std::iter::repeat(result).take(dest_size)); + } + + destination.write_from_bytes(self, &destination_bytes) + } + "bitmask" => { + let [op] = args else { + return Err(MirEvalError::TypeError("simd_bitmask args are not provided")); + }; + let (op_len, _) = self.detect_simd_ty(&op.ty)?; + let op_count = op.interval.size / op_len; + let mut result: u64 = 0; + for (i, val) in op.get(self)?.chunks(op_count).enumerate() { + if !val.iter().all(|&it| it == 0) { + result |= 1 << i; + } + } + destination.write_from_bytes(self, &result.to_le_bytes()[0..destination.size]) + } + "shuffle" => { + let [left, right, index] = args else { + return Err(MirEvalError::TypeError("simd_shuffle args are not provided")); + }; + let TyKind::Array(_, index_len) = index.ty.kind(Interner) else { + return Err(MirEvalError::TypeError( + "simd_shuffle index argument has non-array type", + )); + }; + let index_len = match try_const_usize(self.db, index_len) { + Some(it) => it as usize, + None => { + return Err(MirEvalError::TypeError( + "simd type with unevaluatable len param", + )) + } + }; + let (left_len, _) = self.detect_simd_ty(&left.ty)?; + let left_size = left.interval.size / left_len; + let vector = + left.get(self)?.chunks(left_size).chain(right.get(self)?.chunks(left_size)); + let mut result = vec![]; + for index in index.get(self)?.chunks(index.interval.size / index_len) { + let index = from_bytes!(u32, index) as usize; + let val = match vector.clone().nth(index) { + Some(it) => it, + None => { + return Err(MirEvalError::TypeError( + "out of bound access in simd shuffle", + )) + } + }; + result.extend(val); + } + destination.write_from_bytes(self, &result) + } + _ => not_supported!("unknown simd intrinsic {name}"), + } + } +} |