Skip to content

Commit

Permalink
No longer validate Wasm module header in Module::new_unchecked (#1025)
Browse files Browse the repository at this point in the history
* do not validate data section in header-only validation mode

* fix bug

* no longer validate header for Module::new_unchecked

* properly flag ModuleParser functions as unsafe
  • Loading branch information
Robbepop committed May 10, 2024
1 parent 6214e51 commit 2e7a28f
Showing 1 changed file with 92 additions and 54 deletions.
146 changes: 92 additions & 54 deletions crates/wasmi/src/module/parser.rs
Expand Up @@ -63,11 +63,11 @@ pub unsafe fn parse_unchecked(engine: &Engine, stream: impl Read) -> Result<Modu
}

/// Context used to construct a WebAssembly module from a stream of bytes.
pub struct ModuleParser {
struct ModuleParser {
/// The engine used for translation.
engine: Engine,
/// The Wasm validator used throughout stream parsing.
validator: Validator,
validator: Option<Validator>,
/// The underlying Wasm parser.
parser: WasmParser,
/// The number of compiled or processed functions.
Expand All @@ -76,23 +76,13 @@ pub struct ModuleParser {
eof: bool,
}

/// The mode of Wasm validation when parsing a Wasm module.
#[derive(Debug, Copy, Clone)]
pub enum ValidationMode {
/// Perform Wasm validation on the entire Wasm module including Wasm function bodies.
All,
/// Perform Wasm validation only on the Wasm header but not on Wasm function bodies.
HeaderOnly,
}

impl ModuleParser {
/// Creates a new [`ModuleParser`] for the given [`Engine`].
fn new(engine: &Engine) -> Self {
let validator = Validator::new_with_features(engine.config().wasm_features());
let parser = WasmParser::new(0);
Self {
engine: engine.clone(),
validator,
validator: None,
parser,
compiled_funcs: 0,
eof: false,
Expand All @@ -106,37 +96,50 @@ impl ModuleParser {
/// # Errors
///
/// If the Wasm bytecode stream fails to validate.
pub fn parse(self, stream: impl Read) -> Result<Module, Error> {
self.parse_impl(ValidationMode::All, stream)
pub fn parse(mut self, stream: impl Read) -> Result<Module, Error> {
let features = self.engine.config().wasm_features();
self.validator = Some(Validator::new_with_features(features));
// SAFETY: we just pre-populated the Wasm module parser with a validator
// thus calling this method is safe.
unsafe { self.parse_impl(stream) }
}

/// Starts parsing and validating the Wasm bytecode stream.
///
/// Returns the compiled and validated Wasm [`Module`] upon success.
///
/// # Safety
///
/// The caller is responsible to make sure that the provided
/// `stream` yields valid WebAssembly bytecode.
///
/// # Errors
///
/// If the Wasm bytecode stream fails to validate.
pub unsafe fn parse_unchecked(self, stream: impl Read) -> Result<Module, Error> {
self.parse_impl(ValidationMode::HeaderOnly, stream)
unsafe { self.parse_impl(stream) }
}

/// Starts parsing and validating the Wasm bytecode stream.
///
/// Returns the compiled and validated Wasm [`Module`] upon success.
///
/// # Safety
///
/// The caller is responsible to either
///
/// 1) Populate the [`ModuleParser`] with a [`Validator`] prior to calling this method, OR;
/// 2) Make sure that the provided `stream` yields valid WebAssembly bytecode.
///
/// Otherwise this method has undefined behavior.
///
/// # Errors
///
/// If the Wasm bytecode stream fails to validate.
fn parse_impl(
mut self,
validation_mode: ValidationMode,
mut stream: impl Read,
) -> Result<Module, Error> {
unsafe fn parse_impl(mut self, mut stream: impl Read) -> Result<Module, Error> {
let mut buffer = Vec::new();
let header = Self::parse_header(&mut self, &mut stream, &mut buffer)?;
let builder =
Self::parse_code(&mut self, validation_mode, &mut stream, &mut buffer, header)?;
let builder = Self::parse_code(&mut self, &mut stream, &mut buffer, header)?;
let module = Self::parse_data(&mut self, &mut stream, &mut buffer, builder)?;
Ok(module)
}
Expand Down Expand Up @@ -234,7 +237,6 @@ impl ModuleParser {
/// If the Wasm bytecode stream fails to parse or validate.
fn parse_code(
&mut self,
validation_mode: ValidationMode,
stream: &mut impl Read,
buffer: &mut Vec<u8>,
header: ModuleHeader,
Expand All @@ -254,7 +256,7 @@ impl ModuleParser {
let remaining = func_body.get_binary_reader().bytes_remaining();
let start = consumed - remaining;
let bytes = &buffer[start..consumed];
self.process_code_entry(func_body, validation_mode, bytes, &header)?;
self.process_code_entry(func_body, bytes, &header)?;
}
Payload::CustomSection { .. } => {}
Payload::UnknownSection { id, range, .. } => {
Expand Down Expand Up @@ -330,7 +332,12 @@ impl ModuleParser {

/// Processes the end of the Wasm binary.
fn process_end(&mut self, offset: usize) -> Result<(), Error> {
self.validator.end(offset)?;
if let Some(validator) = &mut self.validator {
// This only checks if the number of code section entries and data segments match
// their expected numbers thus we must avoid this check in header-only mode because
// otherwise we will receive errors for unmatched data section entries.
validator.end(offset)?;
}
Ok(())
}

Expand All @@ -341,9 +348,10 @@ impl ModuleParser {
encoding: Encoding,
range: Range<usize>,
) -> Result<(), Error> {
self.validator
.version(num, encoding, &range)
.map_err(Into::into)
if let Some(validator) = &mut self.validator {
validator.version(num, encoding, &range)?;
}
Ok(())
}

/// Processes the Wasm type section.
Expand All @@ -360,7 +368,9 @@ impl ModuleParser {
section: TypeSectionReader,
header: &mut ModuleHeaderBuilder,
) -> Result<(), Error> {
self.validator.type_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.type_section(&section)?;
}
let limits = self.engine.config().get_engine_limits();
let func_types = section.into_iter().map(|result| {
let wasmparser::Type::Func(ty) = result?;
Expand Down Expand Up @@ -397,7 +407,9 @@ impl ModuleParser {
section: ImportSectionReader,
header: &mut ModuleHeaderBuilder,
) -> Result<(), Error> {
self.validator.import_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.import_section(&section)?;
}
let imports = section
.into_iter()
.map(|import| import.map(Import::from).map_err(Error::from));
Expand All @@ -415,9 +427,10 @@ impl ModuleParser {
&mut self,
section: wasmparser::InstanceSectionReader,
) -> Result<(), Error> {
self.validator
.instance_section(&section)
.map_err(Into::into)
if let Some(validator) = &mut self.validator {
validator.instance_section(&section)?;
}
Ok(())
}

/// Process module function declarations.
Expand All @@ -439,7 +452,9 @@ impl ModuleParser {
return Err(Error::from(EnforcedLimitsError::TooManyFunctions { limit }));
}
}
self.validator.function_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.function_section(&section)?;
}
let funcs = section
.into_iter()
.map(|func| func.map(FuncTypeIdx::from).map_err(Error::from));
Expand All @@ -466,7 +481,9 @@ impl ModuleParser {
return Err(Error::from(EnforcedLimitsError::TooManyTables { limit }));
}
}
self.validator.table_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.table_section(&section)?;
}
let tables = section
.into_iter()
.map(|table| table.map(TableType::from_wasmparser).map_err(Error::from));
Expand All @@ -493,7 +510,9 @@ impl ModuleParser {
return Err(Error::from(EnforcedLimitsError::TooManyMemories { limit }));
}
}
self.validator.memory_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.memory_section(&section)?;
}
let memories = section
.into_iter()
.map(|memory| memory.map(MemoryType::from_wasmparser).map_err(Error::from));
Expand All @@ -508,7 +527,10 @@ impl ModuleParser {
/// This is part of the module linking Wasm proposal and not yet supported
/// by Wasmi.
fn process_tags(&mut self, section: wasmparser::TagSectionReader) -> Result<(), Error> {
self.validator.tag_section(&section).map_err(Into::into)
if let Some(validator) = &mut self.validator {
validator.tag_section(&section)?;
}
Ok(())
}

/// Process module global variable declarations.
Expand All @@ -530,7 +552,9 @@ impl ModuleParser {
return Err(Error::from(EnforcedLimitsError::TooManyGlobals { limit }));
}
}
self.validator.global_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.global_section(&section)?;
}
let globals = section
.into_iter()
.map(|global| global.map(Global::from).map_err(Error::from));
Expand All @@ -552,7 +576,9 @@ impl ModuleParser {
section: ExportSectionReader,
header: &mut ModuleHeaderBuilder,
) -> Result<(), Error> {
self.validator.export_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.export_section(&section)?;
}
let exports = section.into_iter().map(|export| {
let export = export?;
let field: Box<str> = export.name.into();
Expand All @@ -578,7 +604,9 @@ impl ModuleParser {
range: Range<usize>,
header: &mut ModuleHeaderBuilder,
) -> Result<(), Error> {
self.validator.start_section(func, &range)?;
if let Some(validator) = &mut self.validator {
validator.start_section(func, &range)?;
}
header.set_start(FuncIdx::from(func));
Ok(())
}
Expand Down Expand Up @@ -609,7 +637,9 @@ impl ModuleParser {
}));
}
}
self.validator.element_section(&section)?;
if let Some(validator) = &mut self.validator {
validator.element_section(&section)?;
}
let segments = section
.into_iter()
.map(|segment| segment.map(ElementSegment::from).map_err(Error::from));
Expand All @@ -631,9 +661,10 @@ impl ModuleParser {
}));
}
}
self.validator
.data_count_section(count, &range)
.map_err(Into::into)
if let Some(validator) = &mut self.validator {
validator.data_count_section(count, &range)?;
}
Ok(())
}

