summaryrefslogtreecommitdiff
path: root/src/tokens.rs
blob: 4654ddf6eed50d686f31c06fc0351a279aca279f (plain)
use std::fmt::Display;
use std::sync::Arc;

use rust_decimal::{Decimal, MathematicalOps};
use snob::csets;

#[derive(Debug)]
pub struct Lexer {
	scanner: snob::Scanner,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Token {
	pub span: Span,
	pub ty: TokenType,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Span {
	pub start: usize,
	pub end: usize,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TokenType {
	Whitespace(Arc<str>),
	LineComment(Arc<str>),
	BlockComment { comment: Arc<str>, terminated: bool },

	LeftParenthesis,
	RightParenthesis,

	Apostrophe,
	Pound,
	Dot,

	Identifier(Arc<str>),
	String { content: Arc<str>, terminated: bool },
	Number(Decimal),
}

impl Display for TokenType {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		match self {
			Self::Whitespace(string) => f.write_str(string),
			Self::LineComment(string) => f.write_str(string),
			Self::BlockComment {
				comment,
				terminated: _,
			} => f.write_str(comment),
			Self::LeftParenthesis => f.write_str("("),
			Self::RightParenthesis => f.write_str(")"),
			Self::Apostrophe => f.write_str("'"),
			Self::Pound => f.write_str("#"),
			Self::Dot => f.write_str("."),
			Self::Identifier(ident) => f.write_str(ident),
			Self::String {
				content,
				terminated: _,
			} => f.write_str(content),
			Self::Number(number) => f.write_str(&number.to_string()),
		}
	}
}

impl Display for Token {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		self.ty.fmt(f)
	}
}

impl Token {
	fn new(start: usize, end: usize, ty: TokenType) -> Self {
		Self {
			span: Span { start, end },
			ty,
		}
	}

	pub fn is_whitespace(&self) -> bool {
		matches!(self.ty, TokenType::Whitespace(_))
	}

	pub fn contains_newline(&self) -> bool {
		if let TokenType::Whitespace(space) = &self.ty {
			space.contains('\n')
		} else {
			false
		}
	}

	pub fn is_comment(&self) -> bool {
		matches!(
			self.ty,
			TokenType::LineComment(_) | TokenType::BlockComment { .. }
		)
	}

	pub fn is_left_parenthesis(&self) -> bool {
		self.ty == TokenType::LeftParenthesis
	}

	pub fn is_right_parenthesis(&self) -> bool {
		self.ty == TokenType::RightParenthesis
	}

	pub fn is_apostrophe(&self) -> bool {
		self.ty == TokenType::Apostrophe
	}

	pub fn is_pound(&self) -> bool {
		self.ty == TokenType::Pound
	}

	pub fn is_dot(&self) -> bool {
		self.ty == TokenType::Dot
	}

	pub fn is_identifier(&self) -> bool {
		self.raw_identifier().is_some()
	}

	pub fn is_number(&self) -> bool {
		self.number().is_some()
	}

	pub fn is_string(&self) -> bool {
		self.raw_string().is_some()
	}

	pub fn raw_identifier(&self) -> Option<&str> {
		if let TokenType::Identifier(identifier) = &self.ty {
			Some(identifier)
		} else {
			None
		}
	}

	pub fn identifier(&self) -> Option<String> {
		Some(self.raw_identifier()?.to_uppercase())
	}

	pub fn number(&self) -> Option<&Decimal> {
		if let TokenType::Number(number) = &self.ty {
			Some(number)
		} else {
			None
		}
	}

	pub fn raw_string(&self) -> Option<&str> {
		if let TokenType::String { content, .. } = &self.ty {
			Some(content)
		} else {
			None
		}
	}

