2024-12-10

「低レイヤを知りたい人のためのCコンパイラ作成入門」第十五回勉強会まとめ

学習範囲

ステップ7: 比較演算子のアセンブリコードの生成

分割コンパイルとリンク

アセンブリコードの生成

gen関数に<、<=、>、>=、==、!=に対応した、アセンブリコードが出力されるように追記します。

私はM1を使用しているので、ARM64 GCC 14.2.0環境でのアセンブリコードとなります。

こちらをもとにコードを記述しました。

    case ND_EQ:
      printf(" cmp w0, w1\n");
      printf(" cset w0, eq\n");
      break;
    case ND_NE:
      printf(" cmp w0, w1\n");
      printf(" cset w0, ne\n");
      break;
    case ND_LT:
      printf(" cmp w0, w1\n");
      printf(" cset w0, lt\n");
      break;
    case ND_LTE:
      printf(" cmp w0, w1\n");
      printf(" cset w0, le\n");
      break;

gen関数の全体のコードは以下のようになります。

void gen(Node *node) {
  printf("// kind %d\n", node->kind);
  if (node->kind == ND_NUM) {
    printf(" mov w%d, %d\n", global_register_count, node->val);
    printf(" str w%d, [sp, #-16]!\n", global_register_count);
    // global_register_count;
    return;
  }

  gen(node->lhs);
  gen(node->rhs);

  printf(" ldr w1, [sp], #16\n");
  printf(" ldr w0, [sp], #16\n");

  switch (node->kind) {
    case ND_ADD:
      printf(" add w0, w0, w1\n");
      break;
    case ND_SUB:
      printf(" sub w0, w0, w1\n");
      break;
    case ND_MUL:
      printf(" mul w0, w0, w1\n");
      break;
    case ND_DIV:
      printf(" sdiv w0, w0, w1\n");
      break;
    case ND_EQ:
      printf(" cmp w0, w1\n");
      printf(" cset w0, eq\n");
      break;
    case ND_NE:
      printf(" cmp w0, w1\n");
      printf(" cset w0, ne\n");
      break;
    case ND_LT:
      printf(" cmp w0, w1\n");
      printf(" cset w0, lt\n");
      break;
    case ND_LTE:
      printf(" cmp w0, w1\n");
      printf(" cset w0, le\n");
      break;
  }
  printf(" str w0, [sp, #-16]!\n");
}

テストしてみます

assert 1 "0==0"
assert 1 "2!=1"
assert 1 "2<3"
assert 1 "2<=3"
assert 0 "2>3"
assert 0 "2>=3"

これで比較演算子を含む高度な計算ができるようになりました!

分割コンパイル

これまで、全てのCコードを9cc.cに記述していました。コードを役割ごとにまとめて管理しやすくなるように複数のCファイルに分割します。

9cc.h: ヘッダファイル
main.c: main関数
parse.c: パーサ
codegen.c: コードジェネレータ

9cc.h

#include <stdbool.h>

typedef struct Node Node;

typedef enum {
  TK_RESERVED,
  TK_NUM,
  TK_EOF,
} TokenKind;

typedef struct Token Token;

typedef enum {
    ND_ADD, //+
    ND_SUB, //-
    ND_MUL, //*
    ND_DIV, // /
    ND_NUM, // 整数
    ND_EQ, // ==
    ND_NE, // !=
    ND_LT, // <
    ND_LTE, // <=
} NodeKind;

typedef struct Node Node;

// トークン型
struct Token {
  TokenKind kind; //トークンの型
  Token *next;    //次の入力トークン
  int val;        //kindがTK_NUMの場合、その数値
  char *str;      //トークン文字列
  int len;        //トークンの長さ
};

//抽象構文木のノードの型
struct Node {
  NodeKind kind; //ノードの型
  Node *lhs; //左辺
  Node *rhs; //右辺
  int val; //kindがNO_NUMの場合のみ使う
};

Node *add();
bool consume(char *op);
Node *unary();
Node *mul();
Node *relational();
Node *equality();

main.c

#include "9cc.h"

char *user_input;

Token *token;
Node *node;

int main(int argc, char **argv) {
  if (argc != 2) {
    error("引数の個数が正しくありません");
    return 1;
  }

  //トークナイズしてパースする
  user_input = argv[1];
  token = tokenize(argv[1]);
  Node *node = expr();

  //アセンブリの前半部分を出力
  printf(".globl main\n");
  printf("main:\n");

  //抽象構文木を下りながらコード生成
  gen(node);

  //スタックトップに式全体の値が残っているはずなので
  printf(" ret\n");
  return 0;
}

parse.c

#include "9cc.h"
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include "main.c"

//エラー箇所を報告する
void error_at(char *loc, char *fmt, ...) {
  va_list ap;
  va_start(ap, fmt);

  int pos = loc - user_input;
  fprintf(stderr, "%s\n", user_input);
  fprintf(stderr, "%*s", pos, " ");//pos個の空白を出力
  fprintf(stderr, "^ ");
  vfprintf(stderr, fmt, ap);
  fprintf(stderr, "\n");
  exit(1);
}

// エラーを報告するための関数
// printfと同じ引数を取る
void error(char *fmt, ...) {
  va_list ap;
  va_start(ap, fmt);
  vfprintf(stderr, fmt, ap);
  fprintf(stderr, "\n");
  exit(1);
}

//次のトークンが期待している記号のときには、トークンを1つ読み進めて
//真を返す。それ以外の場合には偽をけ返す。
bool consume(char *op) {
  if (token->kind != TK_RESERVED ||
      strlen(op) != token->len ||
      //メモリーコンペア  
      memcmp(token->str, op, token->len))
    return false;
  token = token->next;
  return true;
}

//次の次のトークンが期待している記号のときには、トークンを1つ読み進める。
//それ以外の場合にはエラーを報告する。
void expect(char *op) {
  if (token->kind != TK_RESERVED || strlen(op) != token->len ||
    memcmp(token->str, op, token->len))
    error_at(token->str, "'%c'ではありません", op);
  token = token->next;
}

//次のトークンが数値の場合、トークンを1つ読み進めてその数値を返す。
//ソレ以外の場合にはエラーを報告する。
int expect_number() {
  if (token->kind != TK_NUM)
    error_at(token->str, "数ではありません");
  int val = token->val;
  token = token->next;
  return val;

}

bool at_eof() {
  return token->kind == TK_EOF;
}

//新しいトークンを作成してcurに繋げる
Token *new_token(TokenKind kind, Token *cur, char *str, int len) {
  Token *tok = calloc(1, sizeof(Token));
  tok->kind = kind;
  tok->str = str;
  tok->len = len;
  cur->next = tok;
  return tok;
}

bool startswith(char *p, char *q) {
  return memcmp(p, q, strlen(q)) == 0;
}

//入力文字列pをトークナイズしてそれを返す
Token *tokenize(char *p) {
  Token head;
  head.next = NULL;
  Token *cur = &head;

  while (*p) {
    //空白文字をスキップ
    if (isspace(*p)) {
      p++;
      continue;
    }

    // Multi-letter punctuator
    if (startswith(p, "==") || startswith(p, "!=") ||
        startswith(p, "<=") || startswith(p, ">=")) {
      cur = new_token(TK_RESERVED, cur, p, 2);
      p += 2;
      continue;
    }

     // Single-letter punctuator
    if (strchr("+-*/()<>", *p)) {
      cur = new_token(TK_RESERVED, cur, p++, 1);
      continue;
    }

    // Integer literal
    if (isdigit(*p)) {
      cur = new_token(TK_NUM, cur, p, 0);
      char *q = p;
      cur->val = strtol(p, &p, 10);
      cur->len = p - q;
      continue;
    }
    // printf("foo");
    error_at(p, "トークナイズできません");
  }

  new_token(TK_EOF, cur, p, 0);
  return head.next;
}

Node *new_node(NodeKind kind, Node *lhs, Node *rhs) {
  Node *node = calloc(1, sizeof(Node));
  node->kind = kind;
  node->lhs = lhs;
  node->rhs = rhs;
  return node;
}

Node *new_node_num(int val) {
  Node *node = calloc(1, sizeof(Node));
  node->kind = ND_NUM;
  node->val = val;
  return node;
}

Node *expr() {
  return equality();
}

Node *primary() {
  //次のトークンが"("なら、 "(" expr ")"のはず
  if (consume("(")) {
    Node *node = expr();
    expect(")");
    return node;
  }

  //そうでなければ数値のはず
  return new_node_num(expect_number());
}

Node *unary() {
  if (consume("+"))
    return primary();
  if (consume("-"))
    return new_node(ND_SUB, new_node_num(0), primary());
  return primary();
}

Node *mul() {
  Node *node = unary();

  for(;;) {
    if(consume("*"))
      node = new_node(ND_MUL, node, unary());
    else if (consume("/"))
      node = new_node(ND_DIV, node, unary());
    else
      return node;
  }
}

Node *add() {
  Node *node = mul();

  for (;;) {
    if (consume("+"))
      node = new_node(ND_ADD, node, mul());
    else if (consume("-"))
      node = new_node(ND_SUB, node, mul());
    else
      return node;
  }
}

Node *relational() {
  Node *node = add();

  for (;;) {
    if (consume("<"))
      node = new_node(ND_LT, node, add());
    else if (consume("<="))
      node = new_node(ND_LTE, node, add());
    else if (consume(">"))
      node = new_node(ND_LT, add(), node);
    else if (consume(">="))
      node = new_node(ND_LTE, add(), node);
    else
      return node;
  }
}

Node *equality() {
  Node *node = relational();

  for (;;) {
    if (consume("=="))
      node = new_node(ND_EQ, node, relational());
    else if (consume("!="))
      node = new_node(ND_NE, node, relational());
    else
      return node;
  }
}

codegen.c

#include "9cc.h"

int global_register_count = 0;
void gen(Node *node) {
  printf("// kind %d\n", node->kind);
  if (node->kind == ND_NUM) {
    printf(" mov w%d, %d\n", global_register_count, node->val);
    printf(" str w%d, [sp, #-16]!\n", global_register_count);
    // global_register_count;
    return;
  }

  gen(node->lhs);
  gen(node->rhs);

  printf(" ldr w1, [sp], #16\n");
  printf(" ldr w0, [sp], #16\n");

  switch (node->kind) {
    case ND_ADD:
      printf(" add w0, w0, w1\n");
      break;
    case ND_SUB:
      printf(" sub w0, w0, w1\n");
      break;
    case ND_MUL:
      printf(" mul w0, w0, w1\n");
      break;
    case ND_DIV:
      printf(" sdiv w0, w0, w1\n");
      break;
    case ND_EQ:
      printf(" cmp w0, w1\n");
      printf(" cset w0, eq\n");
      break;
    case ND_NE:
      printf(" cmp w0, w1\n");
      printf(" cset w0, ne\n");
      break;
    case ND_LT:
      printf(" cmp w0, w1\n");
      printf(" cset w0, lt\n");
      break;
    case ND_LTE:
      printf(" cmp w0, w1\n");
      printf(" cset w0, le\n");
      break;
  }
  printf(" str w0, [sp, #-16]!\n");
}

分割コンパイルとは

分割コンパイルとは、1つの大きなプログラムを複数のソースファイル(.cや.cppなど)に分割し、それぞれを個別にコンパイルする手法です。それぞれのソースファイルをコンパイルすると、「オブジェクトファイル」と呼ばれる中間成果物(.oや.obj)が生成されます。これらのオブジェクトファイルは単体では実行できまん。これらをまとめて1つの実行ファイルにするプログラムのことを「リンカ」といいます。

なぜ分割コンパイルが必要なのか?

ビルド時間の短縮

大規模プロジェクトを考えてみてください。もし、数千万行のコードを1つのソースファイルにまとめ、何かを少し修正するたびにすべてをコンパイルし直すとなると、途方もない時間がかかってしまいます。
一方、分割コンパイルでは、変更のあったソースファイルだけを再コンパイルし、残りは過去にビルド済みのオブジェクトファイルを再利用できます。これによりビルド時間を大幅に短縮できるようになります。

コード管理の容易化

すべての関数や変数を1つのファイルに詰め込むと、関数がどこで定義されているのか、変数がどのように使われているのか全体を見通すのが困難になります。
複数のソースファイルに分割してモジュール化することでコードの見通しが良くなり、保守性が向上します。

標準ライブラリなど既存のコードを有効活用

標準ライブラリの関数(例:printf)など既にコンパイル済みのライブラリを毎回ソースコードからコンパイルする必要があると想像してみてください。何度も同じコードをコンパイルするのは非効率です。
実際には、標準ライブラリはあらかじめコンパイル済みのオブジェクトファイルやライブラリファイル(.aや.so、.libなど)として配布されており、手元で毎回コンパイルし直さなくても良くなっています。

まとめ

<、<=、>、>=、==、!=といった比較演算子に対応できるようになったことで、単純な四則演算を超えてより複雑な計算処理が可能になりました。これにより一般的な電卓の計算力を超え、柔軟で拡張性の高い計算ができるようになりました。

また、今まで9cc.cに全ての処理を記述していましたが、役割ごとにファイルを分割したことでコードが読みやすくなりました。エラーの箇所や追加で新しい処理を記述するときにすぐに該当するファイルや関数を見つけ出せることができるので、スムーズに開発を進められるようになりました。