Skip to main content

proka_kernel/libs/
initrd.rs

1//! This code is originally from https://github.com/rcore-os/cpio/blob/main/src/lib.rs, with minor modifications made here.
2
3extern crate alloc;
4use crate::fs::fs_impl::MemFs;
5use crate::fs::vfs::{File, FileSystem, VNodeType, VfsError, VFS};
6use alloc::format;
7use alloc::string::{String, ToString};
8use alloc::vec::Vec;
9use log::{debug, error, info, warn};
10
11/// A CPIO file (newc format) reader.
12///
13/// # Example
14///
15/// ```rust,should_panic
16/// use cpio::CpioNewcReader;
17///
18/// let reader = CpioNewcReader::new(&[]);
19/// for obj in reader {
20///     println!("{}", obj.unwrap().name);
21/// }
22/// ```
23pub struct CpioNewcReader<'a> {
24    buf: &'a [u8],
25}
26
27impl<'a> CpioNewcReader<'a> {
28    /// Creates a new CPIO reader on the buffer.
29    pub fn new(buf: &'a [u8]) -> Self {
30        Self { buf }
31    }
32}
33
34/// File system object in CPIO file.
35pub struct Object<'a> {
36    /// The file metadata.
37    pub metadata: Metadata,
38    /// The full pathname.
39    pub name: &'a str,
40    /// The file data.
41    pub data: &'a [u8],
42}
43
44impl<'a> Iterator for CpioNewcReader<'a> {
45    type Item = Result<Object<'a>, ReadError>;
46
47    fn next(&mut self) -> Option<Self::Item> {
48        // SAFETY: To workaround lifetime
49        let s: &'a mut Self = unsafe { core::mem::transmute(self) };
50        match inner(&mut s.buf) {
51            Ok(Object {
52                name: "TRAILER!!!", ..
53            }) => None,
54            res => Some(res),
55        }
56    }
57}
58
59fn inner<'a>(buf: &'a mut &'a [u8]) -> Result<Object<'a>, ReadError> {
60    const HEADER_LEN: usize = 110;
61    const MAGIC_NUMBER: &[u8] = b"070701";
62
63    if buf.len() < HEADER_LEN {
64        return Err(ReadError::BufTooShort);
65    }
66    let magic = buf.read_bytes(6)?;
67    if magic != MAGIC_NUMBER {
68        return Err(ReadError::InvalidMagic);
69    }
70    let ino = buf.read_hex_u32()?;
71    let mode = buf.read_hex_u32()?;
72    let uid = buf.read_hex_u32()?;
73    let gid = buf.read_hex_u32()?;
74    let nlink = buf.read_hex_u32()?;
75    let mtime = buf.read_hex_u32()?;
76    let file_size = buf.read_hex_u32()?;
77    let dev_major = buf.read_hex_u32()?;
78    let dev_minor = buf.read_hex_u32()?;
79    let rdev_major = buf.read_hex_u32()?;
80    let rdev_minor = buf.read_hex_u32()?;
81    let name_size = buf.read_hex_u32()? as usize;
82    let _check = buf.read_hex_u32()?;
83    let metadata = Metadata {
84        ino,
85        mode,
86        uid,
87        gid,
88        nlink,
89        mtime,
90        file_size,
91        dev_major,
92        dev_minor,
93        rdev_major,
94        rdev_minor,
95    };
96    let name_with_nul = buf.read_bytes(name_size)?;
97    if name_with_nul.last() != Some(&0) {
98        return Err(ReadError::InvalidName);
99    }
100    let name = core::str::from_utf8(&name_with_nul[..name_size - 1])
101        .map_err(|_| ReadError::InvalidName)?;
102    buf.read_bytes(pad_to_4(HEADER_LEN + name_size))?;
103
104    let data = buf.read_bytes(file_size as usize)?;
105    buf.read_bytes(pad_to_4(file_size as usize))?;
106
107    Ok(Object {
108        metadata,
109        name,
110        data,
111    })
112}
113
114trait BufExt<'a> {
115    fn read_hex_u32(&mut self) -> Result<u32, ReadError>;
116    fn read_bytes(&mut self, len: usize) -> Result<&'a [u8], ReadError>;
117}
118
119impl<'a> BufExt<'a> for &'a [u8] {
120    fn read_hex_u32(&mut self) -> Result<u32, ReadError> {
121        let (hex, rest) = self.split_at(8);
122        *self = rest;
123        let str = core::str::from_utf8(hex).map_err(|_| ReadError::InvalidASCII)?;
124        let value = u32::from_str_radix(str, 16).map_err(|_| ReadError::InvalidASCII)?;
125        Ok(value)
126    }
127
128    fn read_bytes(&mut self, len: usize) -> Result<&'a [u8], ReadError> {
129        if self.len() < len {
130            return Err(ReadError::BufTooShort);
131        }
132        let (bytes, rest) = self.split_at(len);
133        *self = rest;
134        Ok(bytes)
135    }
136}
137
138/// pad out to a multiple of 4 bytes
139fn pad_to_4(len: usize) -> usize {
140    match len % 4 {
141        0 => 0,
142        x => 4 - x,
143    }
144}
145
146/// The error type which is returned from CPIO reader.
147#[derive(Debug, PartialEq, Eq)]
148pub enum ReadError {
149    InvalidASCII,
150    InvalidMagic,
151    InvalidName,
152    BufTooShort,
153}
154
155/// The file metadata.
156#[derive(Debug)]
157pub struct Metadata {
158    pub ino: u32,
159    pub mode: u32,
160    pub uid: u32,
161    pub gid: u32,
162    pub nlink: u32,
163    pub mtime: u32,
164    pub file_size: u32,
165    pub dev_major: u32,
166    pub dev_minor: u32,
167    pub rdev_major: u32,
168    pub rdev_minor: u32,
169}
170
171// CPIO mode constants
172const CPIO_S_IFMT: u32 = 0o170000; // Mask for file type
173const CPIO_S_IFDIR: u32 = 0o040000; // Directory
174const CPIO_S_IFREG: u32 = 0o100000; // Regular file
175const CPIO_S_IFLNK: u32 = 0o120000; // Symbolic link
176
177/// Loads the initial RAM disk (initrd) into the Virtual File System (VFS).
178///
179/// This function parses a CPIO archive provided as raw bytes, extracts its
180/// contents, and recreates the file and directory structure within the VFS.
181///
182/// # Arguments
183/// * `initrd_data` - A byte slice containing the CPIO archive data.
184///
185/// # Returns
186/// A `Result` indicating success or a `VfsError` if any operation fails.
187pub fn load_cpio(initrd_data: &[u8]) -> Result<(), VfsError> {
188    let reader = CpioNewcReader::new(initrd_data);
189    let vfs = &*VFS;
190    debug!("Loading CPIO archive...");
191    for obj_result in reader {
192        let obj = obj_result.map_err(|e| {
193            error!("CPIO read error: {:?}", e);
194            VfsError::IoError
195        })?;
196
197        let path = obj.name;
198        if path == "TRAILER!!!" {
199            continue; // Skip the trailer entry, already handled by iterator but good for explicit check
200        }
201
202        // Normalize path: CPIO paths are often like "foo/bar" or "./foo/bar".
203        // We want absolute paths in VFS, e.g., "/foo/bar".
204        let canonical_path = if path.starts_with('/') {
205            path.to_string()
206        } else if let Some(stripped) = path.strip_prefix("./") {
207            format!("/{}", stripped)
208        } else {
209            format!("/{}", path)
210        };
211
212        // Remove trailing slash unless it's the root itself.
213        let final_path = if canonical_path.len() > 1 && canonical_path.ends_with('/') {
214            canonical_path.trim_end_matches('/').to_string()
215        } else {
216            canonical_path
217        };
218
219        let node_type_mode = obj.metadata.mode & CPIO_S_IFMT;
220
221        // Ensure all parent directories exist for the current object's path.
222        // This loop iterates through path components and creates intermediate
223        // directories if they don't already exist.
224        let mut current_dir_segment = String::new();
225        let components: Vec<&str> = final_path.split('/').filter(|&s| !s.is_empty()).collect();
226
227        for (i, component) in components.iter().enumerate() {
228            current_dir_segment.push('/');
229            current_dir_segment.push_str(component);
230
231            // If it's an intermediate component OR the last component is a directory itself,
232            // ensure it exists and is a directory.
233            if i < components.len() - 1 || node_type_mode == CPIO_S_IFDIR {
234                match vfs.lookup(&current_dir_segment) {
235                    Ok(node) => {
236                        if node.node_type() != VNodeType::Dir {
237                            error!(
238                                "Path component '{}' for '{}' is not a directory!",
239                                current_dir_segment, final_path
240                            );
241                            return Err(VfsError::AlreadyExists); // Or a specific error
242                        }
243                    }
244                    Err(VfsError::NotFound) => {
245                        vfs.create_dir(&current_dir_segment).inspect_err(|e| {
246                            error!(
247                                "Failed to create directory {}: {:?}",
248                                current_dir_segment, e
249                            );
250                        })?;
251                    }
252                    Err(e) => {
253                        error!("Error checking path {}: {:?}", current_dir_segment, e);
254                        return Err(e);
255                    }
256                }
257            }
258        }
259        debug!("Created parent directories for {}", final_path);
260        // Now, handle the actual CPIO object based on its type
261        match node_type_mode {
262            CPIO_S_IFREG => {
263                let file_inode = vfs.create_file(&final_path)?;
264                debug!("Created file {}", final_path);
265                let file_handle = File::new(file_inode);
266                file_handle.write(obj.data)?;
267            }
268            CPIO_S_IFLNK => {
269                let target_path =
270                    core::str::from_utf8(obj.data).map_err(|_| VfsError::InvalidArgument)?;
271                vfs.create_symlink(target_path, &final_path)?;
272            }
273            _ => {}
274        }
275    }
276    Ok(())
277}
278
279pub fn load_initrd() {
280    let memfs = MemFs;
281    let root = memfs
282        .mount(None, None)
283        .expect("Failed to initialize root filesystem");
284    VFS.init_root(root);
285
286    // Load initrd
287    if let Some(initrd_response) = crate::MODULE_REQUEST.get_response() {
288        if let Some(inir) = initrd_response.modules().first() {
289            unsafe {
290                let slice: &[u8] = core::slice::from_raw_parts(inir.addr(), inir.size() as usize);
291                match load_cpio(slice) {
292                    Ok(_) => info!("Initrd loaded successfully."),
293                    Err(e) => error!("Failed to load initrd: {:?}", e),
294                }
295            }
296        } else {
297            warn!("No initrd module found.");
298        }
299    } else {
300        warn!("Initrd module request failed.");
301    }
302}