Commit 41a44f2f authored by Stefan Lankes's avatar Stefan Lankes Committed by Stefan Lankes
Browse files

protect network driver by an SpinlockIrqSave

parent defbf8e4
...@@ -12,9 +12,8 @@ use crate::arch::x86_64::kernel::virtio_fs::VirtioFsDriver; ...@@ -12,9 +12,8 @@ use crate::arch::x86_64::kernel::virtio_fs::VirtioFsDriver;
use crate::arch::x86_64::kernel::virtio_net::VirtioNetDriver; use crate::arch::x86_64::kernel::virtio_net::VirtioNetDriver;
use crate::synch::spinlock::SpinlockIrqSave; use crate::synch::spinlock::SpinlockIrqSave;
use crate::x86::io::*; use crate::x86::io::*;
use alloc::rc::Rc;
use alloc::vec::Vec; use alloc::vec::Vec;
use core::cell::RefCell; use core::cell::UnsafeCell;
use core::convert::TryInto; use core::convert::TryInto;
use core::{fmt, u32, u8}; use core::{fmt, u32, u8};
...@@ -125,8 +124,8 @@ pub struct MemoryBar { ...@@ -125,8 +124,8 @@ pub struct MemoryBar {
} }
pub enum PciDriver<'a> { pub enum PciDriver<'a> {
VirtioFs(Rc<RefCell<VirtioFsDriver<'a>>>), VirtioFs(UnsafeCell<SpinlockIrqSave<VirtioFsDriver<'a>>>),
VirtioNet(Rc<RefCell<VirtioNetDriver<'a>>>), VirtioNet(UnsafeCell<SpinlockIrqSave<VirtioNetDriver<'a>>>),
} }
pub fn register_driver(drv: PciDriver<'static>) { pub fn register_driver(drv: PciDriver<'static>) {
...@@ -134,12 +133,11 @@ pub fn register_driver(drv: PciDriver<'static>) { ...@@ -134,12 +133,11 @@ pub fn register_driver(drv: PciDriver<'static>) {
drivers.push(drv); drivers.push(drv);
} }
pub fn get_network_driver() -> Option<Rc<RefCell<VirtioNetDriver<'static>>>> { pub fn get_network_driver() -> Option<&'static SpinlockIrqSave<VirtioNetDriver<'static>>> {
let drivers = PCI_DRIVERS.lock(); for i in PCI_DRIVERS.lock().iter() {
for i in drivers.iter() {
match &*i { match &*i {
PciDriver::VirtioNet(nic_driver) => { PciDriver::VirtioNet(nic_driver) => {
return Some(nic_driver.clone()); return Some(unsafe { &*nic_driver.get() });
} }
_ => {} _ => {}
} }
......
...@@ -16,11 +16,13 @@ use crate::arch::x86_64::kernel::virtio_net; ...@@ -16,11 +16,13 @@ use crate::arch::x86_64::kernel::virtio_net;
use crate::arch::x86_64::mm::paging; use crate::arch::x86_64::mm::paging;
use crate::config::VIRTIO_MAX_QUEUE_SIZE; use crate::config::VIRTIO_MAX_QUEUE_SIZE;
use crate::synch::spinlock::SpinlockIrqSave;
use alloc::boxed::Box; use alloc::boxed::Box;
use alloc::rc::Rc; use alloc::rc::Rc;
use alloc::vec::Vec; use alloc::vec::Vec;
use core::cell::RefCell; use core::cell::RefCell;
use core::cell::UnsafeCell;
use core::convert::TryInto; use core::convert::TryInto;
use core::sync::atomic::spin_loop_hint; use core::sync::atomic::spin_loop_hint;
use core::sync::atomic::{fence, Ordering}; use core::sync::atomic::{fence, Ordering};
...@@ -828,7 +830,9 @@ pub fn init_virtio_device(adapter: &pci::PciAdapter) { ...@@ -828,7 +830,9 @@ pub fn init_virtio_device(adapter: &pci::PciAdapter) {
PciNetworkControllerSubclass::EthernetController => { PciNetworkControllerSubclass::EthernetController => {
// TODO: proper error handling on driver creation fail // TODO: proper error handling on driver creation fail
let drv = virtio_net::create_virtionet_driver(adapter).unwrap(); let drv = virtio_net::create_virtionet_driver(adapter).unwrap();
pci::register_driver(PciDriver::VirtioNet(drv)); pci::register_driver(PciDriver::VirtioNet(UnsafeCell::new(
SpinlockIrqSave::new(drv),
)));
} }
_ => { _ => {
warn!("Virtio device is NOT supported, skipping!"); warn!("Virtio device is NOT supported, skipping!");
...@@ -872,7 +876,7 @@ extern "x86-interrupt" fn virtio_irqhandler(_stack_frame: &mut ExceptionStackFra ...@@ -872,7 +876,7 @@ extern "x86-interrupt" fn virtio_irqhandler(_stack_frame: &mut ExceptionStackFra
increment_irq_counter((32 + unsafe { VIRTIO_IRQ_NO }).into()); increment_irq_counter((32 + unsafe { VIRTIO_IRQ_NO }).into());
let check_scheduler = match get_network_driver() { let check_scheduler = match get_network_driver() {
Some(driver) => driver.borrow_mut().handle_interrupt(), Some(driver) => driver.lock().handle_interrupt(),
_ => false, _ => false,
}; };
......
...@@ -462,9 +462,7 @@ impl<'a> VirtioNetDriver<'a> { ...@@ -462,9 +462,7 @@ impl<'a> VirtioNetDriver<'a> {
} }
} }
pub fn create_virtionet_driver( pub fn create_virtionet_driver(adapter: &pci::PciAdapter) -> Option<VirtioNetDriver<'static>> {
adapter: &pci::PciAdapter,
) -> Option<Rc<RefCell<VirtioNetDriver<'static>>>> {
// Scan capabilities to get common config, which we need to reset the device and get basic info. // Scan capabilities to get common config, which we need to reset the device and get basic info.
// also see https://elixir.bootlin.com/linux/latest/source/drivers/virtio/virtio_pci_modern.c#L581 (virtio_pci_modern_probe) // also see https://elixir.bootlin.com/linux/latest/source/drivers/virtio/virtio_pci_modern.c#L581 (virtio_pci_modern_probe)
// Read status register // Read status register
...@@ -533,7 +531,7 @@ pub fn create_virtionet_driver( ...@@ -533,7 +531,7 @@ pub fn create_virtionet_driver(
// TODO: also load the other cap types (?). // TODO: also load the other cap types (?).
// Instanciate driver on heap, so it outlives this function // Instanciate driver on heap, so it outlives this function
let drv = Rc::new(RefCell::new(VirtioNetDriver { let mut drv = VirtioNetDriver {
tx_buffers: Vec::new(), tx_buffers: Vec::new(),
rx_buffers: Vec::new(), rx_buffers: Vec::new(),
common_cfg, common_cfg,
...@@ -541,10 +539,10 @@ pub fn create_virtionet_driver( ...@@ -541,10 +539,10 @@ pub fn create_virtionet_driver(
isr_cfg, isr_cfg,
notify_cfg, notify_cfg,
vqueues: None, vqueues: None,
})); };
trace!("Driver before init: {:?}", drv); trace!("Driver before init: {:?}", drv);
drv.borrow_mut().init(); drv.init();
trace!("Driver after init: {:?}", drv); trace!("Driver after init: {:?}", drv);
if device_cfg.status & VIRTIO_NET_S_LINK_UP == VIRTIO_NET_S_LINK_UP { if device_cfg.status & VIRTIO_NET_S_LINK_UP == VIRTIO_NET_S_LINK_UP {
...@@ -552,10 +550,7 @@ pub fn create_virtionet_driver( ...@@ -552,10 +550,7 @@ pub fn create_virtionet_driver(
} else { } else {
info!("Virtio-Net link is down"); info!("Virtio-Net link is down");
} }
info!( info!("Virtio-Net status: 0x{:x}", drv.common_cfg.device_status);
"Virtio-Net status: 0x{:x}",
drv.borrow().common_cfg.device_status
);
Some(drv) Some(drv)
} }
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be // http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms. // copied, modified, or distributed except according to those terms.
use crate::arch::irq;
use crate::arch::kernel::percore::*; use crate::arch::kernel::percore::*;
use crate::scheduler::task::TaskHandle; use crate::scheduler::task::TaskHandle;
use crate::synch::semaphore::*; use crate::synch::semaphore::*;
...@@ -25,14 +24,10 @@ const POLL_PERIOD: u64 = 20_000; ...@@ -25,14 +24,10 @@ const POLL_PERIOD: u64 = 20_000;
fn set_polling_mode(value: bool) { fn set_polling_mode(value: bool) {
// is the driver already in polling mode? // is the driver already in polling mode?
if POLLING.swap(value, Ordering::SeqCst) != value { if POLLING.swap(value, Ordering::SeqCst) != value {
let irq = irq::nested_disable();
if let Some(driver) = crate::arch::kernel::pci::get_network_driver() { if let Some(driver) = crate::arch::kernel::pci::get_network_driver() {
driver.borrow_mut().set_polling_mode(value); driver.lock().set_polling_mode(value);
} }
irq::nested_enable(irq);
// wakeup network thread to sleep for longer time // wakeup network thread to sleep for longer time
NET_SEM.release(); NET_SEM.release();
} }
...@@ -70,13 +65,9 @@ pub fn netwait_and_wakeup(handles: &[usize], millis: Option<u64>) { ...@@ -70,13 +65,9 @@ pub fn netwait_and_wakeup(handles: &[usize], millis: Option<u64>) {
} }
if reset_nic { if reset_nic {
let irq = irq::nested_disable();
if let Some(driver) = crate::arch::kernel::pci::get_network_driver() { if let Some(driver) = crate::arch::kernel::pci::get_network_driver() {
driver.borrow_mut().set_polling_mode(false); driver.lock().set_polling_mode(false);
} }
irq::nested_enable(irq);
} else { } else {
NET_SEM.acquire(millis); NET_SEM.acquire(millis);
} }
......
...@@ -17,7 +17,6 @@ use crate::arch; ...@@ -17,7 +17,6 @@ use crate::arch;
use crate::console; use crate::console;
use crate::environment; use crate::environment;
use crate::errno::*; use crate::errno::*;
use crate::synch::spinlock::SpinlockIrqSave;
use crate::syscalls::fs::{self, FilePerms, PosixFile, SeekWhence}; use crate::syscalls::fs::{self, FilePerms, PosixFile, SeekWhence};
use crate::util; use crate::util;
...@@ -27,8 +26,6 @@ pub use self::uhyve::*; ...@@ -27,8 +26,6 @@ pub use self::uhyve::*;
mod generic; mod generic;
mod uhyve; mod uhyve;
static DRIVER_LOCK: SpinlockIrqSave<()> = SpinlockIrqSave::new(());
const SEEK_SET: i32 = 0; const SEEK_SET: i32 = 0;
const SEEK_CUR: i32 = 1; const SEEK_CUR: i32 = 1;
const SEEK_END: i32 = 2; const SEEK_END: i32 = 2;
...@@ -119,65 +116,51 @@ pub trait SyscallInterface: Send + Sync { ...@@ -119,65 +116,51 @@ pub trait SyscallInterface: Send + Sync {
} }
fn get_mac_address(&self) -> Result<[u8; 6], ()> { fn get_mac_address(&self) -> Result<[u8; 6], ()> {
let _lock = DRIVER_LOCK.lock();
match arch::kernel::pci::get_network_driver() { match arch::kernel::pci::get_network_driver() {
Some(driver) => Ok(driver.borrow().get_mac_address()), Some(driver) => Ok(driver.lock().get_mac_address()),
_ => Err(()), _ => Err(()),
} }
} }
fn get_mtu(&self) -> Result<u16, ()> { fn get_mtu(&self) -> Result<u16, ()> {
let _lock = DRIVER_LOCK.lock();
match arch::kernel::pci::get_network_driver() { match arch::kernel::pci::get_network_driver() {
Some(driver) => Ok(driver.borrow().get_mtu()), Some(driver) => Ok(driver.lock().get_mtu()),
_ => Err(()), _ => Err(()),
} }
} }
fn has_packet(&self) -> bool { fn has_packet(&self) -> bool {
let _lock = DRIVER_LOCK.lock();
match arch::kernel::pci::get_network_driver() { match arch::kernel::pci::get_network_driver() {
Some(driver) => driver.borrow().has_packet(), Some(driver) => driver.lock().has_packet(),
_ => false, _ => false,
} }
} }
fn get_tx_buffer(&self, len: usize) -> Result<(*mut u8, usize), ()> { fn get_tx_buffer(&self, len: usize) -> Result<(*mut u8, usize), ()> {
let _lock = DRIVER_LOCK.lock();
match arch::kernel::pci::get_network_driver() { match arch::kernel::pci::get_network_driver() {
Some(driver) => driver.borrow_mut().get_tx_buffer(len), Some(driver) => driver.lock().get_tx_buffer(len),
_ => Err(()), _ => Err(()),
} }
} }
fn send_tx_buffer(&self, handle: usize, len: usize) -> Result<(), ()> { fn send_tx_buffer(&self, handle: usize, len: usize) -> Result<(), ()> {
let _lock = DRIVER_LOCK.lock();
match arch::kernel::pci::get_network_driver() { match arch::kernel::pci::get_network_driver() {
Some(driver) => driver.borrow_mut().send_tx_buffer(handle, len), Some(driver) => driver.lock().send_tx_buffer(handle, len),
_ => Err(()), _ => Err(()),
} }
} }
fn receive_rx_buffer(&self) -> Result<&'static [u8], ()> { fn receive_rx_buffer(&self) -> Result<&'static [u8], ()> {
let _lock = DRIVER_LOCK.lock();
match arch::kernel::pci::get_network_driver() { match arch::kernel::pci::get_network_driver() {
Some(driver) => driver.borrow().receive_rx_buffer(), Some(driver) => driver.lock().receive_rx_buffer(),
_ => Err(()), _ => Err(()),
} }
} }
fn rx_buffer_consumed(&self) -> Result<(), ()> { fn rx_buffer_consumed(&self) -> Result<(), ()> {
let _lock = DRIVER_LOCK.lock();
match arch::kernel::pci::get_network_driver() { match arch::kernel::pci::get_network_driver() {
Some(driver) => { Some(driver) => {
driver.borrow_mut().rx_buffer_consumed(); driver.lock().rx_buffer_consumed();
Ok(()) Ok(())
} }
_ => Err(()), _ => Err(()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment