summaryrefslogtreecommitdiff
path: root/examples/lua.rs
blob: f75348ac2684c77d95cd1076bc03050177268c4b (plain)
use std::sync::Arc;

use snob::csets::CharacterSet;
use snob::{csets, Scanner};

const EXAMPLE_LUA_PROGRAM: &str = r"
-- defines a factorial function
function fact (n)
	if n == 0 then
		return 1
	else
		return n * fact(n - 1)
	end
end

print('enter a number:')
a = io.read('*number')          -- read a number
print(fact(a))
";

#[derive(Debug, Clone)]
enum TokenKind {
	Comment(Arc<str>),
	Identifier(Arc<str>),

	// punctuator
	NotEqual,
	LessEqual,
	GreaterEqual,
	LessThan,
	GreaterThan,
	EqualEqual,
	Assignment,
	Plus,
	Minus,
	Star,
	Slash,
	Percent,
	LeftParenthesis,
	RightParenthesis,
	LeftSquareBracket,
	RightSquareBracket,
	LeftCurlyBrace,
	RightCurlyBrace,
	Semicolon,
	Comma,
	Dot,
	DotDot,
	DotDotDot,

	// literals
	StringLiteral(Arc<str>),
	NumberLiteral(f64),
}

#[derive(Debug, Clone)]
struct Token {
	start: usize,
	end: usize,
	kind: TokenKind,
}

#[derive(Debug, Clone)]
enum TokenErrorKind {
	UnterminatedString,
	InvalidToken,
}

#[derive(Debug, Clone)]
struct TokenError {
	start: usize,
	end: usize,
	kind: TokenErrorKind,
}

struct LuaScanner {
	scanner: Scanner,
}

impl LuaScanner {
	fn new(source: &str) -> Self {
		Self {
			scanner: Scanner::new(source),
		}
	}

	fn create_token(&self, start: usize, kind: TokenKind) -> Result<Token, TokenError> {
		Ok(Token {
			start,
			end: self.scanner.position(),
			kind,
		})
	}

	fn token_error(&self, start: usize, kind: TokenErrorKind) -> Result<Token, TokenError> {
		Err(TokenError {
			start,
			end: self.scanner.position(),
			kind,
		})
	}

	fn goto(&mut self, position: usize) -> String {
		self.scanner.goto(position).expect("a valid position")
	}

	fn escape_code(&mut self) -> Option<char> {
		let mut code = 0;
		let mut iterations = 0;
		while self.scanner.any(csets::AsciiDigits).is_some() {
			let digit = self.scanner.advance_char().expect("another character");
			code *= 8;
			code += (digit as u32) - ('0' as u32);
			iterations += 1;
		}

		if iterations > 0 {
			char::from_u32(code)
		} else if let Some(escape) = self.scanner.advance_char() {
			match escape {
				'a' => Some('\x07'),
				'b' => Some('\x08'),
				'f' => Some('\x0c'),
				'n' => Some('\n'),
				'r' => Some('\r'),
				't' => Some('\t'),
				'\\' => Some('\\'),
				'\"' => Some('\"'),
				'\'' => Some('\''),
				c => Some('c'),
			}
		} else {
			None
		}
	}

	fn string_literal(&mut self, start: usize) -> Result<Token, TokenError> {
		let mut builder = String::new();

		while let Some(position) = self.scanner.upto("\\\'") {
			builder.push_str(&self.goto(position));
			let next = self.scanner.advance_char().expect("another character");

			if next == '\'' {
				return self.create_token(start, TokenKind::StringLiteral(builder.into()));
			} else if next == '\\' {
				if let Some(escaped_char) = self.escape_code() {
					builder.push(escaped_char);
				}
			}
		}

		// unterminated string: skip the rest of the chunk
		self.goto(self.scanner.len());
		self.token_error(start, TokenErrorKind::UnterminatedString)
	}

	fn bracketed_string(&mut self, start: usize) -> Result<Token, TokenError> {
		let mut builder = String::new();
		let mut nesting = 1;

		while let Some(position) = self.scanner.upto("[]") {
			builder.push_str(&self.goto(position));

			if self.scanner.advance_if_starts_with("[[").is_some() {
				nesting += 1;
			} else if self.scanner.advance_if_starts_with("]]").is_some() {
				nesting -= 1;

				if nesting == 0 {
					return self.create_token(start, TokenKind::StringLiteral(builder.into()));
				}
			}
		}

		self.token_error(start, TokenErrorKind::UnterminatedString)
	}
}

impl Iterator for LuaScanner {
	type Item = Result<Token, TokenError>;