	pub fn computed_string(&self) -> Option<String> {
		enum State {
			Default,
			Backslash,
			Unicode {
				remaining_chars: usize,
				current: u32,
			},
		}

		fn handle_default_state(char: char, computed: &mut String, state: &mut State) {
			if char == '\\' {
				*state = State::Backslash;
			} else {
				computed.push(char);
			}
		}

		fn handle_backslash_state(char: char, computed: &mut String, state: &mut State) {
			match char {
				'u' => {
					*state = State::Unicode {
						remaining_chars: 6,
						current: 0,
					}
				}
				'x' => {
					*state = State::Unicode {
						remaining_chars: 2,
						current: 0,
					}
				}
				'n' => {
					*state = State::Default;
					computed.push('\n');
				}
				't' => {
					*state = State::Default;
					computed.push('\t');
				}
				_ => {
					*state = State::Default;
					computed.push(char);
				}
			}
		}

		fn handle_unicode_state(
			char: char,
			remaining_chars: usize,
			current: u32,
			computed: &mut String,
			state: &mut State,
		) {
			let digit = match char {
				'0' => 0,
				'1' => 1,
				'2' => 2,
				'3' => 3,
				'4' => 4,
				'5' => 5,
				'6' => 6,
				'7' => 7,
				'8' => 8,
				'9' => 9,
				'a' | 'A' => 10,
				'b' | 'B' => 11,
				'c' | 'C' => 12,
				'd' | 'D' => 13,
				'e' | 'E' => 14,
				'f' | 'F' => 15,
				_ => {
					*state = State::Default;
					return;
				}
			};

			if remaining_chars == 0 {
				let charcode = current * 16 + digit;
				computed.push(char::from_u32(charcode).unwrap_or(char::REPLACEMENT_CHARACTER));
				*state = State::Default;
			} else {
				*state = State::Unicode {
					remaining_chars: remaining_chars - 1,
					current: current * 16 + digit,
				};
			}
		}

		let TokenType::String { content, .. } = &self.ty else {
			return None;
		};

		let mut computed = String::new();
		let mut state = State::Default;
		for char in content.chars() {
			match state {
				State::Default => handle_default_state(char, &mut computed, &mut state),
				State::Backslash => handle_backslash_state(char, &mut computed, &mut state),
				State::Unicode {
					remaining_chars,
					current,
				} => handle_unicode_state(char, remaining_chars, current, &mut computed, &mut state),
			}
		}

		Some(computed)
	}
}

impl Lexer {
	pub fn new(str: &str) -> Self {
		Self {
			scanner: snob::Scanner::new(str),
		}
	}

	fn goto(&mut self, position: usize) -> String {
		self.scanner
			.goto(position)
			.expect("The position should be valid")
	}

	fn simple_token(&self, start: usize, ty: TokenType) -> Token {
		Token::new(start, self.scanner.position(), ty)
	}

	fn scan_block_comment(&mut self, start: usize) -> Token {
		self.scanner.advance_if_starts_with("#|");
		let mut comment = String::new();
		let mut terminated = false;

		while let Some(position) = self.scanner.upto('|') {
			comment.push_str(&self.goto(position));

			if self.scanner.advance_if_starts_with("|#").is_some() {
				terminated = true;
				break;
			}
		}

		if !terminated {
			comment.push_str(&self.goto(self.scanner.len()));
		}

		self.simple_token(
			start,
			TokenType::BlockComment {
				comment: comment.into(),
				terminated,
			},
		)
	}

	fn scan_string(&mut self, start: usize) -> Token {
		let mut content = String::new();
		let mut terminated = false;

		if let Some(position) = self.scanner.any('"') {
			self.goto(position);
		}

		while let Some(position) = self.scanner.upto("\\\"") {
			content.push_str(&self.goto(position));

			if self.scanner.advance_if_starts_with("\"").is_some() {
				terminated = true;
				break;
			}

			let backslash = self.scanner.advance(1).expect("we found a backslash");
			content.push_str(&backslash);
			if let Some(c) = self.scanner.advance(1) {
				content.push_str(&c)
			}
		}

		if !terminated {
			content.push_str(&self.goto(self.scanner.len()));
		}

		self.simple_token(
			start,
			TokenType::String {
				content: content.into(),
				terminated,
			},
		)
	}

	fn scan_digit(&mut self) -> Option<Decimal> {
		let digit = self.scanner.advance(1)?;
		let digit = match digit.as_str() {
			"0" => Decimal::from(0),
			"1" => Decimal::from(1),
			"2" => Decimal::from(2),
			"3" => Decimal::from(3),
			"4" => Decimal::from(4),
			"5" => Decimal::from(5),
			"6" => Decimal::from(6),
			"7" => Decimal::from(7),
			"8" => Decimal::from(8),
			"9" => Decimal::from(9),
			_ => return None,
		};

		Some(digit)
	}

	fn scan_decimal_number(&mut self) -> Decimal {
		let mut number = Decimal::ZERO;
		while self.scanner.any(csets::AsciiDigits).is_some() {
			let digit = self.scan_digit().expect("we saw there's a digit here");
			number = number * Decimal::TEN + digit;
		}

		number
	}