/// Process module linear memory data segments.
Expand All @@ -657,7 +688,12 @@ impl ModuleParser {
}));
}
}
self.validator.data_section(&section)?;
if let Some(validator) = &mut self.validator {
// Note: data section does not belong to the Wasm module header.
//
// Also benchmarks show that validation of the data section can be very costly.
validator.data_section(&section)?;
}
builder.reserve_data_segments(section.count() as usize);
for segment in section {
builder.push_data_segment(segment?)?;
Expand Down Expand Up @@ -700,7 +736,9 @@ impl ModuleParser {
}
}
}
self.validator.code_section_start(count, &range)?;
if let Some(validator) = &mut self.validator {
validator.code_section_start(count, &range)?;
}
Ok(())
}

Expand Down Expand Up @@ -731,16 +769,15 @@ impl ModuleParser {
fn process_code_entry(
&mut self,
func_body: FunctionBody,
validation_mode: ValidationMode,
bytes: &[u8],
header: &ModuleHeader,
) -> Result<(), Error> {
let (func, compiled_func) = self.next_func(header);
let module = header.clone();
let offset = func_body.get_binary_reader().original_position();
let func_to_validate = match validation_mode {
ValidationMode::All => Some(self.validator.code_section_entry(&func_body)?),
ValidationMode::HeaderOnly => None,
let func_to_validate = match &mut self.validator {
Some(validator) => Some(validator.code_section_entry(&func_body)?),
None => None,
};
self.engine
.translate_func(func, compiled_func, offset, bytes, module, func_to_validate)?;
Expand All @@ -753,8 +790,9 @@ impl ModuleParser {
///
/// This generally will be treated as an error for now.
fn process_unknown(&mut self, id: u8, range: Range<usize>) -> Result<(), Error> {
self.validator
.unknown_section(id, &range)
.map_err(Into::into)
if let Some(validator) = &mut self.validator {
validator.unknown_section(id, &range)?;
}
Ok(())
}
}

0 comments on commit 2e7a28f

Please sign in to comment.