	fn next(&mut self) -> Option<Self::Item> {
		// shebang
		if self.scanner.position() == 0 && self.scanner.advance_if_starts_with("#").is_some() {
			let position = self.scanner.upto('\n').unwrap_or(self.scanner.len());
			self.goto(position);
		}

		// skip whitespace
		if let Some(position) = self.scanner.many(csets::AsciiWhitespace) {
			self.goto(position);
		}

		if self.scanner.is_at_end() {
			return None;
		}

		let start = self.scanner.position();

		// comment
		if self.scanner.advance_if_starts_with("--").is_some() {
			let position = self.scanner.upto('\n').unwrap_or(self.scanner.len());
			let comment = self.goto(position);
			self.scanner.advance_or_goto_end(1); // skip the newline
			return Some(self.create_token(start, TokenKind::Comment(comment.into())));
		}

		// identifiers
		if self.scanner.any(csets::Alphabetic.union('_')).is_some() {
			let identifier = self.goto(
				self.scanner
					.many(csets::Alphanumeric.union('_'))
					.expect("alphanumeric characters"),
			);
			return Some(self.create_token(start, TokenKind::Identifier(identifier.into())));
		}

		// punctuators
		if self.scanner.advance_if_starts_with("...").is_some() {
			return Some(self.create_token(start, TokenKind::DotDotDot));
		} else if self.scanner.advance_if_starts_with("~=").is_some() {
			return Some(self.create_token(start, TokenKind::NotEqual));
		} else if self.scanner.advance_if_starts_with("<=").is_some() {
			return Some(self.create_token(start, TokenKind::LessEqual));
		} else if self.scanner.advance_if_starts_with(">=").is_some() {
			return Some(self.create_token(start, TokenKind::EqualEqual));
		} else if self.scanner.advance_if_starts_with("..").is_some() {
			return Some(self.create_token(start, TokenKind::DotDot));
		} else if self.scanner.advance_if_starts_with("<").is_some() {
			return Some(self.create_token(start, TokenKind::LessThan));
		} else if self.scanner.advance_if_starts_with(">").is_some() {
			return Some(self.create_token(start, TokenKind::GreaterThan));
		} else if self.scanner.advance_if_starts_with("=").is_some() {
			return Some(self.create_token(start, TokenKind::Assignment));
		} else if self.scanner.advance_if_starts_with("+").is_some() {
			return Some(self.create_token(start, TokenKind::Plus));
		} else if self.scanner.advance_if_starts_with("-").is_some() {
			return Some(self.create_token(start, TokenKind::Minus));
		} else if self.scanner.advance_if_starts_with("*").is_some() {
			return Some(self.create_token(start, TokenKind::Star));
		} else if self.scanner.advance_if_starts_with("/").is_some() {
			return Some(self.create_token(start, TokenKind::Slash));
		} else if self.scanner.advance_if_starts_with("%").is_some() {
			return Some(self.create_token(start, TokenKind::Percent));
		} else if self.scanner.advance_if_starts_with("(").is_some() {
			return Some(self.create_token(start, TokenKind::LeftParenthesis));
		} else if self.scanner.advance_if_starts_with(")").is_some() {
			return Some(self.create_token(start, TokenKind::RightParenthesis));
		} else if self.scanner.advance_if_starts_with("{").is_some() {
			return Some(self.create_token(start, TokenKind::LeftCurlyBrace));
		} else if self.scanner.advance_if_starts_with("}").is_some() {
			return Some(self.create_token(start, TokenKind::RightCurlyBrace));
		} else if self.scanner.advance_if_starts_with("[").is_some() {
			return Some(self.create_token(start, TokenKind::LeftSquareBracket));
		} else if self.scanner.advance_if_starts_with("]").is_some() {
			return Some(self.create_token(start, TokenKind::RightSquareBracket));
		} else if self.scanner.advance_if_starts_with(";").is_some() {
			return Some(self.create_token(start, TokenKind::Semicolon));
		} else if self.scanner.advance_if_starts_with(",").is_some() {
			return Some(self.create_token(start, TokenKind::Comma));
		} else if self.scanner.advance_if_starts_with(".").is_some() {
			return Some(self.create_token(start, TokenKind::Dot));
		}

		if self.scanner.starts_with("[[").is_some() {
			return Some(self.bracketed_string(start));
		}

		if let Some(position) = self.scanner.any('\'') {
			self.goto(position);
			return Some(self.string_literal(start));
		}

		if let Some(position) = self.scanner.many(csets::AsciiDigits) {
			let int_part = self.goto(position);

			let frac_part = if self.scanner.advance_if_starts_with(".").is_some() {
				let position = self
					.scanner
					.many(csets::AsciiDigits)
					.unwrap_or(self.scanner.position());
				Some(self.goto(position))
			} else {
				None
			}
			.unwrap_or("0".to_string());

			let exp_part = if let Some(position) = self.scanner.any("Ee") {
				self.goto(position);
				let position = self.scanner.any("+-").unwrap_or(self.scanner.position());
				let sign = self.goto(position);

				let position = self
					.scanner
					.many(csets::AsciiDigits)
					.unwrap_or(self.scanner.position());
				Some((self.goto(position), sign))
			} else {
				None
			}
			.map(|(exp_part, sign)| format!("{sign}{exp_part}"))
			.unwrap_or("1".to_string());

			let number: f64 = format!("{int_part}.{frac_part}e{exp_part}")
				.parse()
				.expect("a number");
			return Some(self.create_token(start, TokenKind::NumberLiteral(number)));
		}

		// invalid tokens
		let next_token_cset = csets::AsciiAlphanumeric
			.union(csets::AsciiWhitespace)
			.union('_');
		let position = self
			.scanner
			.upto(next_token_cset)
			.unwrap_or(self.scanner.len());
		self.goto(position);
		Some(self.token_error(start, TokenErrorKind::InvalidToken))
	}
}

fn main() {
	println!(
		"{:?}",
		LuaScanner::new(EXAMPLE_LUA_PROGRAM).collect::<Vec<Result<Token, TokenError>>>()
	)
}