	fn scan_octal_number(&mut self) -> Decimal {
		let mut number = Decimal::ZERO;
		while self.scanner.any("01234567").is_some() {
			let digit = self.scan_digit().expect("we saw there's a digit here");
			number = number * Decimal::TEN + digit;
		}

		number
	}

	fn scan_number(&mut self, start: usize) -> Token {
		let mut sign = Decimal::ONE;
		let mut fraction_numerator = Decimal::ZERO;
		let mut fraction_denominator = Decimal::ONE;
		let mut exponent = Decimal::ZERO;
		let mut octal_exponent = Decimal::ZERO;

		self.scanner.advance_if_starts_with("+");
		if self.scanner.advance_if_starts_with("-").is_some() {
			sign = Decimal::NEGATIVE_ONE;
		}

		let whole_part = if self.scanner.advance_if_starts_with("0o").is_some() {
			self.scan_octal_number()
		} else {
			self.scan_decimal_number()
		};

		if self.scanner.advance_if_starts_with(".").is_some() {
			while self.scanner.any(csets::AsciiDigits).is_some() {
				let digit = self.scan_digit().expect("we saw that there's a digit here");
				fraction_numerator = fraction_numerator * Decimal::TEN + digit;
				fraction_denominator *= Decimal::TEN;
			}
		}

		if let Some(position) = self.scanner.any("eE") {
			let mut is_negative = false;

			self.goto(position);
			self.scanner.advance_if_starts_with("+");
			if self.scanner.advance_if_starts_with("-").is_some() {
				is_negative = true;
			}

			exponent = self.scan_decimal_number();
			if is_negative {
				exponent *= Decimal::NEGATIVE_ONE;
			}
		}

		if let Some(position) = self.scanner.any("qQ") {
			let mut is_negative = false;

			self.goto(position);
			self.scanner.advance_if_starts_with("+");
			if self.scanner.advance_if_starts_with("-").is_some() {
				is_negative = true;
			}

			octal_exponent = self.scan_decimal_number();
			if is_negative {
				octal_exponent *= Decimal::NEGATIVE_ONE;
			}
		}

		let number = sign
			* (whole_part + fraction_numerator / fraction_denominator)
			* (Decimal::TEN.powd(exponent))
			* (Decimal::from(8).powd(octal_exponent));
		self.simple_token(start, TokenType::Number(number))
	}
}

impl Iterator for Lexer {
	type Item = Token;

	fn next(&mut self) -> Option<Token> {
		let start = self.scanner.position();

		if self.scanner.is_at_end() {
			return None;
		}

		Some(if let Some(whitespace) = self.scanner.many(" \t\r\n\x11") {
			let whitespace = self.goto(whitespace);
			self.simple_token(start, TokenType::Whitespace(whitespace.into()))
		} else if let Some(semicolon) = self.scanner.any(';') {
			self.goto(semicolon);
			let position = self.scanner.upto('\n').unwrap_or(self.scanner.len());
			let comment = self.goto(position);
			self.simple_token(start, TokenType::LineComment(comment.into()))
		} else if self.scanner.starts_with("#|").is_some() {
			self.scan_block_comment(start)
		} else if self.scanner.advance_if_starts_with("(").is_some() {
			self.simple_token(start, TokenType::LeftParenthesis)
		} else if self.scanner.advance_if_starts_with(")").is_some() {
			self.simple_token(start, TokenType::RightParenthesis)
		} else if self.scanner.advance_if_starts_with("'").is_some() {
			self.simple_token(start, TokenType::Apostrophe)
		} else if self.scanner.advance_if_starts_with("#").is_some() {
			self.simple_token(start, TokenType::Pound)
		} else if self.scanner.advance_if_starts_with(".").is_some() {
			self.simple_token(start, TokenType::Dot)
		} else if self.scanner.any('"').is_some() {
			self.scan_string(start)
		} else if self.scanner.any(csets::AsciiDigits).is_some()
			|| (self.scanner.any("+-").is_some()
				&& self
					.scanner
					.char_at(self.scanner.position() + 1)
					.is_some_and(|char| char.is_ascii_digit()))
		{
			self.scan_number(start)
		} else {
			let position = self
				.scanner
				.upto(" \t\r\n\x11;#()'#.\"")
				.unwrap_or(self.scanner.len());
			let identifier = self.goto(position);
			self.simple_token(start, TokenType::Identifier(identifier.into()))
		})
	}
}