diff --git a/kernel/src/syscall/mod.rs b/kernel/src/syscall/mod.rs index ee12c2e7e..20d7a5ef4 100644 --- a/kernel/src/syscall/mod.rs +++ b/kernel/src/syscall/mod.rs @@ -25,6 +25,8 @@ use crate::{ }, }; +use self::user_access::UserBufferWriter; + pub mod user_access; #[repr(i32)] @@ -651,18 +653,18 @@ impl Syscall { SYS_CLOCK => Self::clock(), SYS_PIPE => { let pipefd = args[0] as *mut c_int; - let virt_pipefd = VirtAddr::new(pipefd as usize); - if from_user - && verify_area(virt_pipefd, core::mem::size_of::<[c_int; 2]>() as usize) - .is_err() - { - Err(SystemError::EFAULT) - } else if pipefd.is_null() { - Err(SystemError::EFAULT) - } else { - let pipefd = unsafe { core::slice::from_raw_parts_mut(pipefd, 2) }; - Self::pipe(pipefd) + match UserBufferWriter::new( + pipefd, + core::mem::size_of::<[c_int; 2]>() as usize, + from_user, + ) { + Err(e) => Err(e), + Ok(mut user_buffer) => match user_buffer.buffer::(0) { + Err(e) => Err(e), + Ok(pipefd) => Self::pipe(pipefd), + }, } + } SYS_UNLINK_AT => { @@ -853,27 +855,21 @@ impl Syscall { SYS_RECVMSG => { let msg = args[1] as *mut crate::net::syscall::MsgHdr; let flags = args[2] as u32; - let virt_msg = VirtAddr::new(msg as usize); - let security_check = || { - // 验证msg的地址是否合法 - if verify_area( - virt_msg, - core::mem::size_of::() as usize, - ) - .is_err() - { - // 地址空间超出了用户空间的范围,不合法 - return Err(SystemError::EFAULT); + match UserBufferWriter::new( + msg, + core::mem::size_of::(), + true, + ) { + Err(e) => Err(e), + Ok(mut user_buffer_writer) => { + match user_buffer_writer.buffer::(0) { + Err(e) => Err(e), + Ok(buffer) => { + let msg = &mut buffer[0]; + Self::recvmsg(args[0], msg, flags) + } + } } - let msg = unsafe { msg.as_mut() }.ok_or(SystemError::EFAULT)?; - return Ok(msg); - }; - let r = security_check(); - if r.is_err() { - Err(r.unwrap_err()) - } else { - let msg = r.unwrap(); - Self::recvmsg(args[0], msg, flags) } } @@ -889,32 +885,23 @@ impl Syscall { SYS_GETTIMEOFDAY => { let timeval = args[0] as *mut PosixTimeval; let timezone_ptr = args[1] as *mut PosixTimeZone; - let virt_timeval = VirtAddr::new(timeval as usize); - let virt_timezone_ptr = VirtAddr::new(timezone_ptr as usize); - let security_check = || { - if verify_area(virt_timeval, core::mem::size_of::() as usize) - .is_err() - { - return Err(SystemError::EFAULT); - } - if verify_area( - virt_timezone_ptr, - core::mem::size_of::() as usize, - ) - .is_err() - { - return Err(SystemError::EFAULT); - } - return Ok(()); - }; - let r = security_check(); - if r.is_err() { - Err(r.unwrap_err()) - } else { - if !timeval.is_null() { - Self::gettimeofday(timeval, timezone_ptr) - } else { - Err(SystemError::EFAULT) + match UserBufferWriter::new(timeval, core::mem::size_of::(), true) { + Err(e) => Err(e), + Ok(_) => { + match UserBufferWriter::new( + timezone_ptr, + core::mem::size_of::(), + true, + ) { + Err(e) => Err(e), + Ok(_) => { + if !timeval.is_null() { + Self::gettimeofday(timeval, timezone_ptr) + } else { + Err(SystemError::EFAULT) + } + } + } } } } diff --git a/kernel/src/syscall/user_access.rs b/kernel/src/syscall/user_access.rs index 1078e4171..006a30e0a 100644 --- a/kernel/src/syscall/user_access.rs +++ b/kernel/src/syscall/user_access.rs @@ -1,5 +1,9 @@ //! 这个文件用于放置一些内核态访问用户态数据的函数 -use core::mem::size_of; + +use core::{ + mem::size_of, + slice::{from_raw_parts, from_raw_parts_mut}, +}; use alloc::{string::String, vec::Vec}; @@ -139,3 +143,199 @@ pub fn check_and_clone_cstr_array(user: *const *const u8) -> Result, return Ok(buffer); } } + +#[derive(Debug)] +pub struct UserBufferWriter<'a> { + buffer: &'a mut [u8], +} + +#[derive(Debug)] +pub struct UserBufferReader<'a> { + buffer: &'a [u8], +} + +#[allow(dead_code)] +impl<'a> UserBufferReader<'a> { + /// 构造一个指向用户空间位置的BufferReader,为了兼容类似传入 *const u8 的情况,使用单独的泛型来进行初始化 + /// + /// @param addr 用户空间指针 + /// @param len 缓冲区的字节长度 + /// @param frm_user 代表是否要检验地址来自用户空间 + /// @return 构造成功返回UserbufferReader实例,否则返回错误码 + /// + pub fn new(addr: *const U, len: usize, from_user: bool) -> Result { + if from_user && verify_area(VirtAddr::new(addr as usize), len).is_err() { + return Err(SystemError::EFAULT); + } + return Ok(Self { + buffer: unsafe { core::slice::from_raw_parts(addr as *const u8, len) }, + }); + } + + /// 从用户空间读取数据(到变量中) + /// + /// @param offset 字节偏移量 + /// @return 返回用户空间数据的切片(对单个结构体就返回长度为一的切片) + /// + pub fn read_from_user(&self, offset: usize) -> Result<&[T], SystemError> { + return self.convert_with_offset(&self.buffer, offset); + } + /// 从用户空间读取一个指定偏移量的数据(到变量中) + /// + /// @param offset 字节偏移量 + /// @return 返回用户空间数据的引用 + /// + pub fn read_one_from_user(&self, offset: usize) -> Result<&T, SystemError> { + return self.convert_one_with_offset(&self.buffer, offset); + } + + /// 从用户空间拷贝数据(到指定地址中) + /// + /// @param dst 目标地址指针 + /// @return 拷贝成功的话返回拷贝的元素数量 + /// + pub fn copy_from_user( + &self, + dst: &mut [T], + offset: usize, + ) -> Result { + let data = self.convert_with_offset(&self.buffer, offset)?; + dst.copy_from_slice(data); + return Ok(dst.len()); + } + + /// 从用户空间拷贝数据(到指定地址中) + /// + /// @param dst 目标地址指针 + /// @return 拷贝成功的话返回拷贝的元素数量 + /// + pub fn copy_one_from_user( + &self, + dst: &mut T, + offset: usize, + ) -> Result<(), SystemError> { + let data = self.convert_one_with_offset::(&self.buffer, offset)?; + dst.clone_from(data); + return Ok(()); + } + + fn convert_with_offset(&self, src: &[u8], offset: usize) -> Result<&[T], SystemError> { + if offset >= src.len() { + return Err(SystemError::EINVAL); + } + let byte_buffer: &[u8] = &src[offset..]; + if byte_buffer.len() % core::mem::size_of::() != 0 || byte_buffer.is_empty() { + return Err(SystemError::EINVAL); + } + + let chunks = unsafe { + from_raw_parts( + byte_buffer.as_ptr() as *const T, + byte_buffer.len() / core::mem::size_of::(), + ) + }; + return Ok(chunks); + } + + fn convert_one_with_offset(&self, src: &[u8], offset: usize) -> Result<&T, SystemError> { + if offset + core::mem::size_of::() > src.len() { + return Err(SystemError::EINVAL); + } + let byte_buffer: &[u8] = &src[offset..offset + core::mem::size_of::()]; + + let chunks = unsafe { from_raw_parts(byte_buffer.as_ptr() as *const T, 1) }; + let data = &chunks[0]; + return Ok(data); + } +} + +#[allow(dead_code)] +impl<'a> UserBufferWriter<'a> { + /// 构造一个指向用户空间位置的BufferWriter + /// + /// @param addr 用户空间指针 + /// @param len 缓冲区的字节长度 + /// @return 构造成功返回UserbufferWriter实例,否则返回错误码 + /// + pub fn new(addr: *mut U, len: usize, from_user: bool) -> Result { + if from_user + && verify_area( + VirtAddr::new(addr as usize), + (len * core::mem::size_of::()) as usize, + ) + .is_err() + { + return Err(SystemError::EFAULT); + } + return Ok(Self { + buffer: unsafe { + core::slice::from_raw_parts_mut(addr as *mut u8, len * core::mem::size_of::()) + }, + }); + } + + /// 从指定地址写入数据到用户空间 + /// + /// @param data 要写入的数据地址 + /// @param offset 在UserBuffer中的字节偏移量 + /// @return 返回写入元素的数量 + /// + pub fn copy_to_user( + &'a mut self, + src: &'a [T], + offset: usize, + ) -> Result { + let dst = Self::convert_with_offset(self.buffer, offset)?; + dst.copy_from_slice(&src); + return Ok(src.len()); + } + + /// 从指定地址写入一个数据到用户空间 + /// + /// @param data 要写入的数据地址 + /// @param offset 在UserBuffer中的字节偏移量 + /// @return 返回写入元素的数量 + /// + pub fn copy_one_to_user( + &'a mut self, + src: &'a T, + offset: usize, + ) -> Result<(), SystemError> { + let dst = Self::convert_one_with_offset::(self.buffer, offset)?; + dst.clone_from(src); + return Ok(()); + } + + pub fn buffer(&'a mut self, offset: usize) -> Result<&mut [T], SystemError> { + Ok(Self::convert_with_offset::(self.buffer, offset).map_err(|_| SystemError::EINVAL)?) + } + + fn convert_with_offset(src: &mut [u8], offset: usize) -> Result<&mut [T], SystemError> { + if offset >= src.len() { + return Err(SystemError::EINVAL); + } + let byte_buffer: &mut [u8] = &mut src[offset..]; + if byte_buffer.len() % core::mem::size_of::() != 0 || byte_buffer.is_empty() { + return Err(SystemError::EINVAL); + } + + let chunks = unsafe { + from_raw_parts_mut( + byte_buffer.as_mut_ptr() as *mut T, + byte_buffer.len() / core::mem::size_of::(), + ) + }; + return Ok(chunks); + } + + fn convert_one_with_offset(src: &mut [u8], offset: usize) -> Result<&mut T, SystemError> { + if offset + core::mem::size_of::() > src.len() { + return Err(SystemError::EINVAL); + } + let byte_buffer: &mut [u8] = &mut src[offset..offset + core::mem::size_of::()]; + + let chunks = unsafe { from_raw_parts_mut(byte_buffer.as_mut_ptr() as *mut T, 1) }; + let data = &mut chunks[0]; + return Ok(data); + } +}