diff --git a/src/userprog/syscall.c b/src/userprog/syscall.c index 6a1997d..931d177 100644 --- a/src/userprog/syscall.c +++ b/src/userprog/syscall.c @@ -18,7 +18,9 @@ static struct lock filesys_lock; static int get_byte (const uint8_t *uaddr); static uint32_t get_word (const uint32_t *uaddr); -static void validate_ptr (const uint8_t *uaddr); +static bool put_byte (uint8_t *udst, uint8_t byte); +static void validate_ptr_read (const uint8_t *uaddr, unsigned); +static void validate_ptr_write (uint8_t *udst, unsigned size); static void syscall_handler (struct intr_frame *); @@ -42,53 +44,6 @@ syscall_init (void) intr_register_int (0x30, 3, INTR_ON, syscall_handler, "syscall"); } -/* Reads a byte at user virtual address UADDR. - Returns the byte value if successful, -1 if a segfault - occured or UADDR is not in the user space. */ -static int -get_byte (const uint8_t *uaddr) -{ - if (!is_user_vaddr (uaddr)) - return -1; - int result; - asm ("movl $1f, %0; movzbl %1, %0; 1:" - : "=&a" (result) : "m" (*uaddr)); - return result; -} - -/* Reads a word at user virtual address ADDR. - Returns the word value if successful, calls exit system - call if not. */ -static uint32_t -get_word (const uint32_t *uaddr) -{ - uint32_t res; - int byte; - int i; - for (i = 0; i < 4; i++) - { - byte = get_byte ((uint8_t *) uaddr + i); - if (byte == -1) - { - syscall_exit (-1); - NOT_REACHED (); - } - *((uint8_t *) &res + i) = (uint8_t) byte; - } - return res; -} - -/* Validates a given user virtual address. */ -static void -validate_ptr (const uint8_t *uaddr) -{ - if (get_byte (uaddr) == -1) - { - syscall_exit (-1); - NOT_REACHED (); - } -} - /* Handler which matches the appropriate system call. */ static void syscall_handler (struct intr_frame *f UNUSED) @@ -181,7 +136,7 @@ static pid_t syscall_exec (const char *cmd_line) { /* Check the validity. */ - validate_ptr (cmd_line); + validate_ptr_read (cmd_line, 1); /* Create a new process. */ pid_t pid = process_execute (cmd_line); @@ -216,7 +171,7 @@ static bool syscall_create (const char *file, unsigned init_size) { /* Check the validity. */ - validate_ptr (file); + validate_ptr_read (file, 1); /* Create a new file. */ lock_acquire (&filesys_lock); @@ -230,7 +185,7 @@ static bool syscall_remove (const char *file) { /* Check the validity. */ - validate_ptr (file); + validate_ptr_read (file, 1); /* Create a new file. */ lock_acquire (&filesys_lock); @@ -245,7 +200,7 @@ static int syscall_open (const char *file) { /* Check the validity. */ - validate_ptr (file); + validate_ptr_read (file, 1); /* Open the file. */ lock_acquire (&filesys_lock); @@ -292,8 +247,9 @@ static int syscall_read (int fd, void *buffer, unsigned size) { /* Check the validity. */ - validate_ptr (buffer); uint8_t *bf = (uint8_t *) buffer; + validate_ptr_read (bf, size); + validate_ptr_write (bf, size); /* Read from STDIN. */ unsigned bytes = 0; @@ -331,7 +287,7 @@ static int syscall_write (int fd, void *buffer, unsigned size) { /* Check the validity. */ - validate_ptr (buffer); + validate_ptr_read (buffer, size); /* Write to STDOUT. */ if (fd == STDOUT_FILENO) @@ -421,3 +377,87 @@ syscall_close (int fd) } lock_release (&filesys_lock); } + +/* Reads a byte at user virtual address UADDR. + Returns the byte value if successful, -1 if a segfault + occured or UADDR is not in the user space. */ +static int +get_byte (const uint8_t *uaddr) +{ + if (!is_user_vaddr (uaddr)) + return -1; + int result; + asm ("movl $1f, %0; movzbl %1, %0; 1:" + : "=&a" (result) : "m" (*uaddr)); + return result; +} + +/* Reads a word at user virtual address ADDR. + Returns the word value if successful, calls exit system + call if not. */ +static uint32_t +get_word (const uint32_t *uaddr) +{ + uint32_t res; + int byte; + int i; + for (i = 0; i < 4; i++) + { + byte = get_byte ((uint8_t *) uaddr + i); + if (byte == -1) + { + syscall_exit (-1); + NOT_REACHED (); + } + *((uint8_t *) &res + i) = (uint8_t) byte; + } + return res; +} + +/* Writes BYTE to user address UDST. + Returns true if successful, false if a segfault occurred + or UADDR is not in the user space. */ +static bool +put_byte (uint8_t *udst, uint8_t byte) +{ + if (!is_user_vaddr (udst)) + return false; + int error_code; + asm ("movl $1f, %0; movb %b2, %1; 1:" + : "=&a" (error_code), "=m" (*udst) : "r" (byte)); + return error_code != -1; +} + +/* Validates reading to a given user virtual address UADDR up to + size SIZE. */ +static void +validate_ptr_read (const uint8_t *uaddr, size_t size) +{ + uint8_t *ptr; + for (ptr = pg_round_down (uaddr); ptr < uaddr + size; ptr += PGSIZE) + { + if (get_byte (ptr) == -1) + { + syscall_exit (-1); + NOT_REACHED (); + } + } +} + +/* Validates writing to a given user virtual address UADDR up to + size SIZE. + + Use this method after validate_ptr_read(). */ +static void +validate_ptr_write (uint8_t *udst, unsigned size) +{ + uint8_t *ptr; + for (ptr = pg_round_down (udst); ptr < udst + size; ptr += PGSIZE) + { + if (!put_byte (ptr, get_byte (ptr))) + { + syscall_exit (-1); + NOT_REACHED (); + } + } +}