From 12840c12f9025b747cadf993f2b597b10c1a9e62 Mon Sep 17 00:00:00 2001 From: Yaossg Date: Mon, 18 Nov 2024 09:48:28 +0800 Subject: [PATCH] non-pointer arith check --- README.md | 2 ++ boot.c | 91 ++++++++++++++++++++++++++++++++--------------------- demo/sort.c | 2 ++ 3 files changed, 60 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 387d771..75c4ba8 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,8 @@ $ sh boot.sh | boot.c boot-lib.h | boot1.out | boot2.s | boot2.out | 自举自举自制编译器 | | boot.c boot-lib.h | boot2.out | boot3.s | | 验证自举自举自制编译器 | +后三次编译时,boot-lib.h 的内容被手动导入 boot.c 开头进行编译,boot-lib.c 提供的库通过链接引入。 + 自举的目标为 boot1.s == boot2.s == boot3.s ## 语言文档 diff --git a/boot.c b/boot.c index 6714bd6..025452d 100644 --- a/boot.c +++ b/boot.c @@ -80,6 +80,7 @@ const int TYPE_INT_PTR = 17; const int TYPE_CHAR_PTR = 18; const int TYPE_PTR_MASK = 16; +const int TYPE_FUNC_MASK = 32; const int TYPE_TOKEN_MASK = 128; int parse_int(int ch) { @@ -435,7 +436,6 @@ int max_local_id = 2; const int MARKER_TEMP = 0; const int MARKER_SCALAR = 1; const int MARKER_ARRAY = 2; -const int MARKER_FUNCTION = 3; int local_marker[4096]; int global_marker[4096]; @@ -553,7 +553,7 @@ void load(int rd, int id) { const char* op = "lw"; // int if (type == TYPE_CHAR) { op = "lb"; - } else if (type & TYPE_PTR_MASK) { + } else if (type & TYPE_PTR_MASK || type & TYPE_FUNC_MASK) { op = "ld"; } printf(" %s t%d, 0(t%d) # id: type %d\n", op, rd, rd, type); @@ -565,7 +565,7 @@ void store_t0(int id) { const char* op = "sw"; // int if (type == TYPE_CHAR) { op = "sb"; - } else if (type & TYPE_PTR_MASK) { + } else if (type & TYPE_PTR_MASK || type & TYPE_FUNC_MASK) { op = "sd"; } printf(" %s t0, 0(t1) # id: type %d\n", op, type); @@ -599,9 +599,6 @@ int lookup(int id) { if (global_marker[id] == MARKER_SCALAR) { reg = dereference(reg); } - if (global_marker[id] == MARKER_FUNCTION) { - indirection[reg] = 1; - } return reg; } eprintf("unresolved identifier: %s\n", name); @@ -619,30 +616,58 @@ int asm_label(int label) { return label; } -int is_not_reusable(int rs1) { - return indirection[rs1] || local_marker[rs1] != MARKER_TEMP; +int is_not_reusable(int rs1, int expected_type) { + return indirection[rs1] || local_marker[rs1] != MARKER_TEMP || local_type[rs1] != expected_type; } int asm_r(const char* op, int rs1) { load(0, rs1); printf(" %s t0, t0\n", op); int rd = rs1; - if (is_not_reusable(rs1)) rd = next_reg(local_type[rs1]); + if (is_not_reusable(rs1, TYPE_INT)) { + rd = next_reg(TYPE_INT); + } store_t0(rd); return rd; } +int asm_r_arith(const char* op, int rs1) { + if (local_type[rs1] & TYPE_PTR_MASK || local_type[rs1] & TYPE_FUNC_MASK) { + eprintf("pointer cannot be arithmetically operated by %s\n", op); + exit(1); + } + return asm_r(op, rs1); +} + int asm_rr(const char* op, int rs1, int rs2) { load(0, rs1); load(1, rs2); printf(" %s t0, t0, t1\n", op); int rd = rs1; - if (is_not_reusable(rs1)) rd = rs2; - if (is_not_reusable(rs2)) rd = next_reg(local_type[rs1]); + if (is_not_reusable(rd, TYPE_INT)) { + rd = rs2; + if (is_not_reusable(rd, TYPE_INT)) { + rd = next_reg(TYPE_INT); + } + } store_t0(rd); return rd; } +int asm_rr_arith(const char* op, int rs1, int rs2) { + if (local_type[rs1] & TYPE_PTR_MASK || local_type[rs2] & TYPE_PTR_MASK + || local_type[rs1] & TYPE_FUNC_MASK || local_type[rs2] & TYPE_FUNC_MASK) { + eprintf("pointer cannot be arithmetically operated by %s\n", op); + exit(1); + } + return asm_rr(op, rs1, rs2); +} + +int asm_rr_cmp(const char* op, int rs1, int rs2) { + // since NULL is virtually 0, it is considered valid example of a pointer comparing with an integer + return asm_rr(op, rs1, rs2); +} + void asm_beqz(int rs1, int label) { load(0, rs1); printf(" beqz t0, L%d\n", label); @@ -718,7 +743,7 @@ int asm_add(int lhs, int rhs) { return materialize_t0(ptr_type); } if (type1 && type2) { - eprintf("operands cannot be both pointers\n"); + eprintf("operands of addition cannot be both pointers\n"); exit(1); } return asm_rr("add", lhs, rhs); @@ -745,13 +770,9 @@ int asm_sub(int lhs, int rhs) { return materialize_t0(TYPE_INT); } if (type1) { - int neg = asm_r("neg", rhs); + int neg = asm_r_arith("neg", rhs); return asm_add(lhs, neg); } - if (type2) { - eprintf("a number cannot be subtracted by a pointer\n"); - exit(1); - } return asm_rr("sub", lhs, rhs); } @@ -775,7 +796,7 @@ int parse_primary_expr() { expect_token(TOKEN_PAREN_RIGHT); return reg; } else { - eprintf("unexpected token: %d\n", token_type); + eprintf("unexpected token in primary expression: %d\n", token_type); exit(1); } } @@ -878,10 +899,10 @@ int parse_prefix_expr() { return dereference(materialize_t0(type)); } else if (token_type == TOKEN_MINUS) { int reg = parse_postfix_expr(); - return asm_r("neg", reg); + return asm_r_arith("neg", reg); } else if (token_type == TOKEN_COMPL) { int reg = parse_postfix_expr(); - return asm_r("not", reg); + return asm_r_arith("not", reg); } else if (token_type == TOKEN_NOT) { int reg = parse_postfix_expr(); return asm_r("seqz", reg); @@ -909,13 +930,13 @@ int parse_mul_expr() { next_token(); if (token_type == TOKEN_STAR) { int rhs = parse_prefix_expr(); - lhs = asm_rr("mul", lhs, rhs); + lhs = asm_rr_arith("mul", lhs, rhs); } else if (token_type == TOKEN_DIV) { int rhs = parse_prefix_expr(); - lhs = asm_rr("div", lhs, rhs); + lhs = asm_rr_arith("div", lhs, rhs); } else if (token_type == TOKEN_REM) { int rhs = parse_prefix_expr(); - lhs = asm_rr("rem", lhs, rhs); + lhs = asm_rr_arith("rem", lhs, rhs); } else { unget_token(); break; @@ -948,10 +969,10 @@ int parse_shift_expr() { next_token(); if (token_type == TOKEN_LSHIFT) { int rhs = parse_add_expr(); - lhs = asm_rr("sll", lhs, rhs); + lhs = asm_rr_arith("sll", lhs, rhs); } else if (token_type == TOKEN_RSHIFT) { int rhs = parse_add_expr(); - lhs = asm_rr("sra", lhs, rhs); + lhs = asm_rr_arith("sra", lhs, rhs); } else { unget_token(); break; @@ -966,17 +987,17 @@ int parse_cmp_expr() { next_token(); if (token_type == TOKEN_LT) { int rhs = parse_shift_expr(); - lhs = asm_rr("slt", lhs, rhs); + lhs = asm_rr_cmp("slt", lhs, rhs); } else if (token_type == TOKEN_GT) { int rhs = parse_shift_expr(); - lhs = asm_rr("sgt", lhs, rhs); + lhs = asm_rr_cmp("sgt", lhs, rhs); } else if (token_type == TOKEN_LE) { int rhs = parse_shift_expr(); - int sgt = asm_rr("sgt", lhs, rhs); + int sgt = asm_rr_cmp("sgt", lhs, rhs); lhs = asm_r("seqz", sgt); } else if (token_type == TOKEN_GE) { int rhs = parse_shift_expr(); - int slt = asm_rr("slt", lhs, rhs); + int slt = asm_rr_cmp("slt", lhs, rhs); lhs = asm_r("seqz", slt); } else { unget_token(); @@ -992,11 +1013,11 @@ int parse_eq_expr() { next_token(); if (token_type == TOKEN_EQ) { int rhs = parse_cmp_expr(); - int xor0 = asm_rr("xor", lhs, rhs); + int xor0 = asm_rr_cmp("xor", lhs, rhs); lhs = asm_r("seqz", xor0); } else if (token_type == TOKEN_NE) { int rhs = parse_cmp_expr(); - int xor0 = asm_rr("xor", lhs, rhs); + int xor0 = asm_rr_cmp("xor", lhs, rhs); lhs = asm_r("snez", xor0); } else { unget_token(); @@ -1012,7 +1033,7 @@ int parse_bitwise_and_expr() { next_token(); if (token_type == TOKEN_AND) { int rhs = parse_eq_expr(); - lhs = asm_rr("and", lhs, rhs); + lhs = asm_rr_arith("and", lhs, rhs); } else { unget_token(); break; @@ -1028,7 +1049,7 @@ int parse_bitwise_xor_expr() { next_token(); if (token_type == TOKEN_XOR) { int rhs = parse_bitwise_and_expr(); - lhs = asm_rr("xor", lhs, rhs); + lhs = asm_rr_arith("xor", lhs, rhs); } else { unget_token(); break; @@ -1043,7 +1064,7 @@ int parse_bitwise_or_expr() { next_token(); if (token_type == TOKEN_OR) { int rhs = parse_bitwise_xor_expr(); - lhs = asm_rr("or", lhs, rhs); + lhs = asm_rr_arith("or", lhs, rhs); } else { unget_token(); break; @@ -1420,7 +1441,7 @@ void parse_global_declaration() { char* name = id_table + id_lut[id]; next_token(); if (token_type == TOKEN_PAREN_LEFT) { - declare_global(id, MARKER_FUNCTION, type); + declare_global(id, MARKER_SCALAR, type | TYPE_FUNC_MASK); parse_function(name); } else { declare_global(id, MARKER_SCALAR, type); diff --git a/demo/sort.c b/demo/sort.c index 45fe69b..b1db39d 100644 --- a/demo/sort.c +++ b/demo/sort.c @@ -16,7 +16,9 @@ void sort(int a[], int n) { int main() { int n; int a[100]; + printf("Enter the number of elements in the array: "); scanf("%d", &n); + printf("Enter the elements of the array: "); for (int i = 0; i < n; i++) { scanf("%d", &a[i]); }