Function torch_sys::c_generated::atg__scaled_dot_product_attention
source · pub unsafe extern "C" fn atg__scaled_dot_product_attention(
out__: *mut *mut C_tensor,
query_: *mut C_tensor,
key_: *mut C_tensor,
value_: *mut C_tensor,
attn_mask_: *mut C_tensor,
dropout_p_: f64,
need_attn_weights_: c_int,
is_causal_: c_int
)